From 1e67d32550846fe59df90ca4b3abcac3e4f24c98 Mon Sep 17 00:00:00 2001 From: Ricken Bazolo Date: Sun, 9 Feb 2025 02:10:48 +0100 Subject: [PATCH] Add tests for Mistral moderation API --- .../mistralai/MistralAiModerationModelIT.java | 54 +++++++++++++++++++ .../mistralai/MistralAiTestConfiguration.java | 17 ++++++ .../mistralai/MistralAiPropertiesTests.java | 13 +++++ 3 files changed, 84 insertions(+) create mode 100644 models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java new file mode 100644 index 00000000000..9eb0f86ed3c --- /dev/null +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java @@ -0,0 +1,54 @@ +package org.springframework.ai.mistralai; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.mistralai.moderation.MistralAiModerationModel; +import org.springframework.ai.moderation.Moderation; +import org.springframework.ai.moderation.ModerationPrompt; +import org.springframework.ai.moderation.ModerationResult; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Ricken Bazolo + */ +@SpringBootTest(classes = MistralAiTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+") +public class MistralAiModerationModelIT { + + private static final Logger logger = LoggerFactory.getLogger(MistralAiModerationModelIT.class); + + @Autowired + private MistralAiModerationModel mistralAiModerationModel; + + @Test + void moderationAsPositiveTest() { + var instructions = """ + I want to kill them.!"."""; + + var moderationPrompt = new ModerationPrompt(instructions); + + var moderationResponse = this.mistralAiModerationModel.call(moderationPrompt); + + assertThat(moderationResponse.getResults()).hasSize(1); + + var generation = moderationResponse.getResult(); + Moderation moderation = generation.getOutput(); + assertThat(moderation.getId()).isNotEmpty(); + assertThat(moderation.getResults()).isNotNull(); + assertThat(moderation.getResults().size()).isNotZero(); + logger.info(moderation.getResults().toString()); + + assertThat(moderation.getId()).isNotNull(); + assertThat(moderation.getModel()).isNotNull(); + + ModerationResult result = moderation.getResults().get(0); + assertThat(result.isFlagged()).isTrue(); + assertThat(result.getCategories().isViolence()).isTrue(); + } + +} diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java index 11c084ff613..893f19892ed 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java @@ -18,6 +18,8 @@ import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.mistralai.api.MistralAiApi; +import org.springframework.ai.mistralai.api.MistralAiModerationApi; +import org.springframework.ai.mistralai.moderation.MistralAiModerationModel; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; @@ -35,6 +37,16 @@ public MistralAiApi mistralAiApi() { return new MistralAiApi(apiKey); } + @Bean + public MistralAiModerationApi mistralAiModerationApi() { + var apiKey = System.getenv("MISTRAL_AI_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "Missing MISTRAL_AI_API_KEY environment variable. Please set it to your Mistral AI API key."); + } + return new MistralAiModerationApi(apiKey); + } + @Bean public EmbeddingModel mistralAiEmbeddingModel(MistralAiApi api) { return new MistralAiEmbeddingModel(api, @@ -47,4 +59,9 @@ public MistralAiChatModel mistralAiChatModel(MistralAiApi mistralAiApi) { MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.OPEN_MIXTRAL_7B.getValue()).build()); } + @Bean + public MistralAiModerationModel mistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi) { + return new MistralAiModerationModel(mistralAiModerationApi); + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java index 46c814bc380..6c615c7674e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java @@ -145,4 +145,17 @@ public void embeddingOptionsTest() { }); } + @Test + public void moderationOptionsTest() { + new ApplicationContextRunner() + .withPropertyValues("spring.ai.mistralai.base-url=TEST_BASE_URL", "spring.ai.mistralai.api-key=abc123", + "spring.ai.mistralai.moderation.options.model=MODERATION_MODEL") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)) + .run(context -> { + var moderationProperties = context.getBean(MistralAiModerationProperties.class); + assertThat(moderationProperties.getOptions().getModel()).isEqualTo("MODERATION_MODEL"); + }); + } + }