@@ -264,7 +264,9 @@ class ChatNVIDIA(BaseChatModel):
264
264
None , description = "Sampling temperature in [0, 1]"
265
265
)
266
266
max_tokens : Optional [int ] = Field (
267
- 1024 , description = "Maximum # of tokens to generate"
267
+ 1024 ,
268
+ description = "Maximum # of tokens to generate" ,
269
+ alias = "max_completion_tokens" ,
268
270
)
269
271
top_p : Optional [float ] = Field (None , description = "Top-p for distribution sampling" )
270
272
seed : Optional [int ] = Field (None , description = "The seed for deterministic results" )
@@ -287,6 +289,8 @@ def __init__(self, **kwargs: Any):
287
289
Format for base URL is http://host:port
288
290
temperature (float): Sampling temperature in [0, 1].
289
291
max_tokens (int): Maximum number of tokens to generate.
292
+ Deprecated, use max_completion_tokens instead
293
+ max_completion_tokens (int): Maximum number of tokens to generate.
290
294
top_p (float): Top-p for distribution sampling.
291
295
seed (int): A seed for deterministic results.
292
296
stop (list[str]): A list of cased stop words.
@@ -303,6 +307,16 @@ def __init__(self, **kwargs: Any):
303
307
model="meta-llama3-8b-instruct"
304
308
)
305
309
"""
310
+ # Show deprecation warning if max_tokens was used
311
+ if "max_tokens" in kwargs :
312
+ warnings .warn (
313
+ "The 'max_tokens' parameter is deprecated and will be removed "
314
+ "in a future version. "
315
+ "Please use 'max_completion_tokens' instead." ,
316
+ DeprecationWarning ,
317
+ stacklevel = 2 ,
318
+ )
319
+
306
320
super ().__init__ (** kwargs )
307
321
# allow nvidia_base_url as an alternative for base_url
308
322
base_url = kwargs .pop ("nvidia_base_url" , self .base_url )
@@ -359,7 +373,11 @@ def _get_ls_params(
359
373
ls_model_name = self .model or "UNKNOWN" ,
360
374
ls_model_type = "chat" ,
361
375
ls_temperature = params .get ("temperature" , self .temperature ),
362
- ls_max_tokens = params .get ("max_tokens" , self .max_tokens ),
376
+ # TODO: remove max_tokens once all models support max_completion_tokens
377
+ ls_max_tokens = (
378
+ params .get ("max_completion_tokens" , self .max_tokens )
379
+ or params .get ("max_tokens" , self .max_tokens )
380
+ ),
363
381
# mypy error: Extra keys ("ls_top_p", "ls_seed")
364
382
# for TypedDict "LangSmithParams" [typeddict-item]
365
383
# ls_top_p=params.get("top_p", self.top_p),
@@ -765,7 +783,7 @@ class Choices(enum.Enum):
765
783
For Pydantic schema and Enum, the output will be None if the response is
766
784
insufficient to construct the object or otherwise invalid. For instance,
767
785
```
768
- llm = ChatNVIDIA(max_tokens =1)
786
+ llm = ChatNVIDIA(max_completion_tokens =1)
769
787
structured_llm = llm.with_structured_output(Joke)
770
788
print(structured_llm.invoke("Tell me a joke about NVIDIA"))
771
789
0 commit comments