Skip to content

Commit 41eb7dc

Browse files
committed
Added 'strict' parameter for template retrieval that validates the retrieval of a valid template.
1 parent dcec12d commit 41eb7dc

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

LLama/LLamaTemplate.cs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,21 @@ public bool AddAssistant
105105
/// <summary>
106106
/// Construct a new template, using the default model template
107107
/// </summary>
108-
/// <param name="model"></param>
109-
/// <param name="name"></param>
110-
public LLamaTemplate(SafeLlamaModelHandle model, string? name = null)
111-
: this(model.GetTemplate(name))
108+
/// <param name="model">The native handle of the loaded model.</param>
109+
/// <param name="name">The name of the template, in case there are many or differently named. Set to 'null' for the default behaviour of finding an appropriate match.</param>
110+
/// <param name="strict">Setting this to true will cause the call to throw if no valid templates are found.</param>
111+
public LLamaTemplate(SafeLlamaModelHandle model, string? name = null, bool strict = true)
112+
: this(model.GetTemplate(name, strict))
112113
{
113114
}
114115

115116
/// <summary>
116117
/// Construct a new template, using the default model template
117118
/// </summary>
118-
/// <param name="weights"></param>
119-
public LLamaTemplate(LLamaWeights weights)
120-
: this(weights.NativeHandle)
119+
/// <param name="weights">The handle of the loaded model's weights.</param>
120+
/// <param name="strict">Setting this to true will cause the call to throw if no valid templates are found.</param>
121+
public LLamaTemplate(LLamaWeights weights, bool strict = true)
122+
: this(weights.NativeHandle, strict: strict)
121123
{
122124
}
123125

LLama/Native/SafeLlamaModelHandle.cs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -603,15 +603,24 @@ internal IReadOnlyDictionary<string, string> ReadMetadata()
603603
/// Get the default chat template. Returns nullptr if not available
604604
/// If name is NULL, returns the default chat template
605605
/// </summary>
606-
/// <param name="name"></param>
606+
/// <param name="name">The name of the template, in case there are many or differently named. Set to 'null' for the default behaviour of finding an appropriate match.</param>
607+
/// <param name="strict">Setting this to true will cause the call to throw if no valid templates are found.</param>
607608
/// <returns></returns>
608-
public string? GetTemplate(string? name = null)
609+
public string? GetTemplate(string? name = null, bool strict = true)
609610
{
610611
unsafe
611612
{
612613
var bytesPtr = llama_model_chat_template(this, name);
613614
if (bytesPtr == null)
614-
return null;
615+
{
616+
if (strict)
617+
throw new Exception($"Tried to retrieve template for '{name}' but no templates were found.\n" +
618+
$"This might mean that the model was exported incorrectly, or that this is a base model that contains no template.\n" +
619+
$"This exception can be disabled by passing 'strict=false' as a parameter when retrieving the template.");
620+
else
621+
return null;
622+
623+
}
615624

616625
// Find null terminator
617626
var spanBytes = new Span<byte>(bytesPtr, int.MaxValue);

0 commit comments

Comments
 (0)