From f31ace1388c546363c44e3e21f869533ab391efe Mon Sep 17 00:00:00 2001 From: Ricken Bazolo Date: Sun, 9 Feb 2025 01:49:04 +0100 Subject: [PATCH 1/2] Add Support of Mistral AI Moderation API Signed-off-by: Ricken Bazolo --- .../mistralai/api/MistralAiModerationApi.java | 137 ++++++++++++++++ .../moderation/MistralAiModerationModel.java | 146 ++++++++++++++++++ .../MistralAiModerationOptions.java | 54 +++++++ .../ai/moderation/Categories.java | 82 +++++++++- .../ai/moderation/CategoryScores.java | 51 ++++++ .../mistralai/MistralAiAutoConfiguration.java | 25 ++- .../MistralAiModerationProperties.java | 35 +++++ 7 files changed, 526 insertions(+), 4 deletions(-) create mode 100644 models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiModerationApi.java create mode 100644 models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java create mode 100644 models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationOptions.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiModerationProperties.java diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiModerationApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiModerationApi.java new file mode 100644 index 0000000000..02fcef6f56 --- /dev/null +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiModerationApi.java @@ -0,0 +1,137 @@ +package org.springframework.ai.mistralai.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +import java.util.function.Consumer; + +/** + * MistralAI Moderation API. + * + * @author Ricken Bazolo + * @see jsonContentHeaders = headers -> { + headers.setBearerAuth(mistralAiApiKey); + headers.setContentType(MediaType.APPLICATION_JSON); + }; + + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(jsonContentHeaders) + .defaultStatusHandler(responseErrorHandler) + .build(); + } + + public ResponseEntity moderate(MistralAiModerationRequest mistralAiModerationRequest) { + Assert.notNull(mistralAiModerationRequest, "Moderation request cannot be null."); + Assert.hasLength(mistralAiModerationRequest.prompt(), "Prompt cannot be empty."); + Assert.notNull(mistralAiModerationRequest.model(), "Model cannot be null."); + + return this.restClient.post() + .uri("v1/moderations") + .body(mistralAiModerationRequest) + .retrieve() + .toEntity(MistralAiModerationResponse.class); + } + + public enum Model { + + // @formatter:off + MISTRAL_MODERATION("mistral-moderation-latest"); + // @formatter:on + + private final String value; + + Model(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + + // @formatter:off + @JsonInclude(JsonInclude.Include.NON_NULL) + public record MistralAiModerationRequest( + @JsonProperty("input") String prompt, + @JsonProperty("model") String model + ) { + + public MistralAiModerationRequest(String prompt) { + this(prompt, null); + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record MistralAiModerationResponse( + @JsonProperty("id") String id, + @JsonProperty("model") String model, + @JsonProperty("results") MistralAiModerationResult[] results) { + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record MistralAiModerationResult( + @JsonProperty("categories") Categories categories, + @JsonProperty("category_scores") CategoryScores categoryScores) { + + public boolean flagged() { + return categories != null && (categories.sexual() || categories.hateAndDiscrimination() || categories.violenceAndThreats() + || categories.selfHarm() || categories.dangerousAndCriminalContent() || categories.health() + || categories.financial() || categories.law() || categories.pii()); + } + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Categories( + @JsonProperty("sexual") boolean sexual, + @JsonProperty("hate_and_discrimination") boolean hateAndDiscrimination, + @JsonProperty("violence_and_threats") boolean violenceAndThreats, + @JsonProperty("selfharm") boolean selfHarm, + @JsonProperty("dangerous_and_criminal_content") boolean dangerousAndCriminalContent, + @JsonProperty("health") boolean health, + @JsonProperty("financial") boolean financial, + @JsonProperty("law") boolean law, + @JsonProperty("pii") boolean pii) { + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record CategoryScores( + @JsonProperty("sexual") double sexual, + @JsonProperty("hate_and_discrimination") double hateAndDiscrimination, + @JsonProperty("violence_and_threats") double violenceAndThreats, + @JsonProperty("selfharm") double selfHarm, + @JsonProperty("dangerous_and_criminal_content") double dangerousAndCriminalContent, + @JsonProperty("health") double health, + @JsonProperty("financial") double financial, + @JsonProperty("law") double law, + @JsonProperty("pii") double pii) { + + } + // @formatter:onn + +} diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java new file mode 100644 index 0000000000..1df7b0c259 --- /dev/null +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java @@ -0,0 +1,146 @@ +package org.springframework.ai.mistralai.moderation; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.mistralai.api.MistralAiModerationApi; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.moderation.*; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationRequest; +import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResponse; +import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResult; + +import java.util.ArrayList; +import java.util.List; + +/** + * @author Ricken Bazolo + */ +public class MistralAiModerationModel implements ModerationModel { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final MistralAiModerationApi mistralAiModerationApi; + + private final RetryTemplate retryTemplate; + + private final MistralAiModerationOptions defaultOptions; + + public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi) { + this(mistralAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE, + MistralAiModerationOptions.builder() + .model(MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue()) + .build()); + } + + public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, MistralAiModerationOptions options) { + this(mistralAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE, options); + } + + public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, RetryTemplate retryTemplate, + MistralAiModerationOptions options) { + Assert.notNull(mistralAiModerationApi, "mistralAiModerationApi must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); + Assert.notNull(options, "options must not be null"); + this.mistralAiModerationApi = mistralAiModerationApi; + this.retryTemplate = retryTemplate; + this.defaultOptions = options; + } + + @Override + public ModerationResponse call(ModerationPrompt moderationPrompt) { + return this.retryTemplate.execute(ctx -> { + + var instructions = moderationPrompt.getInstructions().getText(); + + var moderationRequest = new MistralAiModerationRequest(instructions); + + if (this.defaultOptions != null) { + moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest, + MistralAiModerationRequest.class); + } + else { + // moderationPrompt.getOptions() never null but model can be empty, cause + // by ModerationPrompt constructor + moderationRequest = ModelOptionsUtils.merge(toMistralAiModerationOptions(moderationPrompt.getOptions()), + moderationRequest, MistralAiModerationRequest.class); + } + + var moderationResponseEntity = this.mistralAiModerationApi.moderate(moderationRequest); + + return convertResponse(moderationResponseEntity, moderationRequest); + }); + } + + private ModerationResponse convertResponse(ResponseEntity moderationResponseEntity, + MistralAiModerationRequest openAiModerationRequest) { + var moderationApiResponse = moderationResponseEntity.getBody(); + if (moderationApiResponse == null) { + logger.warn("No moderation response returned for request: {}", openAiModerationRequest); + return new ModerationResponse(new Generation()); + } + + List moderationResults = new ArrayList<>(); + if (moderationApiResponse.results() != null) { + + for (MistralAiModerationResult result : moderationApiResponse.results()) { + Categories categories = null; + CategoryScores categoryScores = null; + if (result.categories() != null) { + categories = Categories.builder() + .sexual(result.categories().sexual()) + .pii(result.categories().pii()) + .law(result.categories().law()) + .financial(result.categories().financial()) + .health(result.categories().health()) + .dangerousAndCriminalContent(result.categories().dangerousAndCriminalContent()) + .violence(result.categories().violenceAndThreats()) + .hate(result.categories().hateAndDiscrimination()) + .selfHarm(result.categories().selfHarm()) + .build(); + } + if (result.categoryScores() != null) { + categoryScores = CategoryScores.builder() + .sexual(result.categoryScores().sexual()) + .pii(result.categoryScores().pii()) + .law(result.categoryScores().law()) + .financial(result.categoryScores().financial()) + .health(result.categoryScores().health()) + .dangerousAndCriminalContent(result.categoryScores().dangerousAndCriminalContent()) + .violence(result.categoryScores().violenceAndThreats()) + .hate(result.categoryScores().hateAndDiscrimination()) + .selfHarm(result.categoryScores().selfHarm()) + .build(); + } + var moderationResult = ModerationResult.builder() + .categories(categories) + .categoryScores(categoryScores) + .flagged(result.flagged()) + .build(); + moderationResults.add(moderationResult); + } + + } + + var moderation = Moderation.builder() + .id(moderationApiResponse.id()) + .model(moderationApiResponse.model()) + .results(moderationResults) + .build(); + + return new ModerationResponse(new Generation(moderation)); + } + + private MistralAiModerationOptions toMistralAiModerationOptions(ModerationOptions runtimeModerationOptions) { + var mistralAiModerationOptionsBuilder = MistralAiModerationOptions.builder(); + if (runtimeModerationOptions != null && runtimeModerationOptions.getModel() != null) { + mistralAiModerationOptionsBuilder.model(runtimeModerationOptions.getModel()); + } + return mistralAiModerationOptionsBuilder.build(); + } + +} diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationOptions.java new file mode 100644 index 0000000000..435651409a --- /dev/null +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationOptions.java @@ -0,0 +1,54 @@ +package org.springframework.ai.mistralai.moderation; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.mistralai.api.MistralAiModerationApi; +import org.springframework.ai.moderation.ModerationOptions; + +/** + * @author Ricken Bazolo + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class MistralAiModerationOptions implements ModerationOptions { + + private static final String DEFAULT_MODEL = MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue(); + + /** + * The model to use for moderation generation. + */ + @JsonProperty("model") + private String model = DEFAULT_MODEL; + + public static Builder builder() { + return new Builder(); + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public static final class Builder { + + private final MistralAiModerationOptions options; + + private Builder() { + this.options = new MistralAiModerationOptions(); + } + + public Builder model(String model) { + this.options.setModel(model); + return this; + } + + public MistralAiModerationOptions build() { + return this.options; + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Categories.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Categories.java index be1c1191bc..23fa93cca6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/Categories.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/Categories.java @@ -25,6 +25,7 @@ * * @author Ahmed Yousri * @author Ilayaperumal Gopinathan + * @author Ricken Bazolo * @since 1.0.0 */ public final class Categories { @@ -51,6 +52,16 @@ public final class Categories { private final boolean violence; + private final boolean dangerousAndCriminalContent; + + private final boolean health; + + private final boolean financial; + + private final boolean law; + + private final boolean pii; + private Categories(Builder builder) { this.sexual = builder.sexual; this.hate = builder.hate; @@ -63,6 +74,11 @@ private Categories(Builder builder) { this.selfHarmInstructions = builder.selfHarmInstructions; this.harassmentThreatening = builder.harassmentThreatening; this.violence = builder.violence; + this.dangerousAndCriminalContent = builder.dangerousAndCriminalContent; + this.health = builder.health; + this.financial = builder.financial; + this.law = builder.law; + this.pii = builder.pii; } public static Builder builder() { @@ -113,6 +129,26 @@ public boolean isViolence() { return this.violence; } + public boolean isDangerousAndCriminalContent() { + return this.dangerousAndCriminalContent; + } + + public boolean isHealth() { + return this.health; + } + + public boolean isFinancial() { + return this.financial; + } + + public boolean isLaw() { + return this.law; + } + + public boolean isPii() { + return this.pii; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -126,14 +162,17 @@ public boolean equals(Object o) { && this.selfHarm == that.selfHarm && this.sexualMinors == that.sexualMinors && this.hateThreatening == that.hateThreatening && this.violenceGraphic == that.violenceGraphic && this.selfHarmIntent == that.selfHarmIntent && this.selfHarmInstructions == that.selfHarmInstructions - && this.harassmentThreatening == that.harassmentThreatening && this.violence == that.violence; + && this.harassmentThreatening == that.harassmentThreatening && this.violence == that.violence + && this.dangerousAndCriminalContent == that.dangerousAndCriminalContent && this.health == that.health + && this.financial == that.financial && this.law == that.law && this.pii == that.pii; } @Override public int hashCode() { return Objects.hash(this.sexual, this.hate, this.harassment, this.selfHarm, this.sexualMinors, this.hateThreatening, this.violenceGraphic, this.selfHarmIntent, this.selfHarmInstructions, - this.harassmentThreatening, this.violence); + this.harassmentThreatening, this.violence, this.dangerousAndCriminalContent, this.health, + this.financial, this.law, this.pii); } @Override @@ -142,7 +181,9 @@ public String toString() { + ", selfHarm=" + this.selfHarm + ", sexualMinors=" + this.sexualMinors + ", hateThreatening=" + this.hateThreatening + ", violenceGraphic=" + this.violenceGraphic + ", selfHarmIntent=" + this.selfHarmIntent + ", selfHarmInstructions=" + this.selfHarmInstructions - + ", harassmentThreatening=" + this.harassmentThreatening + ", violence=" + this.violence + '}'; + + ", harassmentThreatening=" + this.harassmentThreatening + ", violence=" + this.violence + + ", dangerousAndCriminalContent=" + this.dangerousAndCriminalContent + ", health=" + this.health + + ", financial=" + this.financial + ", law=" + this.law + ", pii=" + this.pii + '}'; } public static class Builder { @@ -169,6 +210,16 @@ public static class Builder { private boolean violence; + private boolean dangerousAndCriminalContent; + + private boolean health; + + private boolean financial; + + private boolean law; + + private boolean pii; + public Builder sexual(boolean sexual) { this.sexual = sexual; return this; @@ -224,6 +275,31 @@ public Builder violence(boolean violence) { return this; } + public Builder dangerousAndCriminalContent(boolean dangerousAndCriminalContent) { + this.dangerousAndCriminalContent = dangerousAndCriminalContent; + return this; + } + + public Builder health(boolean health) { + this.health = health; + return this; + } + + public Builder financial(boolean financial) { + this.financial = financial; + return this; + } + + public Builder law(boolean law) { + this.law = law; + return this; + } + + public Builder pii(boolean pii) { + this.pii = pii; + return this; + } + public Categories build() { return new Categories(this); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/moderation/CategoryScores.java b/spring-ai-core/src/main/java/org/springframework/ai/moderation/CategoryScores.java index f4ee845d53..4a9c8d1a2d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/moderation/CategoryScores.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/moderation/CategoryScores.java @@ -25,6 +25,7 @@ * * @author Ahmed Yousri * @author Ilayaperumal Gopinathan + * @author Ricken Bazolo * @since 1.0.0 */ public final class CategoryScores { @@ -51,6 +52,16 @@ public final class CategoryScores { private final double violence; + private final double dangerousAndCriminalContent; + + private final double health; + + private final double financial; + + private final double law; + + private final double pii; + private CategoryScores(Builder builder) { this.sexual = builder.sexual; this.hate = builder.hate; @@ -63,6 +74,11 @@ private CategoryScores(Builder builder) { this.selfHarmInstructions = builder.selfHarmInstructions; this.harassmentThreatening = builder.harassmentThreatening; this.violence = builder.violence; + this.dangerousAndCriminalContent = builder.dangerousAndCriminalContent; + this.health = builder.health; + this.financial = builder.financial; + this.law = builder.law; + this.pii = builder.pii; } public static Builder builder() { @@ -174,6 +190,16 @@ public static class Builder { private double violence; + private double dangerousAndCriminalContent; + + private double health; + + private double financial; + + private double law; + + private double pii; + public Builder sexual(double sexual) { this.sexual = sexual; return this; @@ -229,6 +255,31 @@ public Builder violence(double violence) { return this; } + public Builder dangerousAndCriminalContent(double dangerousAndCriminalContent) { + this.dangerousAndCriminalContent = dangerousAndCriminalContent; + return this; + } + + public Builder health(double health) { + this.health = health; + return this; + } + + public Builder financial(double financial) { + this.financial = financial; + return this; + } + + public Builder law(double law) { + this.law = law; + return this; + } + + public Builder pii(double pii) { + this.pii = pii; + return this; + } + public CategoryScores build() { return new CategoryScores(this); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java index e380e897cc..93296971bb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java @@ -25,7 +25,9 @@ import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.mistralai.MistralAiChatModel; import org.springframework.ai.mistralai.MistralAiEmbeddingModel; +import org.springframework.ai.mistralai.moderation.MistralAiModerationModel; import org.springframework.ai.mistralai.api.MistralAiApi; +import org.springframework.ai.mistralai.api.MistralAiModerationApi; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; @@ -56,7 +58,7 @@ */ @AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class }) @EnableConfigurationProperties({ MistralAiEmbeddingProperties.class, MistralAiCommonProperties.class, - MistralAiChatProperties.class }) + MistralAiChatProperties.class, MistralAiModerationProperties.class }) @ConditionalOnClass(MistralAiApi.class) @ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class, WebClientAutoConfiguration.class }) @@ -108,6 +110,27 @@ public MistralAiChatModel mistralAiChatModel(MistralAiCommonProperties commonPro return chatModel; } + @Bean + @ConditionalOnMissingBean + public MistralAiModerationModel mistralAiModerationModel(MistralAiCommonProperties commonProperties, + MistralAiModerationProperties moderationProperties, RetryTemplate retryTemplate, + ObjectProvider restClientBuilderProvider, ResponseErrorHandler responseErrorHandler) { + + var apiKey = moderationProperties.getApiKey(); + var baseUrl = moderationProperties.getBaseUrl(); + + var resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonProperties.getApiKey(); + var resoledBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonProperties.getBaseUrl(); + + Assert.hasText(resolvedApiKey, "Mistral API key must be set"); + Assert.hasText(resoledBaseUrl, "Mistral base URL must be set"); + + var mistralAiModerationAi = new MistralAiModerationApi(resoledBaseUrl, resolvedApiKey, + restClientBuilderProvider.getIfAvailable(RestClient::builder), responseErrorHandler); + + return new MistralAiModerationModel(mistralAiModerationAi, retryTemplate, moderationProperties.getOptions()); + } + private MistralAiApi mistralAiApi(String apiKey, String commonApiKey, String baseUrl, String commonBaseUrl, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiModerationProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiModerationProperties.java new file mode 100644 index 0000000000..2c6c4d39e9 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiModerationProperties.java @@ -0,0 +1,35 @@ +package org.springframework.ai.autoconfigure.mistralai; + +import org.springframework.ai.mistralai.moderation.MistralAiModerationOptions; +import org.springframework.ai.mistralai.api.MistralAiModerationApi; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * @author Ricken Bazolo + */ +@ConfigurationProperties(MistralAiModerationProperties.CONFIG_PREFIX) +public class MistralAiModerationProperties extends MistralAiParentProperties { + + public static final String CONFIG_PREFIX = "spring.ai.mistralai.moderation"; + + private static final String DEFAULT_MODERATION_MODEL = MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue(); + + @NestedConfigurationProperty + private MistralAiModerationOptions options = MistralAiModerationOptions.builder() + .model(DEFAULT_MODERATION_MODEL) + .build(); + + public MistralAiModerationProperties() { + super.setBaseUrl(MistralAiCommonProperties.DEFAULT_BASE_URL); + } + + public MistralAiModerationOptions getOptions() { + return this.options; + } + + public void setOptions(MistralAiModerationOptions options) { + this.options = options; + } + +} From 3856aef6a69b63361cb5a5e699d70309f999a927 Mon Sep 17 00:00:00 2001 From: Ricken Bazolo Date: Sun, 9 Feb 2025 02:10:48 +0100 Subject: [PATCH 2/2] Add tests for Mistral moderation API Signed-off-by: Ricken Bazolo --- .../mistralai/MistralAiModerationModelIT.java | 54 +++++++++++++++++++ .../mistralai/MistralAiTestConfiguration.java | 17 ++++++ .../mistralai/MistralAiPropertiesTests.java | 13 +++++ 3 files changed, 84 insertions(+) create mode 100644 models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java new file mode 100644 index 0000000000..9eb0f86ed3 --- /dev/null +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java @@ -0,0 +1,54 @@ +package org.springframework.ai.mistralai; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.mistralai.moderation.MistralAiModerationModel; +import org.springframework.ai.moderation.Moderation; +import org.springframework.ai.moderation.ModerationPrompt; +import org.springframework.ai.moderation.ModerationResult; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Ricken Bazolo + */ +@SpringBootTest(classes = MistralAiTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+") +public class MistralAiModerationModelIT { + + private static final Logger logger = LoggerFactory.getLogger(MistralAiModerationModelIT.class); + + @Autowired + private MistralAiModerationModel mistralAiModerationModel; + + @Test + void moderationAsPositiveTest() { + var instructions = """ + I want to kill them.!"."""; + + var moderationPrompt = new ModerationPrompt(instructions); + + var moderationResponse = this.mistralAiModerationModel.call(moderationPrompt); + + assertThat(moderationResponse.getResults()).hasSize(1); + + var generation = moderationResponse.getResult(); + Moderation moderation = generation.getOutput(); + assertThat(moderation.getId()).isNotEmpty(); + assertThat(moderation.getResults()).isNotNull(); + assertThat(moderation.getResults().size()).isNotZero(); + logger.info(moderation.getResults().toString()); + + assertThat(moderation.getId()).isNotNull(); + assertThat(moderation.getModel()).isNotNull(); + + ModerationResult result = moderation.getResults().get(0); + assertThat(result.isFlagged()).isTrue(); + assertThat(result.getCategories().isViolence()).isTrue(); + } + +} diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java index 11c084ff61..893f19892e 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java @@ -18,6 +18,8 @@ import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.mistralai.api.MistralAiApi; +import org.springframework.ai.mistralai.api.MistralAiModerationApi; +import org.springframework.ai.mistralai.moderation.MistralAiModerationModel; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; @@ -35,6 +37,16 @@ public MistralAiApi mistralAiApi() { return new MistralAiApi(apiKey); } + @Bean + public MistralAiModerationApi mistralAiModerationApi() { + var apiKey = System.getenv("MISTRAL_AI_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "Missing MISTRAL_AI_API_KEY environment variable. Please set it to your Mistral AI API key."); + } + return new MistralAiModerationApi(apiKey); + } + @Bean public EmbeddingModel mistralAiEmbeddingModel(MistralAiApi api) { return new MistralAiEmbeddingModel(api, @@ -47,4 +59,9 @@ public MistralAiChatModel mistralAiChatModel(MistralAiApi mistralAiApi) { MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.OPEN_MIXTRAL_7B.getValue()).build()); } + @Bean + public MistralAiModerationModel mistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi) { + return new MistralAiModerationModel(mistralAiModerationApi); + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java index 46c814bc38..6c615c7674 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java @@ -145,4 +145,17 @@ public void embeddingOptionsTest() { }); } + @Test + public void moderationOptionsTest() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.mistralai.base-url=TEST_BASE_URL", "spring.ai.mistralai.api-key=abc123", + "spring.ai.mistralai.moderation.options.model=MODERATION_MODEL") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)) + .run(context -> { + var moderationProperties = context.getBean(MistralAiModerationProperties.class); + assertThat(moderationProperties.getOptions().getModel()).isEqualTo("MODERATION_MODEL"); + }); + } + }