Skip to content

Commit

Permalink
Add tests for Mistral moderation API
Browse files Browse the repository at this point in the history
Signed-off-by: Ricken Bazolo <[email protected]>
  • Loading branch information
ricken07 committed Feb 9, 2025
1 parent ca283da commit aeeb029
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package org.springframework.ai.mistralai;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.mistralai.moderation.MistralAiModerationModel;
import org.springframework.ai.moderation.Moderation;
import org.springframework.ai.moderation.ModerationPrompt;
import org.springframework.ai.moderation.ModerationResult;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;

import static org.assertj.core.api.Assertions.assertThat;

/**
* @author Ricken Bazolo
*/
@SpringBootTest(classes = MistralAiTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
public class MistralAiModerationModelIT {

private static final Logger logger = LoggerFactory.getLogger(MistralAiModerationModelIT.class);

@Autowired
private MistralAiModerationModel mistralAiModerationModel;

@Test
void moderationAsPositiveTest() {
var instructions = """
I want to kill them.!".""";

var moderationPrompt = new ModerationPrompt(instructions);

var moderationResponse = this.mistralAiModerationModel.call(moderationPrompt);

assertThat(moderationResponse.getResults()).hasSize(1);

var generation = moderationResponse.getResult();
Moderation moderation = generation.getOutput();
assertThat(moderation.getId()).isNotEmpty();
assertThat(moderation.getResults()).isNotNull();
assertThat(moderation.getResults().size()).isNotZero();
logger.info(moderation.getResults().toString());

assertThat(moderation.getId()).isNotNull();
assertThat(moderation.getModel()).isNotNull();

ModerationResult result = moderation.getResults().get(0);
assertThat(result.isFlagged()).isTrue();
assertThat(result.getCategories().isViolence()).isTrue();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
import org.springframework.ai.mistralai.moderation.MistralAiModerationModel;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.util.StringUtils;
Expand All @@ -35,6 +37,16 @@ public MistralAiApi mistralAiApi() {
return new MistralAiApi(apiKey);
}

@Bean
public MistralAiModerationApi mistralAiModerationApi() {
var apiKey = System.getenv("MISTRAL_AI_API_KEY");
if (!StringUtils.hasText(apiKey)) {
throw new IllegalArgumentException(
"Missing MISTRAL_AI_API_KEY environment variable. Please set it to your Mistral AI API key.");
}
return new MistralAiModerationApi(apiKey);
}

@Bean
public EmbeddingModel mistralAiEmbeddingModel(MistralAiApi api) {
return new MistralAiEmbeddingModel(api,
Expand All @@ -47,4 +59,9 @@ public MistralAiChatModel mistralAiChatModel(MistralAiApi mistralAiApi) {
MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.OPEN_MIXTRAL_7B.getValue()).build());
}

@Bean
public MistralAiModerationModel mistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi) {
return new MistralAiModerationModel(mistralAiModerationApi);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,17 @@ public void embeddingOptionsTest() {
});
}

@Test
public void moderationOptionsTest() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.mistralai.base-url=TEST_BASE_URL", "spring.ai.mistralai.api-key=abc123",
"spring.ai.mistralai.moderation.options.model=MODERATION_MODEL")
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class))
.run(context -> {
var moderationProperties = context.getBean(MistralAiModerationProperties.class);
assertThat(moderationProperties.getOptions().getModel()).isEqualTo("MODERATION_MODEL");
});
}

}

0 comments on commit aeeb029

Please sign in to comment.