Skip to content

Commit b08ba99

Browse files
authored
Merge pull request #1150 from nipeone/feature-llamareranker
add LLamaReranker and tests
2 parents 01d5c36 + d838e1c commit b08ba99

File tree

6 files changed

+344
-6
lines changed

6 files changed

+344
-6
lines changed

LLama.Unittest/Constants.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ internal static class Constants
77
public static readonly string GenerativeModelPath = "Models/Llama-3.2-1B-Instruct-Q4_0.gguf";
88
public static readonly string GenerativeModelPath2 = "Models/smollm-360m-instruct-add-basics-q8_0.gguf";
99
public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf";
10+
public static readonly string RerankingModelPath = "Models/jina-reranker-v1-tiny-en-FP16.gguf";
1011

1112
public static readonly string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf";
1213
public static readonly string LLavaMmpPath = "Models/mmproj-model-f16.gguf";

LLama.Unittest/LLama.Unittest.csproj

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@
4646
<LocalFileName>smollm-360m-instruct-add-basics-q8_0.gguf</LocalFileName>
4747
</DownloadFileItem>
4848

49+
<DownloadFileItem Include="jina-reranker-v1-tiny-en-FP16.gguf">
50+
<SourceUrl>https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-FP16.gguf</SourceUrl>
51+
<DestinationFolder>Models</DestinationFolder>
52+
<LocalFileName>jina-reranker-v1-tiny-en-FP16.gguf</LocalFileName>
53+
</DownloadFileItem>
54+
4955
<DownloadFileItem Include="llava-v1.6-mistral-7b">
5056
<SourceUrl>https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q3_K_XS.gguf</SourceUrl>
5157
<DestinationFolder>Models</DestinationFolder>
@@ -130,6 +136,9 @@
130136
<None Update="Models\Llama-3.2-1B-Instruct-Q4_0.gguf">
131137
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
132138
</None>
139+
<None Update="Models\jina-reranker-v1-tiny-en-FP16.gguf">
140+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
141+
</None>
133142
<None Update="Models\smollm-360m-instruct-add-basics-q8_0.gguf">
134143
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
135144
</None>

LLama.Unittest/LLamaRerankerTests.cs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using LLama.Common;
2+
using LLama.Extensions;
3+
using LLama.Native;
4+
using Microsoft.Extensions.AI;
5+
using System.Runtime.InteropServices;
6+
using Xunit.Abstractions;
7+
8+
namespace LLama.Unittest;
9+
10+
public sealed class LLamaRerankerTests: IDisposable
11+
{
12+
private readonly ITestOutputHelper _testOutputHelper;
13+
private readonly LLamaReranker _reranker;
14+
public LLamaRerankerTests(ITestOutputHelper testOutputHelper)
15+
{
16+
_testOutputHelper = testOutputHelper;
17+
18+
var @params = new ModelParams(Constants.RerankingModelPath)
19+
{
20+
ContextSize = 0,
21+
PoolingType = LLamaPoolingType.Rank,
22+
GpuLayerCount = Constants.CIGpuLayerCount,
23+
24+
};
25+
using var weights = LLamaWeights.LoadFromFile(@params);
26+
_reranker = new LLamaReranker(weights, @params);
27+
}
28+
29+
public void Dispose()
30+
{
31+
_reranker.Dispose();
32+
}
33+
34+
[Fact]
35+
public async Task CompareRerankingScore()
36+
{
37+
38+
39+
var input = "what is panda?";
40+
var documents = new string[] {
41+
"hi",
42+
"it's a bear",
43+
string.Join(", ","The giant panda (Ailuropoda melanoleuca)",
44+
"sometimes called a panda bear or simply panda",
45+
"is a bear species endemic to China.")
46+
};
47+
var scores = await _reranker.GetRelevanceScores(input, documents, normalize: false);
48+
49+
Assert.True(documents.Length == scores.Count);
50+
51+
_testOutputHelper.WriteLine($"Rerank score 0: {scores[0]:F4}");
52+
_testOutputHelper.WriteLine($"Rerank score 1: {scores[1]:F4}");
53+
_testOutputHelper.WriteLine($"Rerank score 2: {scores[2]:F4}");
54+
}
55+
56+
[Fact]
57+
public async Task MostRelevantDocument()
58+
{
59+
var input = "what is panda?";
60+
var documents = new string[] {
61+
"hi",
62+
"it's a bear",
63+
string.Join(", ","The giant panda (Ailuropoda melanoleuca)",
64+
"sometimes called a panda bear or simply panda",
65+
"is a bear species endemic to China.")
66+
};
67+
var scores = await _reranker.GetRelevanceScores(input, documents, normalize: true);
68+
69+
Assert.NotNull(scores);
70+
Assert.True(documents.Length == scores.Count);
71+
72+
int maxIndex = scores.Select((score, index) => (score, index))
73+
.MaxBy(x => x.score)
74+
.index;
75+
76+
var maxScoreDocument = documents[maxIndex];
77+
Assert.Equal(documents[2], maxScoreDocument);
78+
}
79+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using System.Text;
2+
using System.Xml.Linq;
3+
using LLama.Common;
4+
using LLama.Extensions;
5+
using Microsoft.Extensions.Logging;
6+
7+
8+
namespace LLama.Unittest.Native;
9+
10+
public class SafeLlamaModelHandleVocabularyTests: IDisposable
11+
{
12+
private readonly LLamaWeights _model;
13+
14+
public SafeLlamaModelHandleVocabularyTests()
15+
{
16+
var @params = new ModelParams(Constants.RerankingModelPath)
17+
{
18+
ContextSize = 0,
19+
PoolingType = LLama.Native.LLamaPoolingType.Rank,
20+
GpuLayerCount = Constants.CIGpuLayerCount
21+
};
22+
_model = LLamaWeights.LoadFromFile(@params);
23+
}
24+
25+
public void Dispose()
26+
{
27+
_model.Dispose();
28+
}
29+
30+
[Fact]
31+
public void GetLLamaTokenString()
32+
{
33+
var bos = _model.Vocab.BOS;
34+
var eos = _model.Vocab.EOS;
35+
36+
var bosStr = _model.Vocab.LLamaTokenToString(bos, true);
37+
var eosStr = _model.Vocab.LLamaTokenToString(eos, true);
38+
39+
Assert.Equal("<s>", bosStr);
40+
Assert.Equal("</s>", eosStr);
41+
}
42+
}

LLama/LLamaReranker.cs

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.IO;
4+
using System.Linq;
5+
using System.Text;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
using System.Xml.Linq;
9+
using LLama.Abstractions;
10+
using LLama.Exceptions;
11+
using LLama.Native;
12+
using Microsoft.Extensions.Logging;
13+
14+
namespace LLama;
15+
16+
/// <summary>
17+
/// Get rank scores between prompt and documents
18+
/// </summary>
19+
public sealed partial class LLamaReranker
20+
: IDisposable
21+
{
22+
/// <summary>
23+
/// Dimension of embedding vectors
24+
/// </summary>
25+
public int EmbeddingSize => Context.EmbeddingSize;
26+
27+
/// <summary>
28+
/// LLama Context
29+
/// </summary>
30+
public LLamaContext Context { get; }
31+
32+
/// <summary>
33+
/// Create a new reranker, using the given LLamaWeights
34+
/// </summary>
35+
/// <param name="weights"></param>
36+
/// <param name="params"></param>
37+
/// <param name="logger"></param>
38+
public LLamaReranker(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
39+
{
40+
if (@params.UBatchSize != @params.BatchSize)
41+
throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(@params));
42+
if (weights.NativeHandle is { HasEncoder: true, HasDecoder: true })
43+
throw new NotSupportedException("Computing rank in encoder-decoder models is not supported");
44+
if (@params.PoolingType != LLamaPoolingType.Rank)
45+
throw new NotSupportedException("Computing rank score, PoolingType must be equal to LLamaPoolingType.Rank");
46+
Context = weights.CreateContext(@params, logger);
47+
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
48+
}
49+
50+
/// <inheritdoc />
51+
public void Dispose()
52+
{
53+
Context.Dispose();
54+
}
55+
56+
/// <summary>
57+
/// Retrieve relevance scores for input and documents by reranking, execute once.
58+
/// </summary>
59+
/// <param name="input"></param>
60+
/// <param name="documents"></param>
61+
/// <param name="normalize">Whether to normalize the score to the range (0, 1)</param>
62+
/// <param name="cancellationToken"></param>
63+
/// <returns></returns>
64+
/// <exception cref="RuntimeError"></exception>
65+
/// <exception cref="NotSupportedException"></exception>
66+
public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOnlyList<string> documents, bool normalize = false, CancellationToken cancellationToken = default)
67+
{
68+
List<float> scores = new List<float>(documents.Count);
69+
var inputTokens = Context.Tokenize(input);
70+
var batch = new LLamaBatch();
71+
var clearFlag = 0;
72+
73+
for(var idx = 0; idx < documents.Count; idx++)
74+
{
75+
var docTokens = Context.Tokenize(documents[idx] ?? "");
76+
LLamaToken[] tokens = [.. inputTokens, .. docTokens];
77+
78+
if (batch.TokenCount + tokens.Length > Context.ContextSize)
79+
{
80+
scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken));
81+
batch.Clear();
82+
clearFlag = idx;
83+
}
84+
85+
for (var i = 0; i < tokens.Length; i++)
86+
batch.Add(tokens[i], i, (LLamaSeqId)(idx - clearFlag), true);
87+
}
88+
if (batch.LogitPositionCount > 0)
89+
{
90+
scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken));
91+
batch.Clear();
92+
}
93+
94+
return scores;
95+
}
96+
97+
/// <summary>
98+
/// Retrieve relevance score for input and document by reranking
99+
/// </summary>
100+
/// <param name="input"></param>
101+
/// <param name="document"></param>
102+
/// <param name="cancellationToken"></param>
103+
/// <param name="normalize">Whether to normalize the score to the range (0, 1)</param>
104+
/// <returns></returns>
105+
/// <exception cref="RuntimeError"></exception>
106+
/// <exception cref="NotSupportedException"></exception>
107+
public async Task<(float Score, int Tokens)> GetRelevanceScoreWithTokenCount(string input, string document, bool normalize = false, CancellationToken cancellationToken = default)
108+
{
109+
var inputTokens = Context.Tokenize(input);
110+
var docTokens = Context.Tokenize(document);
111+
LLamaToken[] tokens = [..inputTokens, ..docTokens];
112+
var batch = new LLamaBatch();
113+
for (var i = 0; i < tokens.Length; i++)
114+
batch.Add(tokens[i], i, LLamaSeqId.Zero, true);
115+
116+
// clear previous kv_cache values
117+
Context.NativeHandle.KvCacheClear();
118+
119+
// Check if we should cancel the work, just before doing anything expensive (encode/decode)
120+
cancellationToken.ThrowIfCancellationRequested();
121+
122+
// Run model
123+
switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder)
124+
{
125+
case (true, false):
126+
{
127+
var result = await Context.EncodeAsync(batch, cancellationToken);
128+
if (result != EncodeResult.Ok)
129+
throw new RuntimeError($"Failed to encode: {result}");
130+
break;
131+
}
132+
133+
case (false, true):
134+
{
135+
var result = await Context.DecodeAsync(batch, cancellationToken);
136+
if (result != DecodeResult.Ok)
137+
throw new RuntimeError($"Failed to decode: {result}");
138+
break;
139+
}
140+
141+
default:
142+
throw new NotSupportedException("Unsupported model type");
143+
}
144+
145+
var score = Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero)[0];
146+
147+
Context.NativeHandle.KvCacheClear();
148+
149+
return (normalize ? Sigmoid(score) : score, tokens.Length);
150+
}
151+
152+
private async Task<IReadOnlyList<float>> CalcRelevanceScores(LLamaBatch batch, bool normalize = false, CancellationToken cancellationToken = default)
153+
{
154+
var (logicCap, _) = batch.GetLogitPositions()[batch.LogitPositionCount - 1];
155+
var seqNum = logicCap.Value + 1;
156+
List<float> scores = new List<float>(seqNum);
157+
// clear previous kv_cache values
158+
Context.NativeHandle.KvCacheClear();
159+
160+
// Check if we should cancel the work, just before doing anything expensive (encode/decode)
161+
cancellationToken.ThrowIfCancellationRequested();
162+
163+
// Run model
164+
switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder)
165+
{
166+
case (true, false):
167+
{
168+
var result = await Context.EncodeAsync(batch, cancellationToken);
169+
if (result != EncodeResult.Ok)
170+
throw new RuntimeError($"Failed to encode: {result}");
171+
break;
172+
}
173+
174+
case (false, true):
175+
{
176+
var result = await Context.DecodeAsync(batch, cancellationToken);
177+
if (result != DecodeResult.Ok)
178+
throw new RuntimeError($"Failed to decode: {result}");
179+
break;
180+
}
181+
182+
default:
183+
throw new NotSupportedException("Unsupported model type");
184+
}
185+
186+
for (var seq = 0; seq < seqNum; seq++)
187+
{
188+
var score = Context.NativeHandle.GetEmbeddingsSeq((LLamaSeqId)seq)[0];
189+
scores.Add(normalize ? Sigmoid(score) : score);
190+
}
191+
192+
Context.NativeHandle.KvCacheClear();
193+
194+
return scores;
195+
}
196+
197+
private float Sigmoid(float x)
198+
{
199+
return (float)(1 / (1 + Math.Exp(-x)));
200+
}
201+
}

LLama/Native/SafeLlamaModelHandle.cs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,18 @@ internal Vocabulary(SafeLlamaModelHandle model)
651651
_model = model;
652652
}
653653

654-
private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
654+
private static LLamaToken? Normalize(LLamaToken token)
655+
{
656+
return token == -1 ? null : token;
657+
}
658+
659+
/// <summary>
660+
/// Translate LLamaToken to String
661+
/// </summary>
662+
/// <param name="token"></param>
663+
/// <param name="isSpecialToken"></param>
664+
/// <returns></returns>
665+
public string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
655666
{
656667
if (!token.HasValue)
657668
return null;
@@ -676,11 +687,6 @@ internal Vocabulary(SafeLlamaModelHandle model)
676687
return Encoding.UTF8.GetStringFromSpan(slice);
677688
}
678689

679-
private static LLamaToken? Normalize(LLamaToken token)
680-
{
681-
return token == -1 ? null : token;
682-
}
683-
684690
/// <summary>
685691
/// Total number of tokens in this vocabulary
686692
/// </summary>

0 commit comments

Comments
 (0)