Skip to content

Commit b5fa883

Browse files
claudevdmClaude
and
Claude
authored
Add Vertex embeddings to RAG package. (#33593)
Co-authored-by: Claude <[email protected]>
1 parent 15f973f commit b5fa883

File tree

3 files changed

+211
-0
lines changed

3 files changed

+211
-0
lines changed

sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
"""Tests for apache_beam.ml.rag.embeddings.huggingface."""
1818

19+
import shutil
1920
import tempfile
2021
import unittest
2122

@@ -73,6 +74,9 @@ def setUp(self):
7374
})
7475
]
7576

77+
def tearDown(self) -> None:
78+
shutil.rmtree(self.artifact_location)
79+
7680
def test_embedding_pipeline(self):
7781
expected = [
7882
Chunk(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
18+
# Vertex AI Python SDK is required for this module.
19+
# Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long
20+
# to install Vertex AI Python SDK.
21+
22+
"""RAG-specific embedding implementations using Vertex AI models."""
23+
24+
from typing import Optional
25+
26+
from google.auth.credentials import Credentials
27+
28+
import apache_beam as beam
29+
from apache_beam.ml.inference.base import RunInference
30+
from apache_beam.ml.rag.embeddings.base import create_rag_adapter
31+
from apache_beam.ml.rag.types import Chunk
32+
from apache_beam.ml.transforms.base import EmbeddingsManager
33+
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
34+
from apache_beam.ml.transforms.embeddings.vertex_ai import DEFAULT_TASK_TYPE
35+
from apache_beam.ml.transforms.embeddings.vertex_ai import _VertexAITextEmbeddingHandler
36+
37+
try:
38+
import vertexai
39+
except ImportError:
40+
vertexai = None
41+
42+
43+
class VertexAITextEmbeddings(EmbeddingsManager):
44+
def __init__(
45+
self,
46+
model_name: str,
47+
*,
48+
title: Optional[str] = None,
49+
task_type: str = DEFAULT_TASK_TYPE,
50+
project: Optional[str] = None,
51+
location: Optional[str] = None,
52+
credentials: Optional[Credentials] = None,
53+
**kwargs):
54+
"""Utilizes Vertex AI text embeddings for semantic search and RAG
55+
pipelines.
56+
57+
Args:
58+
model_name: Name of the Vertex AI text embedding model
59+
title: Optional title for the text content
60+
task_type: Task type for embeddings (default: RETRIEVAL_DOCUMENT)
61+
project: GCP project ID
62+
location: GCP location
63+
credentials: Optional GCP credentials
64+
**kwargs: Additional arguments passed to EmbeddingsManager including
65+
ModelHandler inference_args.
66+
"""
67+
if not vertexai:
68+
raise ImportError(
69+
"vertexai is required to use VertexAITextEmbeddings. "
70+
"Please install it with `pip install google-cloud-aiplatform`")
71+
72+
super().__init__(type_adapter=create_rag_adapter(), **kwargs)
73+
self.model_name = model_name
74+
self.title = title
75+
self.task_type = task_type
76+
self.project = project
77+
self.location = location
78+
self.credentials = credentials
79+
80+
def get_model_handler(self):
81+
"""Returns model handler configured with RAG adapter."""
82+
return _VertexAITextEmbeddingHandler(
83+
model_name=self.model_name,
84+
title=self.title,
85+
task_type=self.task_type,
86+
project=self.project,
87+
location=self.location,
88+
credentials=self.credentials,
89+
)
90+
91+
def get_ptransform_for_processing(
92+
self, **kwargs
93+
) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]:
94+
"""Returns PTransform that uses the RAG adapter."""
95+
return RunInference(
96+
model_handler=_TextEmbeddingHandler(self),
97+
inference_args=self.inference_args).with_output_types(Chunk)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Tests for apache_beam.ml.rag.embeddings.vertex_ai."""
18+
19+
import shutil
20+
import tempfile
21+
import unittest
22+
23+
import apache_beam as beam
24+
from apache_beam.ml.rag.types import Chunk
25+
from apache_beam.ml.rag.types import Content
26+
from apache_beam.ml.rag.types import Embedding
27+
from apache_beam.ml.transforms.base import MLTransform
28+
from apache_beam.testing.test_pipeline import TestPipeline
29+
from apache_beam.testing.util import assert_that
30+
from apache_beam.testing.util import equal_to
31+
32+
# pylint: disable=ungrouped-imports
33+
try:
34+
import vertexai # pylint: disable=unused-import
35+
from apache_beam.ml.rag.embeddings.vertex_ai import VertexAITextEmbeddings
36+
VERTEX_AI_AVAILABLE = True
37+
except ImportError:
38+
VERTEX_AI_AVAILABLE = False
39+
40+
41+
def chunk_approximately_equals(expected, actual):
42+
"""Compare embeddings allowing for numerical differences."""
43+
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
44+
return False
45+
46+
return (
47+
expected.id == actual.id and expected.metadata == actual.metadata and
48+
expected.content == actual.content and
49+
len(expected.embedding.dense_embedding) == len(
50+
actual.embedding.dense_embedding) and
51+
all(isinstance(x, float) for x in actual.embedding.dense_embedding))
52+
53+
54+
@unittest.skipIf(
55+
not VERTEX_AI_AVAILABLE, "Vertex AI dependencies not available")
56+
class VertexAITextEmbeddingsTest(unittest.TestCase):
57+
def setUp(self):
58+
self.artifact_location = tempfile.mkdtemp(prefix='vertex_ai_')
59+
self.test_chunks = [
60+
Chunk(
61+
content=Content(text="This is a test sentence."),
62+
id="1",
63+
metadata={
64+
"source": "test.txt", "language": "en"
65+
}),
66+
Chunk(
67+
content=Content(text="Another example."),
68+
id="2",
69+
metadata={
70+
"source": "test.txt", "language": "en"
71+
})
72+
]
73+
74+
def tearDown(self) -> None:
75+
shutil.rmtree(self.artifact_location)
76+
77+
def test_embedding_pipeline(self):
78+
# gecko@002 produces 768-dimensional embeddings
79+
expected = [
80+
Chunk(
81+
id="1",
82+
embedding=Embedding(dense_embedding=[0.0] * 768),
83+
metadata={
84+
"source": "test.txt", "language": "en"
85+
},
86+
content=Content(text="This is a test sentence.")),
87+
Chunk(
88+
id="2",
89+
embedding=Embedding(dense_embedding=[0.0] * 768),
90+
metadata={
91+
"source": "test.txt", "language": "en"
92+
},
93+
content=Content(text="Another example."))
94+
]
95+
96+
embedder = VertexAITextEmbeddings(model_name="textembedding-gecko@002")
97+
98+
with TestPipeline() as p:
99+
embeddings = (
100+
p
101+
| beam.Create(self.test_chunks)
102+
| MLTransform(write_artifact_location=self.artifact_location).
103+
with_transform(embedder))
104+
105+
assert_that(
106+
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))
107+
108+
109+
if __name__ == '__main__':
110+
unittest.main()

0 commit comments

Comments
 (0)