Skip to content

Commit 8e36fe9

Browse files
authored
Transformers Backend: max_tokens adherence to OpenAI API (#2108)
max token adherence to OpenAI API improve adherence to OpenAI API when max tokens is omitted or equal to 0 in the request
1 parent 0d8bf91 commit 8e36fe9

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

backend/python/transformers/transformers_server.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ def LoadModel(self, request, context):
159159
quantization_config=quantization,
160160
device_map=device_map,
161161
torch_dtype=compute)
162+
if request.ContextSize > 0:
163+
self.max_tokens = request.ContextSize
164+
else:
165+
self.max_tokens = self.model.config.max_position_embeddings
166+
162167
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
163168
self.XPU = False
164169

@@ -217,10 +222,6 @@ async def _predict(self, request, context, streaming=False):
217222
if request.TopK == 0:
218223
request.TopK = 40
219224

220-
max_tokens = 200
221-
if request.Tokens > 0:
222-
max_tokens = request.Tokens
223-
224225
prompt = request.Prompt
225226
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
226227
prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True)
@@ -232,6 +233,12 @@ async def _predict(self, request, context, streaming=False):
232233
eos_token_id.append(self.tokenizer.convert_tokens_to_ids(word))
233234

234235
inputs = self.tokenizer(prompt, return_tensors="pt")
236+
237+
if request.Tokens > 0:
238+
max_tokens = request.Tokens
239+
else:
240+
max_tokens = self.max_tokens - inputs["input_ids"].size()[inputs["input_ids"].dim()-1]
241+
235242
if self.CUDA:
236243
inputs = inputs.to("cuda")
237244
if XPU and self.OV == False:

0 commit comments

Comments
 (0)