Skip to content

Commit e478d58

Browse files
Add mod chat (#1154)
* Add mod chat * code cleanup * code review
1 parent f457b16 commit e478d58

File tree

3 files changed

+151
-0
lines changed

3 files changed

+151
-0
lines changed

chat/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
## Chat CLI
2+
3+
## Usage
4+
./run.sh --gpu_id <GPU_ID> --model_path <MODEL_PATH>
5+
6+
--gpu_id: Specifies the GPU ID to use. This maps to the CUDA_VISIBLE_DEVICES environment variable.
7+
8+
--model_path: The file path to the chat model. This must point to a valid model directory or file.

chat/chat.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import argparse
2+
import json
3+
from datetime import datetime
4+
from gptqmodel import GPTQModel
5+
from colorama import Fore, init
6+
init(autoreset=True)
7+
8+
9+
USER_PROMPT = "User >>> \n"
10+
ASSISTANT_PROMPT = "Assistant >>> \n"
11+
12+
KEY_USER = 'user'
13+
KEY_ASSISTANT = 'assistant'
14+
15+
ASSISTANT_HELLO = 'How can I help you?'
16+
EXIT_MESSAGE = 'Exiting the program.'
17+
18+
MESSAGES = [
19+
{"role": "system", "content": "You are a helpful and harmless assistant. You should think step-by-step."}
20+
]
21+
22+
DEBUG = False
23+
24+
25+
def load_model(model_path):
26+
print(Fore.BLUE + f"Loading model from `{model_path}` ...\n")
27+
model = GPTQModel.load(model_path)
28+
return model
29+
30+
31+
def chat_prompt_progress(user_input, tokenizer):
32+
user_message = {"role": KEY_USER, "content": user_input}
33+
MESSAGES.append(user_message)
34+
input_tensor = tokenizer.apply_chat_template(MESSAGES, add_generation_prompt=True, return_tensors="pt")
35+
if DEBUG:
36+
debug(tokenizer)
37+
return input_tensor
38+
39+
40+
def debug(tokenizer):
41+
print("********* DEBUG START *********")
42+
print("********* Chat Template info *********")
43+
print(tokenizer.apply_chat_template(MESSAGES, return_dict=False, tokenize=False, add_generation_prompt=True))
44+
print("********* DEBUG END *********")
45+
46+
47+
def get_user_input():
48+
user_input = input(Fore.GREEN + USER_PROMPT)
49+
return user_input
50+
51+
52+
def print_model_message(message):
53+
print(Fore.CYAN + f"{ASSISTANT_PROMPT}{message}\n")
54+
55+
56+
def save_chat_history(chat_history, save_path):
57+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
58+
filename = f"chat_history_{timestamp}.json"
59+
if save_path is not None:
60+
filename = f"{save_path}/chat_history_{timestamp}.json"
61+
with open(filename, 'w') as file:
62+
json.dump(chat_history, file, indent=4, ensure_ascii=False)
63+
print(Fore.YELLOW + f"Chat history saved to '{filename}'.\n")
64+
65+
66+
if __name__ == "__main__":
67+
parser = argparse.ArgumentParser(description="Chat with a GPT model.")
68+
parser.add_argument('--model_path', type=str, help="Path to the model.")
69+
parser.add_argument('--save_chat_path', type=str, help="Path to save the chat history.")
70+
parser.add_argument('--debug', action='store_true', default=False,
71+
help='Print Debug Info')
72+
args = parser.parse_args()
73+
if args.model_path is None:
74+
raise ValueError("Model path is None, Please Set `--model_path`")
75+
DEBUG = args.debug
76+
77+
model = load_model(args.model_path)
78+
79+
print(Fore.CYAN + "Welcome to GPTQModel Chat Assistant!\n")
80+
print(Fore.YELLOW + "You can enter questions or commands as follows:\n")
81+
print(Fore.YELLOW + "1. Type your question for the model.\n")
82+
print(Fore.YELLOW + "2. Type 'exit' to quit the program.\n")
83+
print(Fore.YELLOW + "3. Type 'save' to save the chat history.\n")
84+
85+
tokenizer = model.tokenizer
86+
if tokenizer.pad_token_id is None:
87+
tokenizer.pad_token_id = tokenizer.eos_token_id
88+
89+
chat_history = [] # chat history
90+
91+
print_model_message(ASSISTANT_HELLO)
92+
93+
while True:
94+
user_input = get_user_input()
95+
96+
if user_input.lower() == 'exit':
97+
print(Fore.RED + f"{EXIT_MESSAGE}\n")
98+
break
99+
elif user_input.lower() == 'save':
100+
save_chat_history(chat_history, args.save_chat_path)
101+
else:
102+
input_tensor = chat_prompt_progress(user_input, tokenizer)
103+
outputs = model.generate(
104+
input_ids=input_tensor.to(model.device),
105+
max_new_tokens=4096,
106+
pad_token_id=tokenizer.pad_token_id
107+
)
108+
assistant_response = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
109+
110+
MESSAGES.append({"role": KEY_ASSISTANT, "content": assistant_response})
111+
chat_history.append({KEY_USER: user_input, KEY_ASSISTANT: assistant_response})
112+
113+
print_model_message(assistant_response)

chat/run.sh

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
3+
GPU_ID=0
4+
MODEL_PATH=""
5+
6+
while [[ $# -gt 0 ]]; do
7+
case $1 in
8+
--gpu_id)
9+
GPU_ID="$2"
10+
shift
11+
shift
12+
;;
13+
--model_path)
14+
MODEL_PATH="$2"
15+
shift
16+
shift
17+
;;
18+
*)
19+
echo "Unknow $1"
20+
exit 1
21+
;;
22+
esac
23+
done
24+
25+
if [[ -z "$MODEL_PATH" ]]; then
26+
echo "--model_path REQUIRED!"
27+
exit 1
28+
fi
29+
30+
env CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES="$GPU_ID" python chat.py --model_path "$MODEL_PATH"

0 commit comments

Comments
 (0)