Skip to content

Fixed Spelling Mirostate -> Mirostat #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Common
{
Expand Down Expand Up @@ -83,7 +82,7 @@ public class InferenceParams
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
/// </summary>
public MiroStateType Mirostat { get; set; } = MiroStateType.Disable;
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
/// <summary>
/// target entropy
/// </summary>
Expand All @@ -98,10 +97,25 @@ public class InferenceParams
public bool PenalizeNL { get; set; } = true;
}

public enum MiroStateType
/// <summary>
/// Type of "mirostat" sampling to use.
/// https://github.com/basusourya/mirostat
/// </summary>
public enum MirostatType
{
/// <summary>
/// Disable Mirostat sampling
/// </summary>
Disable = 0,
MiroState = 1,
MiroState2 = 2

/// <summary>
/// Original mirostat algorithm
/// </summary>
Mirostat = 1,

/// <summary>
/// Mirostat 2.0 algorithm
/// </summary>
Mirostat2 = 2
}
}
6 changes: 3 additions & 3 deletions LLama/LLamaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ public void LoadState(State state)
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MiroStateType mirostat = MiroStateType.Disable,
public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f)
{
llama_token id = 0;
Expand All @@ -240,14 +240,14 @@ public llama_token Sample(LLamaTokenDataArray candidates, float temperature = 0.
}
else
{
if (mirostat == MiroStateType.MiroState)
if (mirostat == MirostatType.Mirostat)
{
float mirostat_mu = 2.0f * mirostatTau;
const int mirostat_m = 100;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu);
}
else if (mirostat == MiroStateType.MiroState2)
else if (mirostat == MirostatType.Mirostat2)
{
float mirostat_mu = 2.0f * mirostatTau;
SamplingApi.llama_sample_temperature(_ctx, candidates, temperature);
Expand Down