Skip to content

Commit da2c674

Browse files
authored
openai spec updates (#156)
* adding support for max_completion_tokens * adding support for max_completion_token as class param * adding format changes * formatting changes for lint * formatting changes for lint * formatting changes for lint * removing whitespace
1 parent df881ef commit da2c674

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,9 @@ class ChatNVIDIA(BaseChatModel):
264264
None, description="Sampling temperature in [0, 1]"
265265
)
266266
max_tokens: Optional[int] = Field(
267-
1024, description="Maximum # of tokens to generate"
267+
1024,
268+
description="Maximum # of tokens to generate",
269+
alias="max_completion_tokens",
268270
)
269271
top_p: Optional[float] = Field(None, description="Top-p for distribution sampling")
270272
seed: Optional[int] = Field(None, description="The seed for deterministic results")
@@ -287,6 +289,8 @@ def __init__(self, **kwargs: Any):
287289
Format for base URL is http://host:port
288290
temperature (float): Sampling temperature in [0, 1].
289291
max_tokens (int): Maximum number of tokens to generate.
292+
Deprecated, use max_completion_tokens instead
293+
max_completion_tokens (int): Maximum number of tokens to generate.
290294
top_p (float): Top-p for distribution sampling.
291295
seed (int): A seed for deterministic results.
292296
stop (list[str]): A list of cased stop words.
@@ -303,6 +307,16 @@ def __init__(self, **kwargs: Any):
303307
model="meta-llama3-8b-instruct"
304308
)
305309
"""
310+
# Show deprecation warning if max_tokens was used
311+
if "max_tokens" in kwargs:
312+
warnings.warn(
313+
"The 'max_tokens' parameter is deprecated and will be removed "
314+
"in a future version. "
315+
"Please use 'max_completion_tokens' instead.",
316+
DeprecationWarning,
317+
stacklevel=2,
318+
)
319+
306320
super().__init__(**kwargs)
307321
# allow nvidia_base_url as an alternative for base_url
308322
base_url = kwargs.pop("nvidia_base_url", self.base_url)
@@ -359,7 +373,11 @@ def _get_ls_params(
359373
ls_model_name=self.model or "UNKNOWN",
360374
ls_model_type="chat",
361375
ls_temperature=params.get("temperature", self.temperature),
362-
ls_max_tokens=params.get("max_tokens", self.max_tokens),
376+
# TODO: remove max_tokens once all models support max_completion_tokens
377+
ls_max_tokens=(
378+
params.get("max_completion_tokens", self.max_tokens)
379+
or params.get("max_tokens", self.max_tokens)
380+
),
363381
# mypy error: Extra keys ("ls_top_p", "ls_seed")
364382
# for TypedDict "LangSmithParams" [typeddict-item]
365383
# ls_top_p=params.get("top_p", self.top_p),
@@ -765,7 +783,7 @@ class Choices(enum.Enum):
765783
For Pydantic schema and Enum, the output will be None if the response is
766784
insufficient to construct the object or otherwise invalid. For instance,
767785
```
768-
llm = ChatNVIDIA(max_tokens=1)
786+
llm = ChatNVIDIA(max_completion_tokens=1)
769787
structured_llm = llm.with_structured_output(Joke)
770788
print(structured_llm.invoke("Tell me a joke about NVIDIA"))
771789

libs/ai-endpoints/tests/unit_tests/test_chat_models.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test chat model integration."""
22

3+
import warnings
34

45
import pytest
56
from requests_mock import Mocker
@@ -45,3 +46,28 @@ def test_integration_initialization() -> None:
4546
def test_unavailable(empty_v1_models: None) -> None:
4647
with pytest.warns(UserWarning, match="Model not-a-real-model is unknown"):
4748
ChatNVIDIA(api_key="BOGUS", model="not-a-real-model")
49+
50+
51+
def test_max_tokens_deprecation_warning() -> None:
52+
"""Test that using max_tokens raises a deprecation warning."""
53+
with pytest.warns(
54+
DeprecationWarning,
55+
match=(
56+
"The 'max_tokens' parameter is deprecated and will be removed "
57+
"in a future version"
58+
),
59+
):
60+
ChatNVIDIA(model="meta/llama2-70b", max_tokens=50)
61+
62+
63+
def test_max_completion_tokens() -> None:
64+
"""Test that max_completion_tokens works without warning."""
65+
with warnings.catch_warnings(record=True) as w:
66+
warnings.simplefilter("always")
67+
llm = ChatNVIDIA(
68+
model="meta/llama2-70b",
69+
max_completion_tokens=50,
70+
nvidia_api_key="nvapi-...",
71+
)
72+
assert len(w) == 0
73+
assert llm.max_tokens == 50

0 commit comments

Comments
 (0)