Skip to content

Fixed Mirostate Sampling #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
/// The mode used by the executor.
/// </summary>
public LLamaModel Model => _model;

/// <summary>
/// Current "mu" value for mirostate sampling
/// </summary>
protected float MirostateMu { get; set; } = float.NaN;

/// <summary>
///
/// </summary>
Expand All @@ -78,8 +84,6 @@ protected StatefulExecutorBase(LLamaModel model, ILLamaLogger? logger = null)
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
_embeds = new();
_embed_inps = new();
_last_n_tokens = new FixedSizeQueue<llama_token>(_model.ContextSize).FillWith(0);
}

Expand Down Expand Up @@ -359,24 +363,36 @@ public class ExecutorBaseState
{
[JsonPropertyName("n_past")]
public int PastTokensCount { get; set; }

[JsonPropertyName("n_consumed")]
public int ConsumedTokensCount { get; set; }

[JsonPropertyName("n_session_consumed")]
public int ConsumedSessionCount { get; set; }

[JsonPropertyName("n_matching_session_tokens")]
public int MatchingSessionTokensCount { get; set; }

[JsonPropertyName("path_session")]
public string SessionFilePath { get; set; }

[JsonPropertyName("embd")]
public List<llama_token> Embeds { get; set; }

[JsonPropertyName("embd_inps")]
public List<llama_token> EmbedInps { get; set; }

[JsonPropertyName("session_tokens")]
public List<llama_token> SessionTokens { get; set; }

[JsonPropertyName("last_n_tokens")]
public llama_token[] LastTokens { get; set; }

[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }

[JsonPropertyName("mirostate_mu")]
public float MirostateMu { get; set; }
}
}
}
13 changes: 9 additions & 4 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;

Expand All @@ -20,6 +19,7 @@ public class InstructExecutor : StatefulExecutorBase
string _instructionPrefix;
llama_token[] _inp_pfx;
llama_token[] _inp_sfx;

/// <summary>
///
/// </summary>
Expand Down Expand Up @@ -51,7 +51,8 @@ public override ExecutorBaseState GetStateData()
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity
LastTokensCapacity = _last_n_tokens.Capacity,
MirostateMu = MirostateMu
};
return state;
}
Expand Down Expand Up @@ -214,8 +215,12 @@ protected override void InferInternal(InferenceParams inferenceParams, InferStat
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
var mu = MirostateMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
MirostateMu = mu;

_last_n_tokens.Enqueue(id);

Expand Down
16 changes: 9 additions & 7 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,8 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;

namespace LLama
{
Expand All @@ -21,6 +17,7 @@ public class InteractiveExecutor : StatefulExecutorBase
{
bool _is_prompt_run = true;
llama_token[] _llama_token_newline;

/// <summary>
///
/// </summary>
Expand All @@ -46,7 +43,8 @@ public override ExecutorBaseState GetStateData()
PastTokensCount = _pastTokensCount,
SessionFilePath = _pathSession,
SessionTokens = _session_tokens,
LastTokensCapacity = _last_n_tokens.Capacity
LastTokensCapacity = _last_n_tokens.Capacity,
MirostateMu = MirostateMu
};
return state;
}
Expand Down Expand Up @@ -204,8 +202,12 @@ protected override void InferInternal(InferenceParams inferenceParams, InferStat
var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);
var mu = MirostateMu;
var id = _model.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP
);
MirostateMu = mu;

_last_n_tokens.Enqueue(id);

Expand Down
14 changes: 8 additions & 6 deletions LLama/LLamaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ public void LoadState(State state)
/// Perform the sampling. Please don't use it unless you fully know what it does.
/// </summary>
/// <param name="candidates"></param>
/// <param name="mirostat_mu"></param>
/// <param name="temperature"></param>
/// <param name="mirostat"></param>
/// <param name="mirostatTau"></param>
Expand All @@ -229,27 +230,28 @@ public void LoadState(State state)
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MiroStatType mirostat = MiroStatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{
llama_token id = 0;
llama_token id;
if (temperature <= 0)
{
// Greedy sampling
id = SamplingApi.llama_sample_token_greedy(_ctx, candidates);
}
else
{
if (mirostat == MirostatType.Mirostat)
if (float.IsNaN(mirostat_mu))
mirostat_mu = 2 * mirostatTau;

if (mirostat == MiroStatType.MiroStat)
{
float mirostat_mu = 2.0f * mirostatTau;
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
float mirostat_mu = 2.0f * mirostatTau;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu);
}
Expand Down
3 changes: 2 additions & 1 deletion LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public IEnumerable<string> Infer(string text, InferenceParams? inferenceParams =
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;

var mu = float.NaN;
int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
for(int i = 0; i < max_tokens; i++)
{
Expand All @@ -70,7 +71,7 @@ public IEnumerable<string> Infer(string text, InferenceParams? inferenceParams =
var tokenDataArray = _model.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var id = _model.Sample(tokenDataArray, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
var id = _model.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP);

lastTokens.Add(id);
Expand Down