Skip to content

Commit 2d8955d

Browse files
authored
Merge pull request #938 from asmirnov82/937_llama2_chat_session_example
Add LLama2 Chat Session example with a custom templator
2 parents ce8b05c + b3f420d commit 2d8955d

File tree

4 files changed

+146
-4
lines changed

4 files changed

+146
-4
lines changed

LLama.Examples/ExampleRunner.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ public class ExampleRunner
66
private static readonly Dictionary<string, Func<Task>> Examples = new()
77
{
88
{ "Chat Session: LLama3", LLama3ChatSession.Run },
9+
{ "Chat Session: LLama2", LLama2ChatSession.Run },
910
{ "Chat Session: History", ChatSessionWithHistory.Run },
1011
{ "Chat Session: Role names", ChatSessionWithRoleName.Run },
1112
{ "Chat Session: Role names stripped", ChatSessionStripRoleName.Run },
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
using LLama.Abstractions;
2+
using LLama.Common;
3+
using LLama.Sampling;
4+
using System.Text;
5+
6+
namespace LLama.Examples.Examples;
7+
8+
/// <summary>
9+
/// This sample shows a simple chatbot
10+
/// It's configured to use custom prompt template as provided by llama.cpp and supports
11+
/// models such as LLama 2 and Mistral Instruct
12+
/// </summary>
13+
public class LLama2ChatSession
14+
{
15+
public static async Task Run()
16+
{
17+
var modelPath = UserSettings.GetModelPath();
18+
var parameters = new ModelParams(modelPath)
19+
{
20+
Seed = 1337,
21+
GpuLayerCount = 10
22+
};
23+
24+
using var model = LLamaWeights.LoadFromFile(parameters);
25+
using var context = model.CreateContext(parameters);
26+
var executor = new InteractiveExecutor(context);
27+
28+
var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
29+
var chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
30+
31+
ChatSession session = new(executor, chatHistory);
32+
33+
// add custom templator
34+
session.WithHistoryTransform(new Llama2HistoryTransformer());
35+
36+
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
37+
[model.Tokens.EndOfTurnToken ?? "User:", "�"],
38+
redundancyLength: 5));
39+
40+
var inferenceParams = new InferenceParams
41+
{
42+
SamplingPipeline = new DefaultSamplingPipeline
43+
{
44+
Temperature = 0.6f
45+
},
46+
47+
MaxTokens = -1, // keep generating tokens until the anti prompt is encountered
48+
AntiPrompts = [model.Tokens.EndOfTurnToken ?? "User:"] // model specific end of turn string (or default)
49+
};
50+
51+
Console.ForegroundColor = ConsoleColor.Yellow;
52+
Console.WriteLine("The chat session has started.");
53+
54+
// show the prompt
55+
Console.ForegroundColor = ConsoleColor.Green;
56+
Console.Write("User> ");
57+
var userInput = Console.ReadLine() ?? "";
58+
59+
while (userInput != "exit")
60+
{
61+
Console.ForegroundColor = ConsoleColor.White;
62+
Console.Write("Assistant> ");
63+
64+
// as each token (partial or whole word is streamed back) print it to the console, stream to web client, etc
65+
await foreach (
66+
var text
67+
in session.ChatAsync(
68+
new ChatHistory.Message(AuthorRole.User, userInput),
69+
inferenceParams))
70+
{
71+
Console.ForegroundColor = ConsoleColor.White;
72+
Console.Write(text);
73+
}
74+
Console.WriteLine();
75+
76+
Console.ForegroundColor = ConsoleColor.Green;
77+
Console.Write("User> ");
78+
userInput = Console.ReadLine() ?? "";
79+
}
80+
}
81+
82+
/// <summary>
83+
/// Chat History transformer for Llama 2 family.
84+
/// https://huggingface.co/blog/llama2#how-to-prompt-llama-2
85+
/// </summary>
86+
public class Llama2HistoryTransformer : IHistoryTransform
87+
{
88+
public string Name => "Llama2";
89+
90+
/// <inheritdoc/>
91+
public IHistoryTransform Clone()
92+
{
93+
return new Llama2HistoryTransformer();
94+
}
95+
96+
/// <inheritdoc/>
97+
public string HistoryToText(ChatHistory history)
98+
{
99+
//More info on template format for llama2 https://huggingface.co/blog/llama2#how-to-prompt-llama-2
100+
//We don't have to insert <BOS> token for the first message, as it's done automatically by LLamaSharp.InteractExecutor and LLama.cpp
101+
//See more in https://github.com/ggerganov/llama.cpp/pull/7107
102+
if (history.Messages.Count == 0)
103+
return string.Empty;
104+
105+
var builder = new StringBuilder(64 * history.Messages.Count);
106+
107+
int i = 0;
108+
if (history.Messages[i].AuthorRole == AuthorRole.System)
109+
{
110+
builder.Append($"[INST] <<SYS>>\n").Append(history.Messages[0].Content.Trim()).Append("\n<</SYS>>\n");
111+
i++;
112+
113+
if (history.Messages.Count > 1)
114+
{
115+
builder.Append(history.Messages[1].Content.Trim()).Append(" [/INST]");
116+
i++;
117+
}
118+
}
119+
120+
for (; i < history.Messages.Count; i++)
121+
{
122+
if (history.Messages[i].AuthorRole == AuthorRole.User)
123+
{
124+
builder.Append(i == 0 ? "[INST] " : "<s>[INST] ").Append(history.Messages[i].Content.Trim()).Append(" [/INST]");
125+
}
126+
else
127+
{
128+
builder.Append(' ').Append(history.Messages[i].Content.Trim()).Append(" </s>");
129+
}
130+
}
131+
132+
return builder.ToString();
133+
}
134+
135+
/// <inheritdoc/>
136+
public ChatHistory TextToHistory(AuthorRole role, string text)
137+
{
138+
return new ChatHistory([new ChatHistory.Message(role, text)]);
139+
}
140+
}
141+
}

LLama.Examples/Examples/LLama3ChatSession.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace LLama.Examples.Examples;
77
/// <summary>
88
/// This sample shows a simple chatbot
99
/// It's configured to use the default prompt template as provided by llama.cpp and supports
10-
/// models such as llama3, llama2, phi3, qwen1.5, etc.
10+
/// models such as llama3, phi3, qwen1.5, etc.
1111
/// </summary>
1212
public class LLama3ChatSession
1313
{
@@ -35,7 +35,7 @@ public static async Task Run()
3535

3636
// Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed sometimes
3737
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
38-
[model.Tokens.EndOfTurnToken!, "�"],
38+
[model.Tokens.EndOfTurnToken ?? "User:", "�"],
3939
redundancyLength: 5));
4040

4141
var inferenceParams = new InferenceParams
@@ -46,7 +46,7 @@ public static async Task Run()
4646
},
4747

4848
MaxTokens = -1, // keep generating tokens until the anti prompt is encountered
49-
AntiPrompts = [model.Tokens.EndOfTurnToken!] // model specific end of turn string
49+
AntiPrompts = [model.Tokens.EndOfTurnToken ?? "User:"] // model specific end of turn string (or default)
5050
};
5151

5252
Console.ForegroundColor = ConsoleColor.Yellow;

LLama/ChatSession.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ public async IAsyncEnumerable<string> ChatAsync(
428428
if (state.IsPromptRun)
429429
{
430430
// If the session history was added as part of new chat session history,
431-
// convert the complete history includsing system message and manually added history
431+
// convert the complete history including system message and manually added history
432432
// to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
433433
prompt = HistoryTransform.HistoryToText(History);
434434
}

0 commit comments

Comments
 (0)