diff --git a/models/spring-ai-vertex-ai-imagen/pom.xml b/models/spring-ai-vertex-ai-imagen/pom.xml new file mode 100644 index 0000000000..973c053c0d --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/pom.xml @@ -0,0 +1,109 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-M5 + ../../pom.xml + + + spring-ai-vertex-ai-imagen + jar + Spring AI Model - Vertex AI Imagen + Vertex AI Imagen models support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + + + com.google.cloud + libraries-bom + ${com.google.cloud.version} + pom + import + + + + + + + + com.google.cloud + google-cloud-aiplatform + + + commons-logging + commons-logging + + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + + org.springframework + spring-context-support + + + + org.springframework.boot + spring-boot-starter-logging + + + + io.micrometer + micrometer-observation-test + test + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + + diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenConnectionDetails.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenConnectionDetails.java new file mode 100644 index 0000000000..a466b95f57 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenConnectionDetails.java @@ -0,0 +1,181 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.imagen; + +import java.io.IOException; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; + +import org.springframework.util.StringUtils; + + +/** + * VertexAiImagenConnectionDetails represents the details of a connection to the Vertex AI imagen service. + * It provides methods to access the project ID, location, publisher, and PredictionServiceSettings. + * + * @author Sami Marzouki + */ +public class VertexAiImagenConnectionDetails { + + public static final String DEFAULT_ENDPOINT = "us-central1-aiplatform.googleapis.com:443"; + + public static final String DEFAULT_ENDPOINT_SUFFIX = "-aiplatform.googleapis.com:443"; + + public static final String DEFAULT_PUBLISHER = "google"; + + private static final String DEFAULT_LOCATION = "us-central1"; + + /** + * Your project ID. + */ + private final String projectId; + + /** + * A location is a region + * you can specify in a request to control where data is stored at rest. For a list of + * available regions, see Generative + * AI on Vertex AI locations. + */ + private final String location; + + private final String publisher; + + private final PredictionServiceSettings predictionServiceSettings; + + public VertexAiImagenConnectionDetails(String projectId, String location, String publisher, + PredictionServiceSettings predictionServiceSettings) { + this.projectId = projectId; + this.location = location; + this.publisher = publisher; + this.predictionServiceSettings = predictionServiceSettings; + } + + public static Builder builder() { + return new Builder(); + } + + public String getProjectId() { + return this.projectId; + } + + public String getLocation() { + return this.location; + } + + public String getPublisher() { + return this.publisher; + } + + public EndpointName getEndpointName(String modelName) { + return EndpointName.ofProjectLocationPublisherModelName(this.projectId, this.location, this.publisher, + modelName); + } + + public com.google.cloud.aiplatform.v1.PredictionServiceSettings getPredictionServiceSettings() { + return this.predictionServiceSettings; + } + + public static class Builder { + + /** + * The Vertex AI embedding endpoint. + */ + private String endpoint; + + /** + * Your project ID. + */ + private String projectId; + + /** + * A location is a + * region you can + * specify in a request to control where data is stored at rest. For a list of + * available regions, see Generative + * AI on Vertex AI locations. + */ + private String location; + + /** + * + */ + private String publisher; + + /** + * Allows the connection settings to be customised + */ + private PredictionServiceSettings predictionServiceSettings; + + public Builder apiEndpoint(String endpoint) { + this.endpoint = endpoint; + return this; + } + + public Builder projectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder location(String location) { + this.location = location; + return this; + } + + public Builder publisher(String publisher) { + this.publisher = publisher; + return this; + } + + public Builder predictionServiceSettings(PredictionServiceSettings predictionServiceSettings) { + this.predictionServiceSettings = predictionServiceSettings; + return this; + } + + public VertexAiImagenConnectionDetails build() { + if (!StringUtils.hasText(this.endpoint)) { + if (!StringUtils.hasText(this.location)) { + this.endpoint = DEFAULT_ENDPOINT; + this.location = DEFAULT_LOCATION; + } else { + this.endpoint = this.location + DEFAULT_ENDPOINT_SUFFIX; + } + } + + if (!StringUtils.hasText(this.publisher)) { + this.publisher = DEFAULT_PUBLISHER; + } + + if (this.predictionServiceSettings == null) { + try { + this.predictionServiceSettings = PredictionServiceSettings.newBuilder() + .setEndpoint(this.endpoint) + .build(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + return new VertexAiImagenConnectionDetails(this.projectId, this.location, this.publisher, + this.predictionServiceSettings); + } + + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java new file mode 100644 index 0000000000..837bc3747c --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModel.java @@ -0,0 +1,256 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.imagen; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.protobuf.Value; +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.image.Image; +import org.springframework.ai.image.ImageGeneration; +import org.springframework.ai.image.ImageGenerationMetadata; +import org.springframework.ai.image.ImageModel; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.image.observation.DefaultImageModelObservationConvention; +import org.springframework.ai.image.observation.ImageModelObservationContext; +import org.springframework.ai.image.observation.ImageModelObservationConvention; +import org.springframework.ai.image.observation.ImageModelObservationDocumentation; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.vertexai.imagen.VertexAiImagenUtils.ImageInstanceBuilder; +import org.springframework.ai.vertexai.imagen.VertexAiImagenUtils.ImageParametersBuilder; +import org.springframework.ai.vertexai.imagen.metadata.VertexAiImagenImageGenerationMetadata; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +/** + * VertexAiImagenImageModel is a class that implements the ImageModel interface. It + * provides a client for calling the Imagen on Vertex AI image generation API. + * + * @author Sami Marzouki + */ +public class VertexAiImagenImageModel implements ImageModel { + + private static final ImageModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultImageModelObservationConvention(); + + /** + * The default options used for the image completion requests. + */ + private final VertexAiImagenImageOptions defaultOptions; + + /** + * The connection details for Imagen on Vertex AI. + */ + private final VertexAiImagenConnectionDetails connectionDetails; + + /** + * The retry template used to retry the Imagen on Vertex AI Image API calls. + */ + private final RetryTemplate retryTemplate; + + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private ImageModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + public VertexAiImagenImageModel(VertexAiImagenConnectionDetails connectionDetails, + VertexAiImagenImageOptions defaultOptions) { + this(connectionDetails, defaultOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public VertexAiImagenImageModel(VertexAiImagenConnectionDetails connectionDetails, + VertexAiImagenImageOptions defaultOptions, RetryTemplate retryTemplate) { + this(connectionDetails, defaultOptions, retryTemplate, ObservationRegistry.NOOP); + } + + public VertexAiImagenImageModel(VertexAiImagenConnectionDetails connectionDetails, + VertexAiImagenImageOptions defaultOptions, RetryTemplate retryTemplate, + ObservationRegistry observationRegistry) { + Assert.notNull(defaultOptions, "options must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); + Assert.notNull(observationRegistry, "observationRegistry must not be null"); + this.connectionDetails = connectionDetails; + this.defaultOptions = defaultOptions; + this.retryTemplate = retryTemplate; + this.observationRegistry = observationRegistry; + } + + private static ImageParametersBuilder getImageParametersBuilder(VertexAiImagenImageOptions finalOptions) { + ImageParametersBuilder parametersBuilder = ImageParametersBuilder.of(); + + if (finalOptions.getN() != null) { + parametersBuilder.sampleCount(finalOptions.getN()); + } + if (finalOptions.getSeed() != null) { + parametersBuilder.seed(finalOptions.getSeed()); + } + if (finalOptions.getNegativePrompt() != null) { + parametersBuilder.negativePrompt(finalOptions.getNegativePrompt()); + } + if (finalOptions.getAspectRatio() != null) { + parametersBuilder.aspectRatio(finalOptions.getAspectRatio()); + } + if (finalOptions.getAddWatermark() != null) { + parametersBuilder.addWatermark(finalOptions.getAddWatermark()); + } + if (finalOptions.getStorageUri() != null) { + parametersBuilder.storageUri(finalOptions.getStorageUri()); + } + if (finalOptions.getPersonGeneration() != null) { + parametersBuilder.personGeneration(finalOptions.getPersonGeneration()); + } + if (finalOptions.getSafetySetting() != null) { + parametersBuilder.safetySetting(finalOptions.getSafetySetting()); + } + if (finalOptions.getOutputOptions() != null) { + + ImageParametersBuilder.OutputOptions outputOptions = ImageParametersBuilder.OutputOptions.of(); + if (finalOptions.getOutputOptions().getMimeType() != null) { + outputOptions.mimeType(finalOptions.getOutputOptions().getMimeType()); + } + if (finalOptions.getOutputOptions().getCompressionQuality() != null) { + outputOptions.compressionQuality(finalOptions.getOutputOptions().getCompressionQuality()); + } + + parametersBuilder.outputOptions(outputOptions.build()); + } + + return parametersBuilder; + } + + @Override + public ImageResponse call(ImagePrompt imagePrompt) { + VertexAiImagenImageOptions finalOptions = mergedOptions(imagePrompt); + + var observationContext = ImageModelObservationContext.builder() + .imagePrompt(imagePrompt) + .provider(AiProvider.VERTEX_AI.value()) + .requestOptions(finalOptions) + .build(); + + return ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + PredictionServiceClient client = createPredictionServiceClient(); + + EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); + + PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(imagePrompt, endpointName, + finalOptions); + + PredictResponse imageResponse = this.retryTemplate + .execute(context -> getPredictResponse(client, predictRequestBuilder)); + + List imageGenerationList = new ArrayList<>(); + for (Value prediction : imageResponse.getPredictionsList()) { + Value bytesBase64Encoded = prediction.getStructValue().getFieldsOrThrow("bytesBase64Encoded"); + Value mimeType = prediction.getStructValue().getFieldsOrThrow("mimeType"); + ImageGenerationMetadata metadata = new VertexAiImagenImageGenerationMetadata( + imagePrompt.getInstructions().get(0).getText(), finalOptions.getModel(), + mimeType.getStringValue()); + Image image = new Image(null, bytesBase64Encoded.getStringValue()); + imageGenerationList.add(new ImageGeneration(image, metadata)); + } + ImageResponse response = new ImageResponse(imageGenerationList); + + observationContext.setResponse(response); + + return response; + + }); + } + + private VertexAiImagenImageOptions mergedOptions(ImagePrompt imagePrompt) { + + VertexAiImagenImageOptions mergedOptions = this.defaultOptions; + + if (imagePrompt.getOptions() != null) { + var defaultOptionsCopy = VertexAiImagenImageOptions.builder().from(this.defaultOptions).build(); + mergedOptions = ModelOptionsUtils.merge(imagePrompt.getOptions(), defaultOptionsCopy, + VertexAiImagenImageOptions.class); + } + + return mergedOptions; + } + + protected PredictRequest.Builder getPredictRequestBuilder(ImagePrompt imagePrompt, EndpointName endpointName, + VertexAiImagenImageOptions finalOptions) { + PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder().setEndpoint(endpointName.toString()); + + ImageParametersBuilder parametersBuilder = getImageParametersBuilder(finalOptions); + if (finalOptions.getOutputOptions() != null) { + ImageParametersBuilder.OutputOptions outputOptionsBuilder = ImageParametersBuilder.OutputOptions.of(); + if (finalOptions.getResponseFormat() != null) { + outputOptionsBuilder.mimeType(finalOptions.getResponseFormat()); + } + if (finalOptions.getCompressionQuality() != null) { + outputOptionsBuilder.compressionQuality(finalOptions.getCompressionQuality()); + } + parametersBuilder.outputOptions(outputOptionsBuilder.build()); + } + + predictRequestBuilder.setParameters(VertexAiImagenUtils.valueOf(parametersBuilder.build())); + + for (int i = 0; i < imagePrompt.getInstructions().size(); i++) { + + ImageInstanceBuilder instanceBuilder = ImageInstanceBuilder + .of(imagePrompt.getInstructions().get(i).getText()); + predictRequestBuilder.addInstances(VertexAiImagenUtils.valueOf(instanceBuilder.build())); + } + return predictRequestBuilder; + } + + // for testing + protected PredictionServiceClient createPredictionServiceClient() { + try { + return PredictionServiceClient.create(this.connectionDetails.getPredictionServiceSettings()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + // for testing + protected PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) { + return client.predict(predictRequestBuilder.build()); + } + + /** + * Use the provided convention for reporting observation data. + * + * @param observationConvention The provided convention + */ + public void setObservationConvention(ImageModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModelName.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModelName.java new file mode 100644 index 0000000000..f44cadd4d5 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageModelName.java @@ -0,0 +1,48 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.imagen; + +/** + * Imagen on VertexAI Models: + * - Image generation + * + * @author Sami Marzouki + */ +public enum VertexAiImagenImageModelName { + + IMAGEN_3("imagen-3.0-generate-001"), + + IMAGEN_3_FAST("imagen-3.0-fast-generate-001"), + + IMAGEN_3_CUSTOMIZATION_AND_EDITING("imagen-3.0-capability-001"), + + IMAGEN_2_V006("imagegeneration@006"), + + IMAGEN_2_V005("imagegeneration@005"), + + IMAGEN_1_V002("imagegeneration@002"); + + private final String value; + + VertexAiImagenImageModelName(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } +} diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageOptions.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageOptions.java new file mode 100644 index 0000000000..2e0cabaef4 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenImageOptions.java @@ -0,0 +1,435 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vertexai.imagen; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.image.ImageOptions; + +import static org.springframework.ai.vertexai.imagen.VertexAiImagenUtils.calculateSizeFromAspectRatio; + +/** + * Options for the Vertex AI Image service. + * + * @author Sami Marzouki + */ +public class VertexAiImagenImageOptions implements ImageOptions { + + public static final String DEFAULT_MODEL_NAME = VertexAiImagenImageModelName.IMAGEN_2_V006.getValue(); + + /** + * Required: int + * The number of images to generate. The default value is 4. The + * imagen-3.0-generate-001 model supports values 1 through 4. The + * imagen-3.0-fast-generate-001 model supports values 1 through 4. The + * imagegeneration@006 model supports values 1 through 4. The imagegeneration@005 + * model supports values 1 through 4. The imagegeneration@002 model supports values 1 + * through 8. + */ + @JsonProperty("sampleCount") + private Integer n; + + /** + * The model to use for image generation. + */ + @JsonProperty("model") + private String model; + + /** + * Optional: Uint32 + * The random seed for image generation. This is not available when addWatermark is set to true. + */ + @JsonProperty("seed") + private Integer seed; + + /** + * Optional: string + * A description of what to discourage in the generated images. + * The imagen-3.0-generate-001 model supports up to 480 tokens. + * The imagen-3.0-fast-generate-001 model supports up to 480 tokens. + * The imagegeneration@006 model supports up to 128 tokens. + * The imagegeneration@005 model supports up to 128 tokens. + * The imagegeneration@002 model supports up to 64 tokens. + */ + @JsonProperty("negativePrompt") + private String negativePrompt; + + /** + * Optional: string + * The aspect ratio for the image. The default value is "1:1". + * The imagen-3.0-generate-001 model supports "1:1", "9:16", "16:9", "3:4", or "4:3". + * The imagen-3.0-fast-generate-001 model supports "1:1", "9:16", "16:9", "3:4", or "4:3". + * The imagegeneration@006 model supports "1:1", "9:16", "16:9", "3:4", or "4:3". + * The imagegeneration@005 model supports "1:1" or "9:16". + * The imagegeneration@002 model supports "1:1". + */ + @JsonProperty("aspectRatio") + private String aspectRatio; + + /** + * Optional: outputOptions + * Describes the output image format in an outputOptions object. + * + * @see OutputOptions + */ + @JsonProperty("outputOptions") + private OutputOptions outputOptions; + + /** + * Optional: string (imagegeneration@002 only) + * Describes the style for the generated images. The following values are supported: + * "photograph", "digital_art", "landscape", "sketch", "watercolor", "cyberpunk", "pop_art". + */ + @JsonProperty("sampleImageStyle") + private String style; + + /** + * Optional: string (imagen-3.0-generate-001, imagen-3.0-fast-generate-001, and imagegeneration@006 only) + * Allow generation of people by the model. The following values are supported: + * "dont_allow": Disallow the inclusion of people or faces in images. + * "allow_adult": Allow generation of adults only. + * "allow_all": Allow generation of people of all ages. + * The default value is "allow_adult". + */ + @JsonProperty("personGeneration") + private String personGeneration; + + /** + * Optional: string (imagen-3.0-generate-001, imagen-3.0-fast-generate-001, and imagegeneration@006 only) + * Adds a filter level to safety filtering. The following values are supported: + * "block_low_and_above": Strongest filtering level, most strict blocking. Deprecated value: "block_most". + * "block_medium_and_above": Block some problematic prompts and responses. Deprecated value: "block_some". + * "block_only_high": Reduces the number of requests blocked due to safety filters. May increase objectionable + * content generated by Imagen. Deprecated value: "block_few". + * "block_none": Block very few problematic prompts and responses. Access to this feature is restricted. + * Previous field value: "block_fewest". + * The default value is "block_medium_and_above". + */ + @JsonProperty("safetySetting") + private String safetySetting; + + /** + * Optional: bool + * Add an invisible watermark to the generated images. + * The default value is false for the imagegeneration@002 and imagegeneration@005 models, + * and true for the imagen-3.0-fast-generate-001, imagegeneration@006, and imagegeneration@006 models. + */ + @JsonProperty("addWatermark") + private Boolean addWatermark; + + /** + * Optional: string + * Cloud Storage URI to store the generated images. + */ + @JsonProperty("storageUri") + private String storageUri; + + private List size; + + public static Builder builder() { + return new Builder(); + } + + @Override + public Integer getN() { + return this.n; + } + + public void setN(Integer n) { + this.n = n; + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Integer getWidth() { + if (this.size == null || this.size.isEmpty()) { + return null; + } + return this.size.get(0); + } + + @Override + public Integer getHeight() { + if (this.size == null || this.size.isEmpty()) { + return null; + } + return this.size.get(1); + } + + @Override + public String getStyle() { + return this.style; + } + + public void setStyle(String style) { + this.style = style; + } + + @Override + public String getResponseFormat() { + if (this.outputOptions == null) { + return null; + } + return this.outputOptions.mimeType; + } + + public Integer getCompressionQuality() { + if (this.outputOptions == null) { + return null; + } + return this.outputOptions.compressionQuality; + } + + public OutputOptions getOutputOptions() { + return this.outputOptions; + } + + public Integer getSeed() { + return seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + public String getNegativePrompt() { + return negativePrompt; + } + + public void setNegativePrompt(String negativePrompt) { + this.negativePrompt = negativePrompt; + } + + public String getAspectRatio() { + return aspectRatio; + } + + public void setAspectRatio(String aspectRatio) { + this.aspectRatio = aspectRatio; + } + + public String getPersonGeneration() { + return personGeneration; + } + + public void setPersonGeneration(String personGeneration) { + this.personGeneration = personGeneration; + } + + public String getSafetySetting() { + return safetySetting; + } + + public void setSafetySetting(String safetySetting) { + this.safetySetting = safetySetting; + } + + public Boolean getAddWatermark() { + return addWatermark; + } + + public void setAddWatermark(Boolean addWatermark) { + this.addWatermark = addWatermark; + } + + public String getStorageUri() { + return storageUri; + } + + public void setStorageUri(String storageUri) { + this.storageUri = storageUri; + } + + public void setSize(List size) { + this.size = size; + } + + public static final class OutputOptions { + + @JsonProperty("mimeType") + private String mimeType; + + @JsonProperty("compressionQuality") + private Integer compressionQuality; + + public static Builder builder() { + return new Builder(); + } + + public String getMimeType() { + return mimeType; + } + + public void setMimeType(String mimeType) { + this.mimeType = mimeType; + } + + public Integer getCompressionQuality() { + return compressionQuality; + } + + public void setCompressionQuality(Integer compressionQuality) { + this.compressionQuality = compressionQuality; + } + + public static final class Builder { + + private final OutputOptions options; + + private Builder() { + this.options = new OutputOptions(); + } + + public Builder mimeType(String format) { + this.options.setMimeType(format); + return this; + } + + public Builder compressionQuality(Integer compressionQuality) { + this.options.setCompressionQuality(compressionQuality); + return this; + } + + public OutputOptions build() { + return this.options; + } + } + } + + public static final class Builder { + + private final VertexAiImagenImageOptions options; + + private Builder() { + this.options = new VertexAiImagenImageOptions(); + } + + public Builder from(VertexAiImagenImageOptions fromOptions) { + if (fromOptions.getN() != null) { + this.options.setN(fromOptions.getN()); + } + if (fromOptions.getModel() != null) { + this.options.setModel(fromOptions.getModel()); + } + if (fromOptions.getAspectRatio() != null) { + this.options.setAspectRatio(fromOptions.getAspectRatio()); + this.options.setSize(calculateSizeFromAspectRatio(fromOptions.getAspectRatio())); + } + if (fromOptions.getStyle() != null) { + this.options.setStyle(fromOptions.getStyle()); + } + if (fromOptions.getOutputOptions() != null) { + if (fromOptions.getResponseFormat() != null) { + this.options.outputOptions.setMimeType(fromOptions.getResponseFormat()); + } + if (fromOptions.getCompressionQuality() != null) { + this.options.outputOptions.setCompressionQuality(fromOptions.getCompressionQuality()); + } + } + if (fromOptions.getSeed() != null) { + this.options.setSeed(fromOptions.getSeed()); + } + if (fromOptions.getNegativePrompt() != null) { + this.options.setNegativePrompt(fromOptions.getNegativePrompt()); + } + if (fromOptions.getPersonGeneration() != null) { + this.options.setPersonGeneration(fromOptions.getPersonGeneration()); + } + if (fromOptions.getSafetySetting() != null) { + this.options.setSafetySetting(fromOptions.getSafetySetting()); + } + if (fromOptions.getAddWatermark() != null) { + this.options.setAddWatermark(fromOptions.getAddWatermark()); + } + if (fromOptions.getStorageUri() != null) { + this.options.setStorageUri(fromOptions.getStorageUri()); + } + + return this; + } + + public Builder N(Integer n) { + this.options.setN(n); + return this; + } + + public Builder model(String model) { + this.options.setModel(model); + return this; + } + + public Builder seed(Integer seed) { + this.options.setSeed(seed); + return this; + } + + public Builder negativePrompt(String negativePrompt) { + this.options.setNegativePrompt(negativePrompt); + return this; + } + + public Builder aspectRatio(String aspectRatio) { + this.options.setAspectRatio(aspectRatio); + this.options.setSize(calculateSizeFromAspectRatio(aspectRatio)); + return this; + } + + public Builder outputOptions(OutputOptions outputOptions) { + this.options.outputOptions = outputOptions; + return this; + } + + public Builder personGeneration(String personGeneration) { + this.options.setPersonGeneration(personGeneration); + return this; + } + + public Builder safetySetting(String safetySetting) { + this.options.setSafetySetting(safetySetting); + return this; + } + + public Builder addWatermark(Boolean addWatermark) { + this.options.setAddWatermark(addWatermark); + return this; + } + + public Builder storageUri(String storageUri) { + this.options.setStorageUri(storageUri); + return this; + } + + public Builder style(String style) { + this.options.setStyle(style); + return this; + } + + public VertexAiImagenImageOptions build() { + return this.options; + } + + } +} diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenUtils.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenUtils.java new file mode 100644 index 0000000000..d4650105e7 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/VertexAiImagenUtils.java @@ -0,0 +1,230 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.imagen; + +import java.util.Arrays; +import java.util.List; + +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; + +import org.springframework.util.Assert; + +/** + * Utility class for constructing parameter objects for Imagen on Vertex AI requests. + * + * @author Sami Marzouki + */ +public abstract class VertexAiImagenUtils { + + public static Value valueOf(boolean n) { + return Value.newBuilder().setBoolValue(n).build(); + } + + public static Value valueOf(String s) { + return Value.newBuilder().setStringValue(s).build(); + } + + public static Value valueOf(int n) { + return Value.newBuilder().setNumberValue(n).build(); + } + + public static Value valueOf(Struct struct) { + return Value.newBuilder().setStructValue(struct).build(); + } + + public static Value jsonToValue(String json) throws InvalidProtocolBufferException { + Value.Builder builder = Value.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } + + public static List calculateSizeFromAspectRatio(String aspectRatio) { + if (aspectRatio != null) { + return switch (aspectRatio) { + case "1:1" -> List.of(1024, 1024); + case "9:16" -> List.of(900, 1600); + case "16:9" -> List.of(1600, 900); + case "3:4" -> List.of(750, 1000); + case "4:3" -> List.of(1000, 750); + default -> throw new IllegalStateException("Unexpected value: " + aspectRatio + + " aspect ratio must be one of these values : ['1:1', '9:16', '16:9', '3:4', or '4:3']"); + }; + } + return Arrays.asList(1024, 1024); + } + + public static class ImageInstanceBuilder { + + public String prompt; + + public static ImageInstanceBuilder of(String prompt) { + Assert.hasText(prompt, "Prompt must not be empty"); + var builder = new ImageInstanceBuilder(); + builder.prompt = prompt; + return builder; + } + + public Struct build() { + Struct.Builder textBuilder = Struct.newBuilder(); + textBuilder.putFields("prompt", valueOf(this.prompt)); + return textBuilder.build(); + } + } + + public static class ImageParametersBuilder { + + public Integer sampleCount; + public Integer seed; + public String negativePrompt; + public String aspectRatio; + public Boolean addWatermark; + public String storageUri; + public String personGeneration; + public String safetySetting; + public Struct outputOptions; + + public static ImageParametersBuilder of() { + return new ImageParametersBuilder(); + } + + public ImageParametersBuilder sampleCount(Integer sampleCount) { + Assert.notNull(sampleCount, "Sample count must not be null"); + this.sampleCount = sampleCount; + return this; + } + + public ImageParametersBuilder seed(Integer seed) { + Assert.notNull(seed, "Seed must not be null"); + this.seed = seed; + return this; + } + + public ImageParametersBuilder negativePrompt(String negativePrompt) { + Assert.notNull(negativePrompt, "Negative prompt must not be null"); + this.negativePrompt = negativePrompt; + return this; + } + + public ImageParametersBuilder aspectRatio(String aspectRatio) { + Assert.notNull(aspectRatio, "Aspect ratio must not be null"); + this.aspectRatio = aspectRatio; + return this; + } + + public ImageParametersBuilder addWatermark(Boolean addWatermark) { + Assert.notNull(addWatermark, "Add watermark must not be null"); + this.addWatermark = addWatermark; + return this; + } + + public ImageParametersBuilder storageUri(String storageUri) { + Assert.notNull(storageUri, "Storage URI must not be null"); + this.storageUri = storageUri; + return this; + } + + public ImageParametersBuilder personGeneration(String personGeneration) { + Assert.notNull(personGeneration, "Person generation must not be null"); + this.personGeneration = personGeneration; + return this; + } + + public ImageParametersBuilder safetySetting(String safetySetting) { + Assert.notNull(safetySetting, "Safety setting must not be null"); + this.safetySetting = safetySetting; + return this; + } + + public ImageParametersBuilder outputOptions(Struct outputOptions) { + Assert.notNull(outputOptions, "Output options must not be null"); + this.outputOptions = outputOptions; + return this; + } + + public Struct build() { + Struct.Builder imageParametersBuilder = Struct.newBuilder(); + + if (this.sampleCount != null) { + imageParametersBuilder.putFields("sampleCount", valueOf(this.sampleCount)); + } + if (this.seed != null) { + imageParametersBuilder.putFields("seed", valueOf(this.seed)); + } + if (this.negativePrompt != null) { + imageParametersBuilder.putFields("negativePrompt", valueOf(this.negativePrompt)); + } + if (this.aspectRatio != null) { + imageParametersBuilder.putFields("aspectRatio", valueOf(this.aspectRatio)); + } + if (this.addWatermark != null) { + imageParametersBuilder.putFields("addWatermark", valueOf(this.addWatermark)); + } + if (this.storageUri != null) { + imageParametersBuilder.putFields("storageUri", valueOf(this.storageUri)); + } + if (this.personGeneration != null) { + imageParametersBuilder.putFields("personGeneration", valueOf(this.personGeneration)); + } + if (this.safetySetting != null) { + imageParametersBuilder.putFields("safetySetting", valueOf(this.safetySetting)); + } + if (this.outputOptions != null) { + imageParametersBuilder.putFields("outputOptions", Value.newBuilder().setStructValue(this.outputOptions).build()); + } + return imageParametersBuilder.build(); + } + + public static class OutputOptions { + public String mimeType; + public Integer compressionQuality; + + public static OutputOptions of() { + return new OutputOptions(); + } + + public OutputOptions mimeType(String mimeType) { + Assert.notNull(mimeType, "MIME type must not be null"); + this.mimeType = mimeType; + return this; + } + + public OutputOptions compressionQuality(Integer compressionQuality) { + Assert.notNull(compressionQuality, "Compression quality must not be null"); + this.compressionQuality = compressionQuality; + return this; + } + + public Struct build() { + Struct.Builder outputOptionsBuilder = Struct.newBuilder(); + + if (this.mimeType != null) { + outputOptionsBuilder.putFields("mimeType", valueOf(this.mimeType)); + } + if (this.compressionQuality != null) { + outputOptionsBuilder.putFields("compressionQuality", valueOf(this.compressionQuality)); + } + return outputOptionsBuilder.build(); + } + + } + + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java new file mode 100644 index 0000000000..ee12bd1c97 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/main/java/org/springframework/ai/vertexai/imagen/metadata/VertexAiImagenImageGenerationMetadata.java @@ -0,0 +1,81 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vertexai.imagen.metadata; + +import java.util.Objects; + +import org.springframework.ai.image.ImageGenerationMetadata; + +/** + * VertexAiImagenImageGenerationMetadata is a class that defines the metadata for Imagen + * on Vertex AI. + * + * @author Sami Marzouki + */ +public class VertexAiImagenImageGenerationMetadata implements ImageGenerationMetadata { + + private final String prompt; + + private final String model; + + private final String mimeType; + + public VertexAiImagenImageGenerationMetadata(String revisedPrompt, String mimeType, String model) { + this.prompt = revisedPrompt; + this.model = model; + this.mimeType = mimeType; + } + + public String getPrompt() { + return prompt; + } + + public String getModel() { + return model; + } + + public String getMimeType() { + return mimeType; + } + + @Override + public String toString() { + return "VertexAiImagenImageGenerationMetadata{" + + "prompt='" + prompt + '\'' + + ", model='" + model + '\'' + + ", mimeType='" + mimeType + '\'' + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + VertexAiImagenImageGenerationMetadata that = (VertexAiImagenImageGenerationMetadata) o; + return Objects.equals(prompt, that.prompt) + && Objects.equals(model, that.model) + && Objects.equals(mimeType, that.mimeType); + } + + @Override + public int hashCode() { + return Objects.hash(prompt, model, mimeType); + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/TestVertexAiImagenImageModel.java b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/TestVertexAiImagenImageModel.java new file mode 100644 index 0000000000..4c84c25419 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/TestVertexAiImagenImageModel.java @@ -0,0 +1,76 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package imagen; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; + +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.vertexai.imagen.VertexAiImagenConnectionDetails; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageModel; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageOptions; +import org.springframework.retry.support.RetryTemplate; + +/** + * @author Sami Marzouki + */ +public class TestVertexAiImagenImageModel extends VertexAiImagenImageModel { + + private PredictionServiceClient mockPredictionServiceClient; + + private PredictRequest.Builder mockPredictRequestBuilder; + + public TestVertexAiImagenImageModel(VertexAiImagenConnectionDetails connectionDetails, + VertexAiImagenImageOptions defaultOptions, RetryTemplate retryTemplate) { + super(connectionDetails, defaultOptions, retryTemplate); + } + + public void setMockPredictionServiceClient(PredictionServiceClient mockPredictionServiceClient) { + this.mockPredictionServiceClient = mockPredictionServiceClient; + } + + @Override + public PredictionServiceClient createPredictionServiceClient() { + if (this.mockPredictionServiceClient != null) { + return this.mockPredictionServiceClient; + } + return super.createPredictionServiceClient(); + } + + @Override + public PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) { + if (this.mockPredictionServiceClient != null) { + return this.mockPredictionServiceClient.predict(predictRequestBuilder.build()); + } + return super.getPredictResponse(client, predictRequestBuilder); + } + + public void setMockPredictRequestBuilder(PredictRequest.Builder mockPredictRequestBuilder) { + this.mockPredictRequestBuilder = mockPredictRequestBuilder; + } + + @Override + protected PredictRequest.Builder getPredictRequestBuilder(ImagePrompt imagePrompt, EndpointName endpointName, + VertexAiImagenImageOptions finalOptions) { + if (this.mockPredictRequestBuilder != null) { + return this.mockPredictRequestBuilder; + } + return super.getPredictRequestBuilder(imagePrompt, endpointName, finalOptions); + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelIT.java b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelIT.java new file mode 100644 index 0000000000..15465acee7 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelIT.java @@ -0,0 +1,87 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package imagen; + +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.vertexai.imagen.VertexAiImagenConnectionDetails; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageModel; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageOptions; +import org.springframework.ai.vertexai.imagen.metadata.VertexAiImagenImageGenerationMetadata; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +/** + * @author Marzouki Sami + */ +@SpringBootTest(classes = VertexAiImagenImageModelIT.Config.class) +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_IMAGEN_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_IMAGEN_LOCATION", matches = ".*") +public class VertexAiImagenImageModelIT { + + @Autowired + protected VertexAiImagenImageModel imageModel; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = {"imagen-3.0-generate-001", "imagen-3.0-fast-generate-001", "imagen-3.0-capability-001", + "imagegeneration@006", "imagegeneration@005", "imagegeneration@002"}) + void defaultImageGenerator(String modelName) { + Assertions.assertThat(this.imageModel).isNotNull(); + + var options = VertexAiImagenImageOptions.builder().model(modelName).N(1).build(); + + ImageResponse imageResponse = this.imageModel + .call(new ImagePrompt("little kitten sitting on a purple cushion", options)); + + Assertions.assertThat(imageResponse.getResults()).hasSize(2); + Assertions.assertThat(imageResponse.getResults().get(0).getOutput().getB64Json()).isNotEmpty(); + Assertions.assertThat(((VertexAiImagenImageGenerationMetadata) imageResponse.getResults().get(0).getMetadata()).getModel()).isNotEmpty(); + Assertions.assertThat(((VertexAiImagenImageGenerationMetadata) imageResponse.getResults().get(0).getMetadata()).getPrompt()).isNotEmpty(); + Assertions.assertThat(((VertexAiImagenImageGenerationMetadata) imageResponse.getResults().get(0).getMetadata()).getMimeType()).isNotEmpty(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public VertexAiImagenConnectionDetails connectionDetails() { + return VertexAiImagenConnectionDetails.builder() + .projectId(System.getenv("VERTEX_AI_IMAGEN_PROJECT_ID")) + .location(System.getenv("VERTEX_AI_IMAGEN_LOCATION")) + .build(); + } + + @Bean + public VertexAiImagenImageModel imageModel(VertexAiImagenConnectionDetails connectionDetails) { + + VertexAiImagenImageOptions options = VertexAiImagenImageOptions.builder() + .model(VertexAiImagenImageOptions.DEFAULT_MODEL_NAME) + .build(); + + return new VertexAiImagenImageModel(connectionDetails, options); + } + + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelObservationIT.java b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelObservationIT.java new file mode 100644 index 0000000000..8aa3adfe1a --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageModelObservationIT.java @@ -0,0 +1,122 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package imagen; + +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.image.ImageResponseMetadata; +import org.springframework.ai.image.observation.DefaultImageModelObservationConvention; +import org.springframework.ai.image.observation.ImageModelObservationDocumentation; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.vertexai.imagen.VertexAiImagenConnectionDetails; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageModel; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageModelName; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageOptions; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for observation instrumentation in {@link VertexAiImagenImageModel}. + * + * @author Sami Marzouki + */ +@SpringBootTest(classes = VertexAiImagenImageModelObservationIT.Config.class) +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_IMAGEN_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_IMAGEN_LOCATION", matches = ".*") +public class VertexAiImagenImageModelObservationIT { + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + VertexAiImagenImageModel imageModel; + + @Test + void observationForImageOperation() { + var options = VertexAiImagenImageOptions.builder() + .model(VertexAiImagenImageModelName.IMAGEN_2_V006.getValue()) + .N(1) + .build(); + + ImagePrompt imagePrompt = new ImagePrompt("Little kitten sitting on a purple cushion", options); + ImageResponse imageResponse = this.imageModel.call(imagePrompt); + assertThat(imageResponse.getResults()).isNotEmpty(); + + ImageResponseMetadata responseMetadata = imageResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + TestObservationRegistryAssert.assertThat(this.observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultImageModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("image " + VertexAiImagenImageModelName.IMAGEN_2_V006.getValue()) + .hasLowCardinalityKeyValue( + ImageModelObservationDocumentation.LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.IMAGE.value()) + .hasLowCardinalityKeyValue(ImageModelObservationDocumentation.LowCardinalityKeyNames.AI_PROVIDER.asString(), + AiProvider.VERTEX_AI.value()) + .hasLowCardinalityKeyValue( + ImageModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL.asString(), + VertexAiImagenImageModelName.IMAGEN_2_V006.getValue()) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public VertexAiImagenConnectionDetails connectionDetails() { + return VertexAiImagenConnectionDetails.builder() + .projectId(System.getenv("VERTEX_AI_IMAGEN_PROJECT_ID")) + .location(System.getenv("VERTEX_AI_IMAGEN_LOCATION")) + .build(); + } + + @Bean + public VertexAiImagenImageModel imageModel(VertexAiImagenConnectionDetails connectionDetails, + ObservationRegistry observationRegistry) { + + VertexAiImagenImageOptions options = VertexAiImagenImageOptions.builder() + .model(VertexAiImagenImageOptions.DEFAULT_MODEL_NAME) + .build(); + + return new VertexAiImagenImageModel(connectionDetails, options, RetryUtils.DEFAULT_RETRY_TEMPLATE, + observationRegistry); + } + + } + +} diff --git a/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageRetryTests.java b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageRetryTests.java new file mode 100644 index 0000000000..0086338969 --- /dev/null +++ b/models/spring-ai-vertex-ai-imagen/src/test/java/imagen/VertexAiImagenImageRetryTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package imagen; + +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.ai.vertexai.imagen.VertexAiImagenConnectionDetails; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageOptions; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * @author Sami Marzouki + */ +@ExtendWith(MockitoExtension.class) +public class VertexAiImagenImageRetryTests { + + private TestRetryListener retryListener; + + @Mock + private PredictionServiceClient mockPredictionServiceClient; + + @Mock + private VertexAiImagenConnectionDetails mockConnectionDetails; + + @Mock + private PredictRequest.Builder mockPredictRequestBuilder; + + private TestVertexAiImagenImageModel imageModel; + + @BeforeEach + public void setUp() { + RetryTemplate retryTemplate = RetryUtils.SHORT_RETRY_TEMPLATE; + this.retryListener = new TestRetryListener(); + retryTemplate.registerListener(this.retryListener); + + this.imageModel = new TestVertexAiImagenImageModel(this.mockConnectionDetails, + VertexAiImagenImageOptions.builder().build(), retryTemplate); + this.imageModel.setMockPredictionServiceClient(this.mockPredictionServiceClient); + this.imageModel.setMockPredictRequestBuilder(this.mockPredictRequestBuilder); + given(this.mockPredictRequestBuilder.build()).willReturn(PredictRequest.getDefaultInstance()); + } + + @Test + public void vertexAiImageGeneratorTransientError() { + // Set up the mock PredictResponse + PredictResponse mockResponse = PredictResponse.newBuilder() + .addPredictions(Value.newBuilder() + .setStructValue(Struct.newBuilder() + .putFields("bytesBase64Encoded", Value.newBuilder().setStringValue("BASE64_IMG_BYTES").build()) + .putFields("mimeType", Value.newBuilder().setStringValue("image/png").build()) + .build()) + .build()) + .addPredictions(Value.newBuilder() + .setStructValue(Struct.newBuilder() + .putFields("mimeType", Value.newBuilder().setStringValue("image/png").build()) + .putFields("bytesBase64Encoded", Value.newBuilder().setStringValue("BASE64_IMG_BYTES").build()) + .build()) + .build()) + .build(); + + // Set up the mock PredictionServiceClient + given(this.mockPredictionServiceClient.predict(any())) + .willThrow(new TransientAiException("Transient Error 1")) + .willThrow(new TransientAiException("Transient Error 2")) + .willReturn(mockResponse); + + ImageResponse result = this.imageModel.call(new ImagePrompt("text1", null)); + + assertThat(result).isNotNull(); + assertThat(result.getResults()).hasSize(2); + assertThat(result.getResults().get(0).getOutput().getB64Json()).isEqualTo("BASE64_IMG_BYTES"); + assertThat(result.getResults().get(1).getOutput().getB64Json()).isEqualTo("BASE64_IMG_BYTES"); + assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2); + + verify(this.mockPredictRequestBuilder, times(3)).build(); + } + + @Test + public void vertexAiImageGeneratorNonTransientError() { + // Set up the mock PredictionServiceClient to throw a non-transient error + given(this.mockPredictionServiceClient.predict(any())).willThrow(new RuntimeException("Non Transient Error")); + + // Assert that a RuntimeException is thrown and not retried + assertThatThrownBy(() -> this.imageModel.call(new ImagePrompt("text1", null))) + .isInstanceOf(RuntimeException.class); + + // Verify that predict was called only once (no retries for non-transient errors) + verify(this.mockPredictionServiceClient, times(1)).predict(any()); + } + + private static class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + this.onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + this.onErrorRetryCount = context.getRetryCount(); + } + + } + +} diff --git a/pom.xml b/pom.xml index e0972958f6..5e7a79f27d 100644 --- a/pom.xml +++ b/pom.xml @@ -103,6 +103,7 @@ models/spring-ai-transformers models/spring-ai-vertex-ai-embedding models/spring-ai-vertex-ai-gemini + models/spring-ai-vertex-ai-imagen models/spring-ai-watsonx-ai models/spring-ai-zhipuai models/spring-ai-moonshot @@ -123,6 +124,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-transformers spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-embedding spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-gemini + spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai spring-ai-spring-boot-starters/spring-ai-starter-zhipuai spring-ai-spring-boot-starters/spring-ai-starter-moonshot @@ -665,6 +667,7 @@ org.springframework.ai.transformers/**/*IT.java org.springframework.ai.vertexai.embedding/**/*IT.java org.springframework.ai.vertexai.gemini/**/*IT.java + org.springframework.ai.vertexai.imagen/**/*IT.java org.springframework.ai.watsonx/**/*IT.java org.springframework.ai.zhipuai/**/*IT.java diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index 495d0279cd..1196ff5cec 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -146,6 +146,12 @@ ${project.version} + + org.springframework.ai + spring-ai-vertex-ai-imagen + ${project.version} + + org.springframework.ai spring-ai-mistral-ai @@ -491,6 +497,12 @@ ${project.version} + + org.springframework.ai + spring-ai-vertex-ai-imagen-spring-boot-starter + ${project.version} + + org.springframework.ai spring-ai-weaviate-store-spring-boot-starter diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 13796f3abc..303381b529 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -262,6 +262,14 @@ true + + + org.springframework.ai + spring-ai-vertex-ai-imagen + ${project.parent.version} + true + + org.springframework.ai diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenAutoConfiguration.java new file mode 100644 index 0000000000..810feecdc6 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenAutoConfiguration.java @@ -0,0 +1,86 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vertexai.imagen; + +import java.io.IOException; + +import com.google.cloud.vertexai.VertexAI; +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.image.observation.ImageModelObservationConvention; +import org.springframework.ai.vertexai.imagen.VertexAiImagenConnectionDetails; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageModel; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.ImportAutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * AutoConfiguration for Vertex AI Imagen. + * + * @author Sami Marzouki + */ +@AutoConfiguration(after = {SpringAiRetryAutoConfiguration.class}) +@ConditionalOnClass({VertexAI.class, VertexAiImagenImageModel.class}) +@EnableConfigurationProperties({VertexAiImagenImageProperties.class, VertexAiImagenConnectionProperties.class}) +@ImportAutoConfiguration(classes = {SpringAiRetryAutoConfiguration.class}) +public class VertexAiImagenAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public VertexAiImagenConnectionDetails connectionDetails( + VertexAiImagenConnectionProperties connectionProperties) throws IOException { + Assert.hasText(connectionProperties.getProjectId(), "Vertex AI project-id must be set!"); + Assert.hasText(connectionProperties.getLocation(), "Vertex AI location must be set!"); + + var connectionBuilder = VertexAiImagenConnectionDetails.builder() + .projectId(connectionProperties.getProjectId()) + .location(connectionProperties.getLocation()); + + if (StringUtils.hasText(connectionProperties.getApiEndpoint())) { + connectionBuilder.apiEndpoint(connectionProperties.getApiEndpoint()); + } + + return connectionBuilder.build(); + } + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = VertexAiImagenImageProperties.CONFIG_PREFIX, name = "enabled", + havingValue = "true", matchIfMissing = true) + public VertexAiImagenImageModel imageModel(VertexAiImagenConnectionDetails connectionDetails, + VertexAiImagenImageProperties properties, RetryTemplate retryTemplate, + ObjectProvider observationRegistry, + ObjectProvider observationConvention) { + + var imageModel = new VertexAiImagenImageModel(connectionDetails, properties.getOptions(), + retryTemplate, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); + + observationConvention.ifAvailable(imageModel::setObservationConvention); + + return imageModel; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenConnectionProperties.java new file mode 100644 index 0000000000..fa86bf0090 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenConnectionProperties.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vertexai.imagen; + +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.core.io.Resource; + +/** + * Configuration properties for Vertex AI Imagen. + * + * @author Sami Marzouki + */ +@ConfigurationProperties(VertexAiImagenConnectionProperties.CONFIG_PREFIX) +public class VertexAiImagenConnectionProperties { + + public static final String CONFIG_PREFIX = "spring.ai.vertex.ai.imagen"; + + /** + * Vertex AI Imagen project ID. + */ + private String projectId; + + /** + * Vertex AI Imagen location. + */ + private String location; + + /** + * URI to Vertex AI Imagen credentials (optional) + */ + private Resource credentialsUri; + + /** + * Vertex AI Imagen API endpoint. + */ + private String apiEndpoint; + + public String getProjectId() { + return this.projectId; + } + + public void setProjectId(String projectId) { + this.projectId = projectId; + } + + public String getLocation() { + return this.location; + } + + public void setLocation(String location) { + this.location = location; + } + + public Resource getCredentialsUri() { + return this.credentialsUri; + } + + public void setCredentialsUri(Resource credentialsUri) { + this.credentialsUri = credentialsUri; + } + + public String getApiEndpoint() { + return this.apiEndpoint; + } + + public void setApiEndpoint(String apiEndpoint) { + this.apiEndpoint = apiEndpoint; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenImageProperties.java new file mode 100644 index 0000000000..3cba5ac320 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenImageProperties.java @@ -0,0 +1,57 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.vertexai.imagen; + +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Configuration properties for Vertex AI Imagen. + * + * @author Sami Marzouki + */ +@ConfigurationProperties(VertexAiImagenImageProperties.CONFIG_PREFIX) +public class VertexAiImagenImageProperties { + + public static final String CONFIG_PREFIX = "spring.ai.vertex.ai.imagen.generator"; + + private boolean enabled = true; + + /** + * Vertex AI Imagen API options. + */ + private VertexAiImagenImageOptions options = VertexAiImagenImageOptions.builder() + .model(VertexAiImagenImageOptions.DEFAULT_MODEL_NAME) + .build(); + + public VertexAiImagenImageOptions getOptions() { + return this.options; + } + + public void setOptions(VertexAiImagenImageOptions options) { + this.options = options; + } + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenModelAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenModelAutoConfigurationIT.java new file mode 100644 index 0000000000..e1211f851a --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/imagen/VertexAiImagenModelAutoConfigurationIT.java @@ -0,0 +1,86 @@ +/* + * Copyright 2025-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.vertexai.imagen; + +import java.io.File; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.junit.jupiter.api.io.TempDir; + +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.vertexai.imagen.VertexAiImagenImageModel; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Sami Marzouki + */ +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_IMAGEN_PROJECT_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "VERTEX_AI_IMAGEN_LOCATION", matches = ".*") +public class VertexAiImagenModelAutoConfigurationIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.vertex.ai.imagen.project-id=" + System.getenv("VERTEX_AI_IMAGEN_PROJECT_ID"), + "spring.ai.vertex.ai.imagen.location=" + System.getenv("VERTEX_AI_IMAGEN_LOCATION")) + .withConfiguration(AutoConfigurations.of(VertexAiImagenAutoConfiguration.class)); + + @TempDir + File tempDir; + + + @Test + public void imageGenerator() { + this.contextRunner.run(context -> { + var connectionProperties = context.getBean(VertexAiImagenConnectionProperties.class); + var imageProperties = context.getBean(VertexAiImagenImageProperties.class); + + assertThat(connectionProperties).isNotNull(); + assertThat(imageProperties.isEnabled()).isTrue(); + + VertexAiImagenImageModel imageModel = context.getBean(VertexAiImagenImageModel.class); + assertThat(imageModel).isInstanceOf(VertexAiImagenImageModel.class); + + ImageResponse imageResponse = imageModel.call(new ImagePrompt("Spring Framework, Spring AI")); + + assertThat(imageResponse.getResults().size()).isEqualTo(1); + assertThat(imageResponse.getResults().get(0).getOutput().getB64Json()).isNotEmpty(); + }); + } + + @Test + void imageGeneratorActivation() { + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.imagen.generator.enabled=false").run(context -> { + assertThat(context.getBeansOfType(VertexAiImagenImageProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(VertexAiImagenImageModel.class)).isEmpty(); + }); + + this.contextRunner.withPropertyValues("spring.ai.vertex.ai.imagen.generator.enabled=true").run(context -> { + assertThat(context.getBeansOfType(VertexAiImagenImageProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(VertexAiImagenImageModel.class)).isNotEmpty(); + }); + + this.contextRunner.run(context -> { + assertThat(context.getBeansOfType(VertexAiImagenImageProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(VertexAiImagenImageModel.class)).isNotEmpty(); + }); + + } + +} diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml new file mode 100644 index 0000000000..c789f590a3 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-vertex-ai-imagen/pom.xml @@ -0,0 +1,61 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-M5 + ../../pom.xml + + + spring-ai-vertex-ai-imagen-spring-boot-starter + jar + Spring AI Starter - VertexAI Imagen + Spring AI Vertex Imagen AI Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-vertex-ai-imagen + ${project.parent.version} + + + +