Skip to content

Commit f1632fb

Browse files
committed
Refactor
1 parent e9d3c43 commit f1632fb

File tree

1 file changed

+74
-43
lines changed

1 file changed

+74
-43
lines changed

Sources/llama-cpp-swift/LLama.swift

Lines changed: 74 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ public actor LLama {
77
private let logger = Logger.llama
88
private let model: Model
99
private let sampling: UnsafeMutablePointer<llama_sampler>
10-
private var tokensList: [llama_token]
11-
private var temporaryInvalidCChars: [CChar]
1210

1311
// MARK: - Init & Teardown
1412

@@ -24,10 +22,6 @@ public actor LLama {
2422
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.8))
2523
llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
2624
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
27-
28-
// Initialize token lists
29-
self.tokensList = []
30-
self.temporaryInvalidCChars = []
3125
}
3226

3327
// MARK: - Inference
@@ -51,7 +45,7 @@ public actor LLama {
5145
}
5246
}
5347

54-
/// Performs the inference loop and yields generated tokens to the continuation.
48+
/// Initiates the inference process and manages the lifecycle of variables.
5549
///
5650
/// - Parameters:
5751
/// - prompt: The input text prompt to generate completions for.
@@ -67,57 +61,53 @@ public actor LLama {
6761
var nCur: Int32 = 0
6862
var nDecode: Int32 = 0
6963
var batch = llama_batch_init(512, 0, 1)
64+
var temporaryInvalidCChars: [CChar] = []
7065
defer {
7166
llama_batch_free(batch)
7267
}
7368

74-
do {
75-
try self.completionInit(text: prompt, batch: &batch, nLen: nLen, nCur: &nCur)
76-
} catch {
77-
throw error
78-
}
79-
80-
while !isDone && nCur < nLen && nCur - batch.n_tokens < maxTokens {
81-
guard !Task.isCancelled else {
82-
continuation.finish()
83-
return
84-
}
85-
let newTokenStr = self.completionLoop(
86-
batch: &batch,
87-
isDone: &isDone,
88-
nLen: nLen,
89-
nCur: &nCur,
90-
nDecode: &nDecode
91-
)
92-
continuation.yield(newTokenStr)
93-
}
94-
continuation.finish()
69+
try self.initializeInference(
70+
prompt: prompt,
71+
batch: &batch,
72+
nLen: nLen,
73+
nCur: &nCur
74+
)
75+
76+
try await self.runInferenceLoop(
77+
batch: &batch,
78+
temporaryInvalidCChars: &temporaryInvalidCChars,
79+
isDone: &isDone,
80+
nLen: nLen,
81+
nCur: &nCur,
82+
nDecode: &nDecode,
83+
maxTokens: maxTokens,
84+
continuation: continuation
85+
)
9586
}
9687

9788
// MARK: - Private Helpers
9889

99-
/// Initializes the completion process by tokenizing the input text and preparing the batch.
90+
/// Initializes the inference process by tokenizing the input and preparing the batch.
10091
///
10192
/// - Parameters:
102-
/// - text: The input text to tokenize.
93+
/// - prompt: The input text prompt.
10394
/// - batch: The batch to initialize.
104-
/// - nLen: The maximum length of the sequence.
95+
/// - nLen: The maximum sequence length.
10596
/// - nCur: The current position in the sequence.
10697
///
107-
/// - Throws: An `InferError` if the KV cache is too small or decoding fails.
108-
private func completionInit(
109-
text: String,
98+
/// - Throws: An `InferError` if the KV cache is insufficient or decoding fails.
99+
private func initializeInference(
100+
prompt: String,
110101
batch: inout llama_batch,
111102
nLen: Int32,
112103
nCur: inout Int32
113104
) throws {
114-
logger.debug("Attempting to complete \"\(text)\"")
105+
logger.debug("Attempting to complete \"\(prompt)\"")
115106

116-
tokensList = tokenize(text: text, add_bos: true)
117-
temporaryInvalidCChars = []
107+
let tokensList = tokenize(text: prompt, add_bos: true)
118108

119109
let nCtx = llama_n_ctx(model.context)
120-
let nKvReq = tokensList.count + Int(nLen) - tokensList.count
110+
let nKvReq = tokensList.count + Int(nLen - Int32(tokensList.count))
121111

122112
logger.debug("\nn_len = \(nLen), n_ctx = \(nCtx), n_kv_req = \(nKvReq)")
123113

@@ -142,18 +132,59 @@ public actor LLama {
142132
nCur = batch.n_tokens
143133
}
144134

145-
/// Performs a single iteration of the completion loop, generating the next token.
135+
/// Runs the main inference loop, generating tokens and yielding them to the continuation.
136+
///
137+
/// - Parameters:
138+
/// - batch: The batch used for decoding.
139+
/// - temporaryInvalidCChars: Buffer for building partial UTF8 strings.
140+
/// - isDone: A flag indicating whether inference is complete.
141+
/// - nLen: The maximum sequence length.
142+
/// - nCur: The current position in the sequence.
143+
/// - nDecode: The number of tokens decoded so far.
144+
/// - maxTokens: The maximum number of tokens to generate.
145+
/// - continuation: The stream continuation to yield tokens to.
146+
private func runInferenceLoop(
147+
batch: inout llama_batch,
148+
temporaryInvalidCChars: inout [CChar],
149+
isDone: inout Bool,
150+
nLen: Int32,
151+
nCur: inout Int32,
152+
nDecode: inout Int32,
153+
maxTokens: Int32,
154+
continuation: AsyncThrowingStream<String, Error>.Continuation
155+
) async throws {
156+
while !isDone && nCur < nLen && nCur - batch.n_tokens < maxTokens {
157+
guard !Task.isCancelled else {
158+
continuation.finish()
159+
return
160+
}
161+
let newTokenStr = self.generateNextToken(
162+
batch: &batch,
163+
temporaryInvalidCChars: &temporaryInvalidCChars,
164+
isDone: &isDone,
165+
nLen: nLen,
166+
nCur: &nCur,
167+
nDecode: &nDecode
168+
)
169+
continuation.yield(newTokenStr)
170+
}
171+
continuation.finish()
172+
}
173+
174+
/// Generates the next token and updates necessary states.
146175
///
147176
/// - Parameters:
148-
/// - batch: The batch to use for decoding.
149-
/// - isDone: A flag indicating whether the generation is complete.
150-
/// - nLen: The maximum length of the sequence.
177+
/// - batch: The batch used for decoding.
178+
/// - temporaryInvalidCChars: Buffer for building partial UTF8 strings.
179+
/// - isDone: A flag indicating whether inference is complete.
180+
/// - nLen: The maximum sequence length.
151181
/// - nCur: The current position in the sequence.
152182
/// - nDecode: The number of tokens decoded so far.
153183
///
154184
/// - Returns: The newly generated token as a string.
155-
private func completionLoop(
185+
private func generateNextToken(
156186
batch: inout llama_batch,
187+
temporaryInvalidCChars: inout [CChar],
157188
isDone: inout Bool,
158189
nLen: Int32,
159190
nCur: inout Int32,

0 commit comments

Comments
 (0)