Skip to content

Commit 90ef521

Browse files
Fix: gen_config in lmdeploypipeline updated by input gen_params (#151)
1 parent 6a54476 commit 90ef521

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

lagent/llms/lmdepoly_wrapper.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,19 @@ def generate(self,
238238
Returns:
239239
(a list of/batched) text/chat completion
240240
"""
241+
from lmdeploy.messages import GenerationConfig
242+
241243
batched = True
242244
if isinstance(inputs, str):
243245
inputs = [inputs]
244246
batched = False
245247
prompt = inputs
246248
gen_params = self.update_gen_params(**kwargs)
249+
max_tokens = gen_params.pop('max_tokens')
250+
gen_config = GenerationConfig(**gen_params)
251+
gen_config.max_new_tokens = max_tokens
247252
response = self.model.batch_infer(
248-
prompt, do_preprocess=do_preprocess, **gen_params)
253+
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
249254
response = [resp.text for resp in response]
250255
# remove stop_words
251256
response = filter_suffix(response, self.gen_params.get('stop_words'))

0 commit comments

Comments
 (0)