-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Support of Mistral AI Moderation API
Signed-off-by: Ricken Bazolo <[email protected]>
- Loading branch information
Showing
7 changed files
with
526 additions
and
4 deletions.
There are no files selected for viewing
137 changes: 137 additions & 0 deletions
137
...mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiModerationApi.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
package org.springframework.ai.mistralai.api; | ||
|
||
import com.fasterxml.jackson.annotation.JsonInclude; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import org.springframework.ai.retry.RetryUtils; | ||
import org.springframework.http.HttpHeaders; | ||
import org.springframework.http.MediaType; | ||
import org.springframework.http.ResponseEntity; | ||
import org.springframework.util.Assert; | ||
import org.springframework.web.client.ResponseErrorHandler; | ||
import org.springframework.web.client.RestClient; | ||
|
||
import java.util.function.Consumer; | ||
|
||
/** | ||
* MistralAI Moderation API. | ||
* | ||
* @author Ricken Bazolo | ||
* @see <a href= "https://docs.mistral.ai/capabilities/guardrailing/</a> | ||
*/ | ||
public class MistralAiModerationApi { | ||
|
||
private static final String DEFAULT_BASE_URL = "https://api.mistral.ai"; | ||
|
||
private final RestClient restClient; | ||
|
||
public MistralAiModerationApi(String mistralAiApiKey) { | ||
this(DEFAULT_BASE_URL, mistralAiApiKey, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); | ||
} | ||
|
||
public MistralAiModerationApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder, | ||
ResponseErrorHandler responseErrorHandler) { | ||
|
||
Consumer<HttpHeaders> jsonContentHeaders = headers -> { | ||
headers.setBearerAuth(mistralAiApiKey); | ||
headers.setContentType(MediaType.APPLICATION_JSON); | ||
}; | ||
|
||
this.restClient = restClientBuilder.baseUrl(baseUrl) | ||
.defaultHeaders(jsonContentHeaders) | ||
.defaultStatusHandler(responseErrorHandler) | ||
.build(); | ||
} | ||
|
||
public ResponseEntity<MistralAiModerationResponse> moderate(MistralAiModerationRequest mistralAiModerationRequest) { | ||
Assert.notNull(mistralAiModerationRequest, "Moderation request cannot be null."); | ||
Assert.hasLength(mistralAiModerationRequest.prompt(), "Prompt cannot be empty."); | ||
Assert.notNull(mistralAiModerationRequest.model(), "Model cannot be null."); | ||
|
||
return this.restClient.post() | ||
.uri("v1/moderations") | ||
.body(mistralAiModerationRequest) | ||
.retrieve() | ||
.toEntity(MistralAiModerationResponse.class); | ||
} | ||
|
||
public enum Model { | ||
|
||
// @formatter:off | ||
MISTRAL_MODERATION("mistral-moderation-latest"); | ||
// @formatter:on | ||
|
||
private final String value; | ||
|
||
Model(String value) { | ||
this.value = value; | ||
} | ||
|
||
public String getValue() { | ||
return this.value; | ||
} | ||
|
||
} | ||
|
||
// @formatter:off | ||
@JsonInclude(JsonInclude.Include.NON_NULL) | ||
public record MistralAiModerationRequest( | ||
@JsonProperty("input") String prompt, | ||
@JsonProperty("model") String model | ||
) { | ||
|
||
public MistralAiModerationRequest(String prompt) { | ||
this(prompt, null); | ||
} | ||
} | ||
|
||
@JsonInclude(JsonInclude.Include.NON_NULL) | ||
public record MistralAiModerationResponse( | ||
@JsonProperty("id") String id, | ||
@JsonProperty("model") String model, | ||
@JsonProperty("results") MistralAiModerationResult[] results) { | ||
|
||
} | ||
|
||
@JsonInclude(JsonInclude.Include.NON_NULL) | ||
public record MistralAiModerationResult( | ||
@JsonProperty("categories") Categories categories, | ||
@JsonProperty("category_scores") CategoryScores categoryScores) { | ||
|
||
public boolean flagged() { | ||
return categories != null && (categories.sexual() || categories.hateAndDiscrimination() || categories.violenceAndThreats() | ||
|| categories.selfHarm() || categories.dangerousAndCriminalContent() || categories.health() | ||
|| categories.financial() || categories.law() || categories.pii()); | ||
} | ||
|
||
} | ||
|
||
@JsonInclude(JsonInclude.Include.NON_NULL) | ||
public record Categories( | ||
@JsonProperty("sexual") boolean sexual, | ||
@JsonProperty("hate_and_discrimination") boolean hateAndDiscrimination, | ||
@JsonProperty("violence_and_threats") boolean violenceAndThreats, | ||
@JsonProperty("selfharm") boolean selfHarm, | ||
@JsonProperty("dangerous_and_criminal_content") boolean dangerousAndCriminalContent, | ||
@JsonProperty("health") boolean health, | ||
@JsonProperty("financial") boolean financial, | ||
@JsonProperty("law") boolean law, | ||
@JsonProperty("pii") boolean pii) { | ||
|
||
} | ||
|
||
@JsonInclude(JsonInclude.Include.NON_NULL) | ||
public record CategoryScores( | ||
@JsonProperty("sexual") double sexual, | ||
@JsonProperty("hate_and_discrimination") double hateAndDiscrimination, | ||
@JsonProperty("violence_and_threats") double violenceAndThreats, | ||
@JsonProperty("selfharm") double selfHarm, | ||
@JsonProperty("dangerous_and_criminal_content") double dangerousAndCriminalContent, | ||
@JsonProperty("health") double health, | ||
@JsonProperty("financial") double financial, | ||
@JsonProperty("law") double law, | ||
@JsonProperty("pii") double pii) { | ||
|
||
} | ||
// @formatter:onn | ||
|
||
} |
146 changes: 146 additions & 0 deletions
146
...i/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
package org.springframework.ai.mistralai.moderation; | ||
|
||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
import org.springframework.ai.mistralai.api.MistralAiModerationApi; | ||
import org.springframework.ai.model.ModelOptionsUtils; | ||
import org.springframework.ai.moderation.*; | ||
import org.springframework.ai.retry.RetryUtils; | ||
import org.springframework.http.ResponseEntity; | ||
import org.springframework.retry.support.RetryTemplate; | ||
import org.springframework.util.Assert; | ||
|
||
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationRequest; | ||
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResponse; | ||
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResult; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
/** | ||
* @author Ricken Bazolo | ||
*/ | ||
public class MistralAiModerationModel implements ModerationModel { | ||
|
||
private final Logger logger = LoggerFactory.getLogger(getClass()); | ||
|
||
private final MistralAiModerationApi mistralAiModerationApi; | ||
|
||
private final RetryTemplate retryTemplate; | ||
|
||
private final MistralAiModerationOptions defaultOptions; | ||
|
||
public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi) { | ||
this(mistralAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE, | ||
MistralAiModerationOptions.builder() | ||
.model(MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue()) | ||
.build()); | ||
} | ||
|
||
public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, MistralAiModerationOptions options) { | ||
this(mistralAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE, options); | ||
} | ||
|
||
public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, RetryTemplate retryTemplate, | ||
MistralAiModerationOptions options) { | ||
Assert.notNull(mistralAiModerationApi, "mistralAiModerationApi must not be null"); | ||
Assert.notNull(retryTemplate, "retryTemplate must not be null"); | ||
Assert.notNull(options, "options must not be null"); | ||
this.mistralAiModerationApi = mistralAiModerationApi; | ||
this.retryTemplate = retryTemplate; | ||
this.defaultOptions = options; | ||
} | ||
|
||
@Override | ||
public ModerationResponse call(ModerationPrompt moderationPrompt) { | ||
return this.retryTemplate.execute(ctx -> { | ||
|
||
var instructions = moderationPrompt.getInstructions().getText(); | ||
|
||
var moderationRequest = new MistralAiModerationRequest(instructions); | ||
|
||
if (this.defaultOptions != null) { | ||
moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest, | ||
MistralAiModerationRequest.class); | ||
} | ||
else { | ||
// moderationPrompt.getOptions() never null but model can be empty, cause | ||
// by ModerationPrompt constructor | ||
moderationRequest = ModelOptionsUtils.merge(toMistralAiModerationOptions(moderationPrompt.getOptions()), | ||
moderationRequest, MistralAiModerationRequest.class); | ||
} | ||
|
||
var moderationResponseEntity = this.mistralAiModerationApi.moderate(moderationRequest); | ||
|
||
return convertResponse(moderationResponseEntity, moderationRequest); | ||
}); | ||
} | ||
|
||
private ModerationResponse convertResponse(ResponseEntity<MistralAiModerationResponse> moderationResponseEntity, | ||
MistralAiModerationRequest openAiModerationRequest) { | ||
var moderationApiResponse = moderationResponseEntity.getBody(); | ||
if (moderationApiResponse == null) { | ||
logger.warn("No moderation response returned for request: {}", openAiModerationRequest); | ||
return new ModerationResponse(new Generation()); | ||
} | ||
|
||
List<ModerationResult> moderationResults = new ArrayList<>(); | ||
if (moderationApiResponse.results() != null) { | ||
|
||
for (MistralAiModerationResult result : moderationApiResponse.results()) { | ||
Categories categories = null; | ||
CategoryScores categoryScores = null; | ||
if (result.categories() != null) { | ||
categories = Categories.builder() | ||
.sexual(result.categories().sexual()) | ||
.pii(result.categories().pii()) | ||
.law(result.categories().law()) | ||
.financial(result.categories().financial()) | ||
.health(result.categories().health()) | ||
.dangerousAndCriminalContent(result.categories().dangerousAndCriminalContent()) | ||
.violence(result.categories().violenceAndThreats()) | ||
.hate(result.categories().hateAndDiscrimination()) | ||
.selfHarm(result.categories().selfHarm()) | ||
.build(); | ||
} | ||
if (result.categoryScores() != null) { | ||
categoryScores = CategoryScores.builder() | ||
.sexual(result.categoryScores().sexual()) | ||
.pii(result.categoryScores().pii()) | ||
.law(result.categoryScores().law()) | ||
.financial(result.categoryScores().financial()) | ||
.health(result.categoryScores().health()) | ||
.dangerousAndCriminalContent(result.categoryScores().dangerousAndCriminalContent()) | ||
.violence(result.categoryScores().violenceAndThreats()) | ||
.hate(result.categoryScores().hateAndDiscrimination()) | ||
.selfHarm(result.categoryScores().selfHarm()) | ||
.build(); | ||
} | ||
var moderationResult = ModerationResult.builder() | ||
.categories(categories) | ||
.categoryScores(categoryScores) | ||
.flagged(result.flagged()) | ||
.build(); | ||
moderationResults.add(moderationResult); | ||
} | ||
|
||
} | ||
|
||
var moderation = Moderation.builder() | ||
.id(moderationApiResponse.id()) | ||
.model(moderationApiResponse.model()) | ||
.results(moderationResults) | ||
.build(); | ||
|
||
return new ModerationResponse(new Generation(moderation)); | ||
} | ||
|
||
private MistralAiModerationOptions toMistralAiModerationOptions(ModerationOptions runtimeModerationOptions) { | ||
var mistralAiModerationOptionsBuilder = MistralAiModerationOptions.builder(); | ||
if (runtimeModerationOptions != null && runtimeModerationOptions.getModel() != null) { | ||
mistralAiModerationOptionsBuilder.model(runtimeModerationOptions.getModel()); | ||
} | ||
return mistralAiModerationOptionsBuilder.build(); | ||
} | ||
|
||
} |
54 changes: 54 additions & 0 deletions
54
...src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationOptions.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package org.springframework.ai.mistralai.moderation; | ||
|
||
import com.fasterxml.jackson.annotation.JsonInclude; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import org.springframework.ai.mistralai.api.MistralAiModerationApi; | ||
import org.springframework.ai.moderation.ModerationOptions; | ||
|
||
/** | ||
* @author Ricken Bazolo | ||
*/ | ||
@JsonInclude(JsonInclude.Include.NON_NULL) | ||
public class MistralAiModerationOptions implements ModerationOptions { | ||
|
||
private static final String DEFAULT_MODEL = MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue(); | ||
|
||
/** | ||
* The model to use for moderation generation. | ||
*/ | ||
@JsonProperty("model") | ||
private String model = DEFAULT_MODEL; | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
@Override | ||
public String getModel() { | ||
return this.model; | ||
} | ||
|
||
public void setModel(String model) { | ||
this.model = model; | ||
} | ||
|
||
public static final class Builder { | ||
|
||
private final MistralAiModerationOptions options; | ||
|
||
private Builder() { | ||
this.options = new MistralAiModerationOptions(); | ||
} | ||
|
||
public Builder model(String model) { | ||
this.options.setModel(model); | ||
return this; | ||
} | ||
|
||
public MistralAiModerationOptions build() { | ||
return this.options; | ||
} | ||
|
||
} | ||
|
||
} |
Oops, something went wrong.