Skip to content

Commit

Permalink
Add Support of Mistral AI Moderation API
Browse files Browse the repository at this point in the history
Signed-off-by: Ricken Bazolo <[email protected]>
  • Loading branch information
ricken07 committed Feb 9, 2025
1 parent 4874374 commit f31ace1
Show file tree
Hide file tree
Showing 7 changed files with 526 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package org.springframework.ai.mistralai.api;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;

import java.util.function.Consumer;

/**
* MistralAI Moderation API.
*
* @author Ricken Bazolo
* @see <a href= "https://docs.mistral.ai/capabilities/guardrailing/</a>
*/
public class MistralAiModerationApi {

private static final String DEFAULT_BASE_URL = "https://api.mistral.ai";

private final RestClient restClient;

public MistralAiModerationApi(String mistralAiApiKey) {
this(DEFAULT_BASE_URL, mistralAiApiKey, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}

public MistralAiModerationApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder,
ResponseErrorHandler responseErrorHandler) {

Consumer<HttpHeaders> jsonContentHeaders = headers -> {
headers.setBearerAuth(mistralAiApiKey);
headers.setContentType(MediaType.APPLICATION_JSON);
};

this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultHeaders(jsonContentHeaders)
.defaultStatusHandler(responseErrorHandler)
.build();
}

public ResponseEntity<MistralAiModerationResponse> moderate(MistralAiModerationRequest mistralAiModerationRequest) {
Assert.notNull(mistralAiModerationRequest, "Moderation request cannot be null.");
Assert.hasLength(mistralAiModerationRequest.prompt(), "Prompt cannot be empty.");
Assert.notNull(mistralAiModerationRequest.model(), "Model cannot be null.");

return this.restClient.post()
.uri("v1/moderations")
.body(mistralAiModerationRequest)
.retrieve()
.toEntity(MistralAiModerationResponse.class);
}

public enum Model {

// @formatter:off
MISTRAL_MODERATION("mistral-moderation-latest");
// @formatter:on

private final String value;

Model(String value) {
this.value = value;
}

public String getValue() {
return this.value;
}

}

// @formatter:off
@JsonInclude(JsonInclude.Include.NON_NULL)
public record MistralAiModerationRequest(
@JsonProperty("input") String prompt,
@JsonProperty("model") String model
) {

public MistralAiModerationRequest(String prompt) {
this(prompt, null);
}
}

@JsonInclude(JsonInclude.Include.NON_NULL)
public record MistralAiModerationResponse(
@JsonProperty("id") String id,
@JsonProperty("model") String model,
@JsonProperty("results") MistralAiModerationResult[] results) {

}

@JsonInclude(JsonInclude.Include.NON_NULL)
public record MistralAiModerationResult(
@JsonProperty("categories") Categories categories,
@JsonProperty("category_scores") CategoryScores categoryScores) {

public boolean flagged() {
return categories != null && (categories.sexual() || categories.hateAndDiscrimination() || categories.violenceAndThreats()
|| categories.selfHarm() || categories.dangerousAndCriminalContent() || categories.health()
|| categories.financial() || categories.law() || categories.pii());
}

}

@JsonInclude(JsonInclude.Include.NON_NULL)
public record Categories(
@JsonProperty("sexual") boolean sexual,
@JsonProperty("hate_and_discrimination") boolean hateAndDiscrimination,
@JsonProperty("violence_and_threats") boolean violenceAndThreats,
@JsonProperty("selfharm") boolean selfHarm,
@JsonProperty("dangerous_and_criminal_content") boolean dangerousAndCriminalContent,
@JsonProperty("health") boolean health,
@JsonProperty("financial") boolean financial,
@JsonProperty("law") boolean law,
@JsonProperty("pii") boolean pii) {

}

@JsonInclude(JsonInclude.Include.NON_NULL)
public record CategoryScores(
@JsonProperty("sexual") double sexual,
@JsonProperty("hate_and_discrimination") double hateAndDiscrimination,
@JsonProperty("violence_and_threats") double violenceAndThreats,
@JsonProperty("selfharm") double selfHarm,
@JsonProperty("dangerous_and_criminal_content") double dangerousAndCriminalContent,
@JsonProperty("health") double health,
@JsonProperty("financial") double financial,
@JsonProperty("law") double law,
@JsonProperty("pii") double pii) {

}
// @formatter:onn

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package org.springframework.ai.mistralai.moderation;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.moderation.*;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationRequest;
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResponse;
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResult;

import java.util.ArrayList;
import java.util.List;

/**
* @author Ricken Bazolo
*/
public class MistralAiModerationModel implements ModerationModel {

private final Logger logger = LoggerFactory.getLogger(getClass());

private final MistralAiModerationApi mistralAiModerationApi;

private final RetryTemplate retryTemplate;

private final MistralAiModerationOptions defaultOptions;

public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi) {
this(mistralAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE,
MistralAiModerationOptions.builder()
.model(MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue())
.build());
}

public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, MistralAiModerationOptions options) {
this(mistralAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE, options);
}

public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, RetryTemplate retryTemplate,
MistralAiModerationOptions options) {
Assert.notNull(mistralAiModerationApi, "mistralAiModerationApi must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");
Assert.notNull(options, "options must not be null");
this.mistralAiModerationApi = mistralAiModerationApi;
this.retryTemplate = retryTemplate;
this.defaultOptions = options;
}

@Override
public ModerationResponse call(ModerationPrompt moderationPrompt) {
return this.retryTemplate.execute(ctx -> {

var instructions = moderationPrompt.getInstructions().getText();

var moderationRequest = new MistralAiModerationRequest(instructions);

if (this.defaultOptions != null) {
moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest,
MistralAiModerationRequest.class);
}
else {
// moderationPrompt.getOptions() never null but model can be empty, cause
// by ModerationPrompt constructor
moderationRequest = ModelOptionsUtils.merge(toMistralAiModerationOptions(moderationPrompt.getOptions()),
moderationRequest, MistralAiModerationRequest.class);
}

var moderationResponseEntity = this.mistralAiModerationApi.moderate(moderationRequest);

return convertResponse(moderationResponseEntity, moderationRequest);
});
}

private ModerationResponse convertResponse(ResponseEntity<MistralAiModerationResponse> moderationResponseEntity,
MistralAiModerationRequest openAiModerationRequest) {
var moderationApiResponse = moderationResponseEntity.getBody();
if (moderationApiResponse == null) {
logger.warn("No moderation response returned for request: {}", openAiModerationRequest);
return new ModerationResponse(new Generation());
}

List<ModerationResult> moderationResults = new ArrayList<>();
if (moderationApiResponse.results() != null) {

for (MistralAiModerationResult result : moderationApiResponse.results()) {
Categories categories = null;
CategoryScores categoryScores = null;
if (result.categories() != null) {
categories = Categories.builder()
.sexual(result.categories().sexual())
.pii(result.categories().pii())
.law(result.categories().law())
.financial(result.categories().financial())
.health(result.categories().health())
.dangerousAndCriminalContent(result.categories().dangerousAndCriminalContent())
.violence(result.categories().violenceAndThreats())
.hate(result.categories().hateAndDiscrimination())
.selfHarm(result.categories().selfHarm())
.build();
}
if (result.categoryScores() != null) {
categoryScores = CategoryScores.builder()
.sexual(result.categoryScores().sexual())
.pii(result.categoryScores().pii())
.law(result.categoryScores().law())
.financial(result.categoryScores().financial())
.health(result.categoryScores().health())
.dangerousAndCriminalContent(result.categoryScores().dangerousAndCriminalContent())
.violence(result.categoryScores().violenceAndThreats())
.hate(result.categoryScores().hateAndDiscrimination())
.selfHarm(result.categoryScores().selfHarm())
.build();
}
var moderationResult = ModerationResult.builder()
.categories(categories)
.categoryScores(categoryScores)
.flagged(result.flagged())
.build();
moderationResults.add(moderationResult);
}

}

var moderation = Moderation.builder()
.id(moderationApiResponse.id())
.model(moderationApiResponse.model())
.results(moderationResults)
.build();

return new ModerationResponse(new Generation(moderation));
}

private MistralAiModerationOptions toMistralAiModerationOptions(ModerationOptions runtimeModerationOptions) {
var mistralAiModerationOptionsBuilder = MistralAiModerationOptions.builder();
if (runtimeModerationOptions != null && runtimeModerationOptions.getModel() != null) {
mistralAiModerationOptionsBuilder.model(runtimeModerationOptions.getModel());
}
return mistralAiModerationOptionsBuilder.build();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package org.springframework.ai.mistralai.moderation;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
import org.springframework.ai.moderation.ModerationOptions;

/**
* @author Ricken Bazolo
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public class MistralAiModerationOptions implements ModerationOptions {

private static final String DEFAULT_MODEL = MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue();

/**
* The model to use for moderation generation.
*/
@JsonProperty("model")
private String model = DEFAULT_MODEL;

public static Builder builder() {
return new Builder();
}

@Override
public String getModel() {
return this.model;
}

public void setModel(String model) {
this.model = model;
}

public static final class Builder {

private final MistralAiModerationOptions options;

private Builder() {
this.options = new MistralAiModerationOptions();
}

public Builder model(String model) {
this.options.setModel(model);
return this;
}

public MistralAiModerationOptions build() {
return this.options;
}

}

}
Loading

0 comments on commit f31ace1

Please sign in to comment.