1
1
import os
2
2
from chromadb .utils .embedding_functions .openai_embedding_function import OpenAIEmbeddingFunction
3
+ from chromadb .api .types import Documents , EmbeddingFunction
4
+ from typing import Dict , Any , Optional
3
5
4
6
class BasetenEmbeddingFunction (OpenAIEmbeddingFunction ):
5
7
6
8
def __init__ (
7
9
self ,
8
- api_key : str ,
10
+ api_key : Optional [ str ] ,
9
11
api_base : str ,
10
12
api_key_env_var : str = "CHROMA_BASETEN_API_KEY" ,
11
13
):
12
14
"""
13
15
Initialize the BasetenEmbeddingFunction.
14
16
Args:
15
- api_key (str, required ): The API key for your Baseten account
17
+ api_key (str, optional ): The API key for your Baseten account
16
18
api_base (str, required): The Baseten URL of the deployment
19
+ api_key_env_var (str, optional): The environment variable to use for the API key. Defaults to "CHROMA_BASETEN_API_KEY".
17
20
"""
18
21
try :
19
22
import openai
@@ -23,9 +26,13 @@ def __init__(
23
26
)
24
27
25
28
self .api_key_env_var = api_key_env_var
26
- self .api_key = api_key or os .getenv (api_key_env_var )
27
- if not self .api_key :
28
- raise ValueError (f"The { api_key_env_var } environment variable is not set." )
29
+ # Prioritize api_key argument, then environment variable
30
+ resolved_api_key = api_key or os .getenv (api_key_env_var )
31
+ if not resolved_api_key :
32
+ raise ValueError (f"API key not provided and { api_key_env_var } environment variable is not set." )
33
+ self .api_key = resolved_api_key
34
+ if not api_base :
35
+ raise ValueError ("The api_base argument must be provided." )
29
36
self .api_base = api_base
30
37
self .model_name = "baseten-embedding-model"
31
38
self .dimensions = None
@@ -34,3 +41,43 @@ def __init__(
34
41
api_key = self .api_key ,
35
42
base_url = self .api_base
36
43
)
44
+
45
+ @staticmethod
46
+ def name () -> str :
47
+ return "baseten"
48
+
49
+ def get_config (self ) -> Dict [str , Any ]:
50
+ return {
51
+ "api_base" : self .api_base ,
52
+ "api_key_env_var" : self .api_key_env_var
53
+ }
54
+
55
+
56
+ @staticmethod
57
+ def build_from_config (config : Dict [str , Any ]) -> "BasetenEmbeddingFunction" :
58
+ """
59
+ Build the BasetenEmbeddingFunction from a configuration dictionary.
60
+
61
+ Args:
62
+ config (Dict[str, Any]): A dictionary containing the configuration parameters.
63
+ Expected keys: 'api_key', 'api_base', 'api_key_env_var'.
64
+
65
+ Returns:
66
+ BasetenEmbeddingFunction: An instance of BasetenEmbeddingFunction.
67
+ """
68
+ api_key_env_var = config .get ("api_key_env_var" )
69
+ api_base = config .get ("api_base" )
70
+ if api_key_env_var is None or api_base is None :
71
+ raise ValueError ("Missing 'api_key_env_var' or 'api_base' in configuration for BasetenEmbeddingFunction." )
72
+
73
+ # Note: We rely on the __init__ method to handle potential missing api_key
74
+ # by checking the environment variable if the config value is None.
75
+ # However, api_base must be present either in config or have a default.
76
+ if api_base is None :
77
+ raise ValueError ("Missing 'api_base' in configuration for BasetenEmbeddingFunction." )
78
+
79
+ return BasetenEmbeddingFunction (
80
+ api_key = None , # Pass None if not in config, __init__ will check env var
81
+ api_base = api_base ,
82
+ api_key_env_var = api_key_env_var ,
83
+ )
0 commit comments