Skip to content

Add LLama2 Chat Session example with a custom templator #938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ public class ExampleRunner
private static readonly Dictionary<string, Func<Task>> Examples = new()
{
{ "Chat Session: LLama3", LLama3ChatSession.Run },
{ "Chat Session: LLama2", LLama2ChatSession.Run },
{ "Chat Session: History", ChatSessionWithHistory.Run },
{ "Chat Session: Role names", ChatSessionWithRoleName.Run },
{ "Chat Session: Role names stripped", ChatSessionStripRoleName.Run },
Expand Down
141 changes: 141 additions & 0 deletions LLama.Examples/Examples/LLama2ChatSession.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
using LLama.Abstractions;
using LLama.Common;
using LLama.Sampling;
using System.Text;

namespace LLama.Examples.Examples;

/// <summary>
/// This sample shows a simple chatbot
/// It's configured to use custom prompt template as provided by llama.cpp and supports
/// models such as LLama 2 and Mistral Instruct
/// </summary>
public class LLama2ChatSession
{
public static async Task Run()
{
var modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath)
{
Seed = 1337,
GpuLayerCount = 10
};

using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
var chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

ChatSession session = new(executor, chatHistory);

// add custom templator
session.WithHistoryTransform(new Llama2HistoryTransformer());

session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
[model.Tokens.EndOfTurnToken ?? "User:", "�"],
redundancyLength: 5));

var inferenceParams = new InferenceParams
{
SamplingPipeline = new DefaultSamplingPipeline
{
Temperature = 0.6f
},

MaxTokens = -1, // keep generating tokens until the anti prompt is encountered
AntiPrompts = [model.Tokens.EndOfTurnToken ?? "User:"] // model specific end of turn string (or default)
};

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");

// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
Console.Write("User> ");
var userInput = Console.ReadLine() ?? "";

while (userInput != "exit")
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Assistant> ");

// as each token (partial or whole word is streamed back) print it to the console, stream to web client, etc
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write(text);
}
Console.WriteLine();

Console.ForegroundColor = ConsoleColor.Green;
Console.Write("User> ");
userInput = Console.ReadLine() ?? "";
}
}

/// <summary>
/// Chat History transformer for Llama 2 family.
/// https://huggingface.co/blog/llama2#how-to-prompt-llama-2
/// </summary>
public class Llama2HistoryTransformer : IHistoryTransform
{
public string Name => "Llama2";

/// <inheritdoc/>
public IHistoryTransform Clone()
{
return new Llama2HistoryTransformer();
}

/// <inheritdoc/>
public string HistoryToText(ChatHistory history)
{
//More info on template format for llama2 https://huggingface.co/blog/llama2#how-to-prompt-llama-2
//We don't have to insert <BOS> token for the first message, as it's done automatically by LLamaSharp.InteractExecutor and LLama.cpp
//See more in https://github.com/ggerganov/llama.cpp/pull/7107
if (history.Messages.Count == 0)
return string.Empty;

var builder = new StringBuilder(64 * history.Messages.Count);

int i = 0;
if (history.Messages[i].AuthorRole == AuthorRole.System)
{
builder.Append($"[INST] <<SYS>>\n").Append(history.Messages[0].Content.Trim()).Append("\n<</SYS>>\n");
i++;

if (history.Messages.Count > 1)
{
builder.Append(history.Messages[1].Content.Trim()).Append(" [/INST]");
i++;
}
}

for (; i < history.Messages.Count; i++)
{
if (history.Messages[i].AuthorRole == AuthorRole.User)
{
builder.Append(i == 0 ? "[INST] " : "<s>[INST] ").Append(history.Messages[i].Content.Trim()).Append(" [/INST]");
}
else
{
builder.Append(' ').Append(history.Messages[i].Content.Trim()).Append(" </s>");
}
}

return builder.ToString();
}

/// <inheritdoc/>
public ChatHistory TextToHistory(AuthorRole role, string text)
{
return new ChatHistory([new ChatHistory.Message(role, text)]);
}
}
}
6 changes: 3 additions & 3 deletions LLama.Examples/Examples/LLama3ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace LLama.Examples.Examples;
/// <summary>
/// This sample shows a simple chatbot
/// It's configured to use the default prompt template as provided by llama.cpp and supports
/// models such as llama3, llama2, phi3, qwen1.5, etc.
/// models such as llama3, phi3, qwen1.5, etc.
/// </summary>
public class LLama3ChatSession
{
Expand Down Expand Up @@ -35,7 +35,7 @@ public static async Task Run()

// Add a transformer to eliminate printing the end of turn tokens, llama 3 specifically has an odd LF that gets printed sometimes
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
[model.Tokens.EndOfTurnToken!, "�"],
[model.Tokens.EndOfTurnToken ?? "User:", "�"],
redundancyLength: 5));

var inferenceParams = new InferenceParams
Expand All @@ -46,7 +46,7 @@ public static async Task Run()
},

MaxTokens = -1, // keep generating tokens until the anti prompt is encountered
AntiPrompts = [model.Tokens.EndOfTurnToken!] // model specific end of turn string
AntiPrompts = [model.Tokens.EndOfTurnToken ?? "User:"] // model specific end of turn string (or default)
};

Console.ForegroundColor = ConsoleColor.Yellow;
Expand Down
2 changes: 1 addition & 1 deletion LLama/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@
if (state.IsPromptRun)
{
// If the session history was added as part of new chat session history,
// convert the complete history includsing system message and manually added history
// convert the complete history including system message and manually added history
// to a prompt that adhere to the prompt template specified in the HistoryTransform class implementation.
prompt = HistoryTransform.HistoryToText(History);
}
Expand Down Expand Up @@ -779,7 +779,7 @@

return new SessionState(
contextState,
executorState,

Check warning on line 782 in LLama/ChatSession.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'executorState' in 'SessionState.SessionState(State? contextState, ExecutorBaseState executorState, ChatHistory history, List<ITextTransform> inputTransformPipeline, ITextStreamTransform outputTransform, IHistoryTransform historyTransform)'.

Check warning on line 782 in LLama/ChatSession.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'executorState' in 'SessionState.SessionState(State? contextState, ExecutorBaseState executorState, ChatHistory history, List<ITextTransform> inputTransformPipeline, ITextStreamTransform outputTransform, IHistoryTransform historyTransform)'.

Check warning on line 782 in LLama/ChatSession.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'executorState' in 'SessionState.SessionState(State? contextState, ExecutorBaseState executorState, ChatHistory history, List<ITextTransform> inputTransformPipeline, ITextStreamTransform outputTransform, IHistoryTransform historyTransform)'.
history,
inputTransforms.ToList(),
outputTransform,
Expand Down
Loading