@@ -2,15 +2,19 @@ import Foundation
2
2
import Logging
3
3
@preconcurrency import llama
4
4
5
+ /// An actor that handles inference using the LLama language model.
5
6
public actor LLama {
6
7
private let logger = Logger . llama
7
8
private let model : Model
8
9
private let sampling : UnsafeMutablePointer < llama_sampler >
9
10
private var tokensList : [ llama_token ]
10
11
private var temporaryInvalidCChars : [ CChar ]
11
12
12
- // MARK: - Init & teardown
13
+ // MARK: - Init & Teardown
13
14
15
+ /// Initializes a new instance of `LLama` with the specified model.
16
+ ///
17
+ /// - Parameter model: The language model to use for inference.
14
18
public init ( model: Model ) {
15
19
self . model = model
16
20
@@ -26,52 +30,81 @@ public actor LLama {
26
30
self . temporaryInvalidCChars = [ ]
27
31
}
28
32
29
- deinit {
30
- // llama_sampler_free(sampling)
31
- }
32
-
33
33
// MARK: - Inference
34
34
35
+ /// Generates an asynchronous stream of tokens as strings based on the given prompt.
36
+ ///
37
+ /// - Parameters:
38
+ /// - prompt: The input text prompt to generate completions for.
39
+ /// - maxTokens: The maximum number of tokens to generate. Defaults to 128.
40
+ ///
41
+ /// - Returns: An `AsyncThrowingStream` emitting generated tokens as strings.
35
42
public func infer( prompt: String , maxTokens: Int32 = 128 ) -> AsyncThrowingStream < String , Error > {
36
43
return AsyncThrowingStream { continuation in
37
44
Task {
38
- var isDone = false
39
- let nLen : Int32 = 1024
40
- var nCur : Int32 = 0
41
- var nDecode : Int32 = 0
42
- var batch = llama_batch_init ( 512 , 0 , 1 )
43
- defer {
44
- llama_batch_free ( batch)
45
- }
46
-
47
45
do {
48
- try self . completionInit ( text : prompt, batch : & batch , nLen : nLen , nCur : & nCur )
46
+ try await self . infer ( prompt : prompt, maxTokens : maxTokens , continuation : continuation )
49
47
} catch {
50
48
continuation. finish ( throwing: error)
51
- return
52
49
}
50
+ }
51
+ }
52
+ }
53
53
54
- while !isDone && nCur < nLen && nCur - batch. n_tokens < maxTokens {
55
- guard !Task. isCancelled else {
56
- continuation. finish ( )
57
- return
58
- }
59
- let newTokenStr = self . completionLoop (
60
- batch: & batch,
61
- isDone: & isDone,
62
- nLen: nLen,
63
- nCur: & nCur,
64
- nDecode: & nDecode
65
- )
66
- continuation. yield ( newTokenStr)
67
- }
54
+ /// Performs the inference loop and yields generated tokens to the continuation.
55
+ ///
56
+ /// - Parameters:
57
+ /// - prompt: The input text prompt to generate completions for.
58
+ /// - maxTokens: The maximum number of tokens to generate.
59
+ /// - continuation: The stream continuation to yield tokens to.
60
+ private func infer(
61
+ prompt: String ,
62
+ maxTokens: Int32 ,
63
+ continuation: AsyncThrowingStream < String , Error > . Continuation
64
+ ) async throws {
65
+ var isDone = false
66
+ let nLen : Int32 = 1024
67
+ var nCur : Int32 = 0
68
+ var nDecode : Int32 = 0
69
+ var batch = llama_batch_init ( 512 , 0 , 1 )
70
+ defer {
71
+ llama_batch_free ( batch)
72
+ }
73
+
74
+ do {
75
+ try self . completionInit ( text: prompt, batch: & batch, nLen: nLen, nCur: & nCur)
76
+ } catch {
77
+ throw error
78
+ }
79
+
80
+ while !isDone && nCur < nLen && nCur - batch. n_tokens < maxTokens {
81
+ guard !Task. isCancelled else {
68
82
continuation. finish ( )
83
+ return
69
84
}
85
+ let newTokenStr = self . completionLoop (
86
+ batch: & batch,
87
+ isDone: & isDone,
88
+ nLen: nLen,
89
+ nCur: & nCur,
90
+ nDecode: & nDecode
91
+ )
92
+ continuation. yield ( newTokenStr)
70
93
}
94
+ continuation. finish ( )
71
95
}
72
96
73
- // MARK: - Private helpers
74
-
97
+ // MARK: - Private Helpers
98
+
99
+ /// Initializes the completion process by tokenizing the input text and preparing the batch.
100
+ ///
101
+ /// - Parameters:
102
+ /// - text: The input text to tokenize.
103
+ /// - batch: The batch to initialize.
104
+ /// - nLen: The maximum length of the sequence.
105
+ /// - nCur: The current position in the sequence.
106
+ ///
107
+ /// - Throws: An `InferError` if the KV cache is too small or decoding fails.
75
108
private func completionInit(
76
109
text: String ,
77
110
batch: inout llama_batch ,
@@ -109,6 +142,16 @@ public actor LLama {
109
142
nCur = batch. n_tokens
110
143
}
111
144
145
+ /// Performs a single iteration of the completion loop, generating the next token.
146
+ ///
147
+ /// - Parameters:
148
+ /// - batch: The batch to use for decoding.
149
+ /// - isDone: A flag indicating whether the generation is complete.
150
+ /// - nLen: The maximum length of the sequence.
151
+ /// - nCur: The current position in the sequence.
152
+ /// - nDecode: The number of tokens decoded so far.
153
+ ///
154
+ /// - Returns: The newly generated token as a string.
112
155
private func completionLoop(
113
156
batch: inout llama_batch ,
114
157
isDone: inout Bool ,
@@ -154,6 +197,14 @@ public actor LLama {
154
197
return newTokenStr
155
198
}
156
199
200
+ /// Adds a token to the batch.
201
+ ///
202
+ /// - Parameters:
203
+ /// - batch: The batch to add the token to.
204
+ /// - id: The token ID to add.
205
+ /// - pos: The position of the token in the sequence.
206
+ /// - seq_ids: The sequence IDs associated with the token.
207
+ /// - logits: A flag indicating whether to compute logits for this token.
157
208
private func llamaBatchAdd(
158
209
_ batch: inout llama_batch ,
159
210
_ id: llama_token ,
@@ -172,6 +223,13 @@ public actor LLama {
172
223
batch. n_tokens += 1
173
224
}
174
225
226
+ /// Tokenizes the given text using the model's tokenizer.
227
+ ///
228
+ /// - Parameters:
229
+ /// - text: The text to tokenize.
230
+ /// - add_bos: A flag indicating whether to add a beginning-of-sequence token.
231
+ ///
232
+ /// - Returns: An array of token IDs.
175
233
private func tokenize( text: String , add_bos: Bool ) -> [ llama_token ] {
176
234
let utf8Data = text. utf8CString
177
235
let nTokens = Int32 ( utf8Data. count) + ( add_bos ? 1 : 0 )
@@ -187,6 +245,11 @@ public actor LLama {
187
245
return Array ( UnsafeBufferPointer ( start: tokens, count: Int ( tokenCount) ) )
188
246
}
189
247
248
+ /// Converts a token ID to an array of CChars representing the token piece.
249
+ ///
250
+ /// - Parameter token: The token ID to convert.
251
+ ///
252
+ /// - Returns: An array of CChars representing the token piece.
190
253
private func tokenToPieceArray( token: llama_token ) -> [ CChar ] {
191
254
var buffer = [ CChar] ( repeating: 0 , count: 8 )
192
255
var nTokens = llama_token_to_piece ( model. model, token, & buffer, 8 , 0 , false )
@@ -200,6 +263,11 @@ public actor LLama {
200
263
return Array ( buffer. prefix ( Int ( nTokens) ) )
201
264
}
202
265
266
+ /// Attempts to create a partial string from an array of CChars if the full string is invalid.
267
+ ///
268
+ /// - Parameter cchars: The array of CChars to attempt to convert.
269
+ ///
270
+ /// - Returns: A valid string if possible; otherwise, `nil`.
203
271
private func attemptPartialString( from cchars: [ CChar ] ) -> String ? {
204
272
for i in ( 1 ..< cchars. count) . reversed ( ) {
205
273
let subArray = Array ( cchars. prefix ( i) )
@@ -212,12 +280,16 @@ public actor LLama {
212
280
}
213
281
214
282
extension llama_batch {
283
+ /// Clears the batch by resetting the token count.
215
284
fileprivate mutating func clear( ) {
216
285
n_tokens = 0
217
286
}
218
287
}
219
288
220
289
extension String {
290
+ /// Initializes a string from a sequence of CChars, validating UTF8 encoding.
291
+ ///
292
+ /// - Parameter validatingUTF8: The array of CChars to initialize the string from.
221
293
fileprivate init ? ( validatingUTF8 cchars: [ CChar ] ) {
222
294
if #available( macOS 15 . 0 , * ) {
223
295
self . init ( validating: cchars. map { UInt8 ( bitPattern: $0) } , as: UTF8 . self)
0 commit comments