Skip to content

Commit d9d9ce9

Browse files
fix embedding function with additional methods
1 parent d07f5f7 commit d9d9ce9

File tree

1 file changed

+52
-5
lines changed

1 file changed

+52
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
import os
22
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
3+
from chromadb.api.types import Documents, EmbeddingFunction
4+
from typing import Dict, Any, Optional
35

46
class BasetenEmbeddingFunction(OpenAIEmbeddingFunction):
57

68
def __init__(
79
self,
8-
api_key: str,
10+
api_key: Optional[str],
911
api_base: str,
1012
api_key_env_var: str = "CHROMA_BASETEN_API_KEY",
1113
):
1214
"""
1315
Initialize the BasetenEmbeddingFunction.
1416
Args:
15-
api_key (str, required): The API key for your Baseten account
17+
api_key (str, optional): The API key for your Baseten account
1618
api_base (str, required): The Baseten URL of the deployment
19+
api_key_env_var (str, optional): The environment variable to use for the API key. Defaults to "CHROMA_BASETEN_API_KEY".
1720
"""
1821
try:
1922
import openai
@@ -23,9 +26,13 @@ def __init__(
2326
)
2427

2528
self.api_key_env_var = api_key_env_var
26-
self.api_key = api_key or os.getenv(api_key_env_var)
27-
if not self.api_key:
28-
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
29+
# Prioritize api_key argument, then environment variable
30+
resolved_api_key = api_key or os.getenv(api_key_env_var)
31+
if not resolved_api_key:
32+
raise ValueError(f"API key not provided and {api_key_env_var} environment variable is not set.")
33+
self.api_key = resolved_api_key
34+
if not api_base:
35+
raise ValueError("The api_base argument must be provided.")
2936
self.api_base = api_base
3037
self.model_name = "baseten-embedding-model"
3138
self.dimensions = None
@@ -34,3 +41,43 @@ def __init__(
3441
api_key=self.api_key,
3542
base_url=self.api_base
3643
)
44+
45+
@staticmethod
46+
def name() -> str:
47+
return "baseten"
48+
49+
def get_config(self) -> Dict[str, Any]:
50+
return {
51+
"api_base": self.api_base,
52+
"api_key_env_var": self.api_key_env_var
53+
}
54+
55+
56+
@staticmethod
57+
def build_from_config(config: Dict[str, Any]) -> "BasetenEmbeddingFunction":
58+
"""
59+
Build the BasetenEmbeddingFunction from a configuration dictionary.
60+
61+
Args:
62+
config (Dict[str, Any]): A dictionary containing the configuration parameters.
63+
Expected keys: 'api_key', 'api_base', 'api_key_env_var'.
64+
65+
Returns:
66+
BasetenEmbeddingFunction: An instance of BasetenEmbeddingFunction.
67+
"""
68+
api_key_env_var = config.get("api_key_env_var")
69+
api_base = config.get("api_base")
70+
if api_key_env_var is None or api_base is None:
71+
raise ValueError("Missing 'api_key_env_var' or 'api_base' in configuration for BasetenEmbeddingFunction.")
72+
73+
# Note: We rely on the __init__ method to handle potential missing api_key
74+
# by checking the environment variable if the config value is None.
75+
# However, api_base must be present either in config or have a default.
76+
if api_base is None:
77+
raise ValueError("Missing 'api_base' in configuration for BasetenEmbeddingFunction.")
78+
79+
return BasetenEmbeddingFunction(
80+
api_key=None, # Pass None if not in config, __init__ will check env var
81+
api_base=api_base,
82+
api_key_env_var=api_key_env_var,
83+
)

0 commit comments

Comments
 (0)