Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support of Mistral AI Moderation API #2201

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package org.springframework.ai.mistralai.api;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;

import java.util.function.Consumer;

/**
* MistralAI Moderation API.
*
* @author Ricken Bazolo
* @see <a href= "https://docs.mistral.ai/capabilities/guardrailing/</a>
*/
public class MistralAiModerationApi {

private static final String DEFAULT_BASE_URL = "https://api.mistral.ai";

private final RestClient restClient;

public MistralAiModerationApi(String mistralAiApiKey) {
this(DEFAULT_BASE_URL, mistralAiApiKey, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}

public MistralAiModerationApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder,
ResponseErrorHandler responseErrorHandler) {

Consumer<HttpHeaders> jsonContentHeaders = headers -> {
headers.setBearerAuth(mistralAiApiKey);
headers.setContentType(MediaType.APPLICATION_JSON);
};

this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultHeaders(jsonContentHeaders)
.defaultStatusHandler(responseErrorHandler)
.build();
}

public ResponseEntity<MistralAiModerationResponse> moderate(MistralAiModerationRequest mistralAiModerationRequest) {
Assert.notNull(mistralAiModerationRequest, "Moderation request cannot be null.");
Assert.hasLength(mistralAiModerationRequest.prompt(), "Prompt cannot be empty.");
Assert.notNull(mistralAiModerationRequest.model(), "Model cannot be null.");

return this.restClient.post()
.uri("v1/moderations")
.body(mistralAiModerationRequest)
.retrieve()
.toEntity(MistralAiModerationResponse.class);
}

public enum Model {

// @formatter:off
MISTRAL_MODERATION("mistral-moderation-latest");
// @formatter:on

private final String value;

Model(String value) {
this.value = value;
}

public String getValue() {
return this.value;
}

}

// @formatter:off
@JsonInclude(JsonInclude.Include.NON_NULL)
public record MistralAiModerationRequest(
@JsonProperty("input") String prompt,
@JsonProperty("model") String model
) {

public MistralAiModerationRequest(String prompt) {
this(prompt, null);
}
}

@JsonInclude(JsonInclude.Include.NON_NULL)
public record MistralAiModerationResponse(
@JsonProperty("id") String id,
@JsonProperty("model") String model,
@JsonProperty("results") MistralAiModerationResult[] results) {

}

@JsonInclude(JsonInclude.Include.NON_NULL)
public record MistralAiModerationResult(
@JsonProperty("categories") Categories categories,
@JsonProperty("category_scores") CategoryScores categoryScores) {

public boolean flagged() {
return categories != null && (categories.sexual() || categories.hateAndDiscrimination() || categories.violenceAndThreats()
|| categories.selfHarm() || categories.dangerousAndCriminalContent() || categories.health()
|| categories.financial() || categories.law() || categories.pii());
}

}

@JsonInclude(JsonInclude.Include.NON_NULL)
public record Categories(
@JsonProperty("sexual") boolean sexual,
@JsonProperty("hate_and_discrimination") boolean hateAndDiscrimination,
@JsonProperty("violence_and_threats") boolean violenceAndThreats,
@JsonProperty("selfharm") boolean selfHarm,
@JsonProperty("dangerous_and_criminal_content") boolean dangerousAndCriminalContent,
@JsonProperty("health") boolean health,
@JsonProperty("financial") boolean financial,
@JsonProperty("law") boolean law,
@JsonProperty("pii") boolean pii) {

}

@JsonInclude(JsonInclude.Include.NON_NULL)
public record CategoryScores(
@JsonProperty("sexual") double sexual,
@JsonProperty("hate_and_discrimination") double hateAndDiscrimination,
@JsonProperty("violence_and_threats") double violenceAndThreats,
@JsonProperty("selfharm") double selfHarm,
@JsonProperty("dangerous_and_criminal_content") double dangerousAndCriminalContent,
@JsonProperty("health") double health,
@JsonProperty("financial") double financial,
@JsonProperty("law") double law,
@JsonProperty("pii") double pii) {

}
// @formatter:onn

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package org.springframework.ai.mistralai.moderation;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.moderation.*;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationRequest;
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResponse;
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResult;

import java.util.ArrayList;
import java.util.List;

/**
* @author Ricken Bazolo
*/
public class MistralAiModerationModel implements ModerationModel {

private final Logger logger = LoggerFactory.getLogger(getClass());

private final MistralAiModerationApi mistralAiModerationApi;

private final RetryTemplate retryTemplate;

private final MistralAiModerationOptions defaultOptions;

public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi) {
this(mistralAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE,
MistralAiModerationOptions.builder()
.model(MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue())
.build());
}

public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, MistralAiModerationOptions options) {
this(mistralAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE, options);
}

public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, RetryTemplate retryTemplate,
MistralAiModerationOptions options) {
Assert.notNull(mistralAiModerationApi, "mistralAiModerationApi must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
Assert.notNull(options, "options must not be null");
this.mistralAiModerationApi = mistralAiModerationApi;
this.retryTemplate = retryTemplate;
this.defaultOptions = options;
}

@Override
public ModerationResponse call(ModerationPrompt moderationPrompt) {
return this.retryTemplate.execute(ctx -> {

var instructions = moderationPrompt.getInstructions().getText();

var moderationRequest = new MistralAiModerationRequest(instructions);

if (this.defaultOptions != null) {
moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest,
MistralAiModerationRequest.class);
}
else {
// moderationPrompt.getOptions() never null but model can be empty, cause
// by ModerationPrompt constructor
moderationRequest = ModelOptionsUtils.merge(toMistralAiModerationOptions(moderationPrompt.getOptions()),
moderationRequest, MistralAiModerationRequest.class);
}

var moderationResponseEntity = this.mistralAiModerationApi.moderate(moderationRequest);

return convertResponse(moderationResponseEntity, moderationRequest);
});
}

private ModerationResponse convertResponse(ResponseEntity<MistralAiModerationResponse> moderationResponseEntity,
MistralAiModerationRequest openAiModerationRequest) {
var moderationApiResponse = moderationResponseEntity.getBody();
if (moderationApiResponse == null) {
logger.warn("No moderation response returned for request: {}", openAiModerationRequest);
return new ModerationResponse(new Generation());
}

List<ModerationResult> moderationResults = new ArrayList<>();
if (moderationApiResponse.results() != null) {

for (MistralAiModerationResult result : moderationApiResponse.results()) {
Categories categories = null;
CategoryScores categoryScores = null;
if (result.categories() != null) {
categories = Categories.builder()
.sexual(result.categories().sexual())
.pii(result.categories().pii())
.law(result.categories().law())
.financial(result.categories().financial())
.health(result.categories().health())
.dangerousAndCriminalContent(result.categories().dangerousAndCriminalContent())
.violence(result.categories().violenceAndThreats())
.hate(result.categories().hateAndDiscrimination())
.selfHarm(result.categories().selfHarm())
.build();
}
if (result.categoryScores() != null) {
categoryScores = CategoryScores.builder()
.sexual(result.categoryScores().sexual())
.pii(result.categoryScores().pii())
.law(result.categoryScores().law())
.financial(result.categoryScores().financial())
.health(result.categoryScores().health())
.dangerousAndCriminalContent(result.categoryScores().dangerousAndCriminalContent())
.violence(result.categoryScores().violenceAndThreats())
.hate(result.categoryScores().hateAndDiscrimination())
.selfHarm(result.categoryScores().selfHarm())
.build();
}
var moderationResult = ModerationResult.builder()
.categories(categories)
.categoryScores(categoryScores)
.flagged(result.flagged())
.build();
moderationResults.add(moderationResult);
}

}

var moderation = Moderation.builder()
.id(moderationApiResponse.id())
.model(moderationApiResponse.model())
.results(moderationResults)
.build();

return new ModerationResponse(new Generation(moderation));
}

private MistralAiModerationOptions toMistralAiModerationOptions(ModerationOptions runtimeModerationOptions) {
var mistralAiModerationOptionsBuilder = MistralAiModerationOptions.builder();
if (runtimeModerationOptions != null && runtimeModerationOptions.getModel() != null) {
mistralAiModerationOptionsBuilder.model(runtimeModerationOptions.getModel());
}
return mistralAiModerationOptionsBuilder.build();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package org.springframework.ai.mistralai.moderation;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
import org.springframework.ai.moderation.ModerationOptions;

/**
* @author Ricken Bazolo
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public class MistralAiModerationOptions implements ModerationOptions {

private static final String DEFAULT_MODEL = MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue();

/**
* The model to use for moderation generation.
*/
@JsonProperty("model")
private String model = DEFAULT_MODEL;

public static Builder builder() {
return new Builder();
}

@Override
public String getModel() {
return this.model;
}

public void setModel(String model) {
this.model = model;
}

public static final class Builder {

private final MistralAiModerationOptions options;

private Builder() {
this.options = new MistralAiModerationOptions();
}

public Builder model(String model) {
this.options.setModel(model);
return this;
}

public MistralAiModerationOptions build() {
return this.options;
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package org.springframework.ai.mistralai;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.mistralai.moderation.MistralAiModerationModel;
import org.springframework.ai.moderation.Moderation;
import org.springframework.ai.moderation.ModerationPrompt;
import org.springframework.ai.moderation.ModerationResult;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;

import static org.assertj.core.api.Assertions.assertThat;

/**
* @author Ricken Bazolo
*/
@SpringBootTest(classes = MistralAiTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
public class MistralAiModerationModelIT {

private static final Logger logger = LoggerFactory.getLogger(MistralAiModerationModelIT.class);

@Autowired
private MistralAiModerationModel mistralAiModerationModel;

@Test
void moderationAsPositiveTest() {
var instructions = """
I want to kill them.!".""";

var moderationPrompt = new ModerationPrompt(instructions);

var moderationResponse = this.mistralAiModerationModel.call(moderationPrompt);

assertThat(moderationResponse.getResults()).hasSize(1);

var generation = moderationResponse.getResult();
Moderation moderation = generation.getOutput();
assertThat(moderation.getId()).isNotEmpty();
assertThat(moderation.getResults()).isNotNull();
assertThat(moderation.getResults().size()).isNotZero();
logger.info(moderation.getResults().toString());

assertThat(moderation.getId()).isNotNull();
assertThat(moderation.getModel()).isNotNull();

ModerationResult result = moderation.getResults().get(0);
assertThat(result.isFlagged()).isTrue();
assertThat(result.getCategories().isViolence()).isTrue();
}

}
Loading