@@ -7,8 +7,6 @@ public actor LLama {
7
7
private let logger = Logger . llama
8
8
private let model : Model
9
9
private let sampling : UnsafeMutablePointer < llama_sampler >
10
- private var tokensList : [ llama_token ]
11
- private var temporaryInvalidCChars : [ CChar ]
12
10
13
11
// MARK: - Init & Teardown
14
12
@@ -24,10 +22,6 @@ public actor LLama {
24
22
llama_sampler_chain_add ( self . sampling, llama_sampler_init_temp ( 0.8 ) )
25
23
llama_sampler_chain_add ( self . sampling, llama_sampler_init_softmax ( ) )
26
24
llama_sampler_chain_add ( self . sampling, llama_sampler_init_dist ( 1234 ) )
27
-
28
- // Initialize token lists
29
- self . tokensList = [ ]
30
- self . temporaryInvalidCChars = [ ]
31
25
}
32
26
33
27
// MARK: - Inference
@@ -51,7 +45,7 @@ public actor LLama {
51
45
}
52
46
}
53
47
54
- /// Performs the inference loop and yields generated tokens to the continuation .
48
+ /// Initiates the inference process and manages the lifecycle of variables .
55
49
///
56
50
/// - Parameters:
57
51
/// - prompt: The input text prompt to generate completions for.
@@ -67,57 +61,53 @@ public actor LLama {
67
61
var nCur : Int32 = 0
68
62
var nDecode : Int32 = 0
69
63
var batch = llama_batch_init ( 512 , 0 , 1 )
64
+ var temporaryInvalidCChars : [ CChar ] = [ ]
70
65
defer {
71
66
llama_batch_free ( batch)
72
67
}
73
68
74
- do {
75
- try self . completionInit ( text: prompt, batch: & batch, nLen: nLen, nCur: & nCur)
76
- } catch {
77
- throw error
78
- }
79
-
80
- while !isDone && nCur < nLen && nCur - batch. n_tokens < maxTokens {
81
- guard !Task. isCancelled else {
82
- continuation. finish ( )
83
- return
84
- }
85
- let newTokenStr = self . completionLoop (
86
- batch: & batch,
87
- isDone: & isDone,
88
- nLen: nLen,
89
- nCur: & nCur,
90
- nDecode: & nDecode
91
- )
92
- continuation. yield ( newTokenStr)
93
- }
94
- continuation. finish ( )
69
+ try self . initializeInference (
70
+ prompt: prompt,
71
+ batch: & batch,
72
+ nLen: nLen,
73
+ nCur: & nCur
74
+ )
75
+
76
+ try await self . runInferenceLoop (
77
+ batch: & batch,
78
+ temporaryInvalidCChars: & temporaryInvalidCChars,
79
+ isDone: & isDone,
80
+ nLen: nLen,
81
+ nCur: & nCur,
82
+ nDecode: & nDecode,
83
+ maxTokens: maxTokens,
84
+ continuation: continuation
85
+ )
95
86
}
96
87
97
88
// MARK: - Private Helpers
98
89
99
- /// Initializes the completion process by tokenizing the input text and preparing the batch.
90
+ /// Initializes the inference process by tokenizing the input and preparing the batch.
100
91
///
101
92
/// - Parameters:
102
- /// - text : The input text to tokenize .
93
+ /// - prompt : The input text prompt .
103
94
/// - batch: The batch to initialize.
104
- /// - nLen: The maximum length of the sequence.
95
+ /// - nLen: The maximum sequence length .
105
96
/// - nCur: The current position in the sequence.
106
97
///
107
- /// - Throws: An `InferError` if the KV cache is too small or decoding fails.
108
- private func completionInit (
109
- text : String ,
98
+ /// - Throws: An `InferError` if the KV cache is insufficient or decoding fails.
99
+ private func initializeInference (
100
+ prompt : String ,
110
101
batch: inout llama_batch ,
111
102
nLen: Int32 ,
112
103
nCur: inout Int32
113
104
) throws {
114
- logger. debug ( " Attempting to complete \" \( text ) \" " )
105
+ logger. debug ( " Attempting to complete \" \( prompt ) \" " )
115
106
116
- tokensList = tokenize ( text: text, add_bos: true )
117
- temporaryInvalidCChars = [ ]
107
+ let tokensList = tokenize ( text: prompt, add_bos: true )
118
108
119
109
let nCtx = llama_n_ctx ( model. context)
120
- let nKvReq = tokensList. count + Int( nLen) - tokensList. count
110
+ let nKvReq = tokensList. count + Int( nLen - Int32 ( tokensList. count) )
121
111
122
112
logger. debug ( " \n n_len = \( nLen) , n_ctx = \( nCtx) , n_kv_req = \( nKvReq) " )
123
113
@@ -142,18 +132,59 @@ public actor LLama {
142
132
nCur = batch. n_tokens
143
133
}
144
134
145
- /// Performs a single iteration of the completion loop, generating the next token.
135
+ /// Runs the main inference loop, generating tokens and yielding them to the continuation.
136
+ ///
137
+ /// - Parameters:
138
+ /// - batch: The batch used for decoding.
139
+ /// - temporaryInvalidCChars: Buffer for building partial UTF8 strings.
140
+ /// - isDone: A flag indicating whether inference is complete.
141
+ /// - nLen: The maximum sequence length.
142
+ /// - nCur: The current position in the sequence.
143
+ /// - nDecode: The number of tokens decoded so far.
144
+ /// - maxTokens: The maximum number of tokens to generate.
145
+ /// - continuation: The stream continuation to yield tokens to.
146
+ private func runInferenceLoop(
147
+ batch: inout llama_batch ,
148
+ temporaryInvalidCChars: inout [ CChar ] ,
149
+ isDone: inout Bool ,
150
+ nLen: Int32 ,
151
+ nCur: inout Int32 ,
152
+ nDecode: inout Int32 ,
153
+ maxTokens: Int32 ,
154
+ continuation: AsyncThrowingStream < String , Error > . Continuation
155
+ ) async throws {
156
+ while !isDone && nCur < nLen && nCur - batch. n_tokens < maxTokens {
157
+ guard !Task. isCancelled else {
158
+ continuation. finish ( )
159
+ return
160
+ }
161
+ let newTokenStr = self . generateNextToken (
162
+ batch: & batch,
163
+ temporaryInvalidCChars: & temporaryInvalidCChars,
164
+ isDone: & isDone,
165
+ nLen: nLen,
166
+ nCur: & nCur,
167
+ nDecode: & nDecode
168
+ )
169
+ continuation. yield ( newTokenStr)
170
+ }
171
+ continuation. finish ( )
172
+ }
173
+
174
+ /// Generates the next token and updates necessary states.
146
175
///
147
176
/// - Parameters:
148
- /// - batch: The batch to use for decoding.
149
- /// - isDone: A flag indicating whether the generation is complete.
150
- /// - nLen: The maximum length of the sequence.
177
+ /// - batch: The batch used for decoding.
178
+ /// - temporaryInvalidCChars: Buffer for building partial UTF8 strings.
179
+ /// - isDone: A flag indicating whether inference is complete.
180
+ /// - nLen: The maximum sequence length.
151
181
/// - nCur: The current position in the sequence.
152
182
/// - nDecode: The number of tokens decoded so far.
153
183
///
154
184
/// - Returns: The newly generated token as a string.
155
- private func completionLoop (
185
+ private func generateNextToken (
156
186
batch: inout llama_batch ,
187
+ temporaryInvalidCChars: inout [ CChar ] ,
157
188
isDone: inout Bool ,
158
189
nLen: Int32 ,
159
190
nCur: inout Int32 ,
0 commit comments