Skip to content

support inference for pad_token & chatglm chat #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lagent/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .base_api import BaseAPIModel
from .base_llm import BaseModel
from .huggingface import HFTransformer, HFTransformerCasualLM
from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat
from .lmdepoly_wrapper import LMDeployClient, LMDeployPipeline, LMDeployServer
from .meta_template import INTERNLM2_META
from .openai import GPTAPI

__all__ = [
'BaseModel', 'BaseAPIModel', 'GPTAPI', 'LMDeployClient',
'LMDeployPipeline', 'LMDeployServer', 'HFTransformer',
'HFTransformerCasualLM', 'INTERNLM2_META'
'HFTransformerCasualLM', 'INTERNLM2_META', 'HFTransformerChat'
]
3 changes: 2 additions & 1 deletion lagent/llms/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]:
return res

def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
merged_prompt = self.roles[self.roles[role_prompt['role']]]
# merged_prompt = self.roles[self.roles[role_prompt['role']]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete

merged_prompt = self.roles[role_prompt['role']]
if merged_prompt.get('fallback_role'):
merged_prompt = self.roles[self.roles[
merged_prompt['fallback_role']]]
Expand Down
8 changes: 6 additions & 2 deletions lagent/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def __init__(self,
top_k: float = None,
temperature: float = 0.8,
repetition_penalty: float = 1.0,
stop_words: Union[List[str], str] = None):
stop_words: Union[List[str], str] = None,
stop_words_id: Union[List[int], int] = None):
self.path = path
self.tokenizer_only = tokenizer_only
# meta template
Expand All @@ -132,13 +133,16 @@ def __init__(self,

if isinstance(stop_words, str):
stop_words = [stop_words]
if isinstance(stop_words_id, int):
stop_words_id = [stop_words_id]
self.gen_params = dict(
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
stop_words=stop_words)
stop_words=stop_words,
stop_words_id=stop_words_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to huggingface


def generate(self, inputs: Union[str, List[str]], **gen_params) -> str:
"""Generate results given a str (or list of) inputs.
Expand Down
59 changes: 56 additions & 3 deletions lagent/llms/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, List, Optional, Union

from lagent.schema import ModelStatusCode
from .base_api import APITemplateParser
from .base_llm import BaseModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -57,7 +58,9 @@ def __init__(self,
self.prefix_allowed_tokens_fn = None

stop_words_id = []
if self.gen_params.get('stop_words'):
if self.gen_params.get('stop_words_id'):
stop_words_id = self.gen_params.get('stop_words_id')
elif self.gen_params.get('stop_words'):
for sw in self.gen_params.get('stop_words'):
stop_words_id.append(self.tokenizer(sw)['input_ids'][-1])
self.additional_eos_token_id = stop_words_id
Expand All @@ -69,9 +72,28 @@ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
tokenizer_path if tokenizer_path else path,
trust_remote_code=True,
**tokenizer_kwargs)

if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.tokenizer.eos_token is not None:
logger.warning(
f'Using eos_token_id {self.tokenizer.eos_token} '
'as pad_token_id.')
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
from transformers.generation import GenerationConfig
self.gcfg = GenerationConfig.from_pretrained(path)

if self.gcfg.pad_token_id is not None:
logger.warning(
f'Using pad_token_id {self.gcfg.pad_token_id} '
'as pad_token_id.')
self.tokenizer.pad_token_id = self.gcfg.pad_token_id
else:
raise ValueError(
'pad_token_id is not set for this tokenizer. Try to '
'set pad_token_id via passing '
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')

def _load_model(self, path: str, model_kwargs: dict):
import torch
from transformers import AutoModel
Expand Down Expand Up @@ -127,7 +149,6 @@ def stream_generate(
if isinstance(inputs, str):
inputs = [inputs]
batched = False
# import pdb; pdb.set_trace()
inputs = self.tokenizer(
inputs, padding=True, return_tensors='pt', return_length=True)
input_length = inputs['length']
Expand All @@ -148,6 +169,11 @@ def stream_generate(
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if eos_token_id is None:
if self.gcfg.eos_token_id is not None:
eos_token_id = self.gcfg.eos_token_id
else:
eos_token_id = []
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if self.additional_eos_token_id is not None:
Expand Down Expand Up @@ -267,3 +293,30 @@ def _load_model(self, path: str, model_kwargs: dict):
self.model = AutoModelForCausalLM.from_pretrained(
path, trust_remote_code=True, **model_kwargs)
self.model.eval()

class HFTransformerChat(HFTransformerCasualLM):
def __init__(self,
template_parser=APITemplateParser,
**kwargs):
super().__init__(template_parser=template_parser, **kwargs)

def chat(self, inputs: List[dict], do_sample: bool = True, **kwargs):
"""Return the chat completions in stream mode.

Args:
inputs (List[dict]): input messages to be completed.
do_sample (bool): do sampling if enabled
Returns:
the text/chat completion
"""
prompt = self.template_parser(inputs)
query = prompt[-1]['content']
history = prompt[:-1]
try:
response, history = self.model.chat(self.tokenizer,
query,
history=history)
except:
response = ""
return response