Skip to content

Commit 3c52ac0

Browse files
feat: allocate minimal blocks per window size (#3028)
* implement variable window attention by breaking the block manager into window block managers per window size Signed-off-by: Netanel Haber <[email protected]> * revert isCyclic to be true if the min attention window is reached, not per window size Signed-off-by: Netanel Haber <[email protected]> * add explanatory comment to mCyclicThreshold Signed-off-by: Netanel Haber <[email protected]> * load correct gemma config Signed-off-by: Netanel Haber <[email protected]> * don't shadow inputLength in addSequence - it should remain the function scope input length between window size loop iterations Signed-off-by: Netanel Haber <[email protected]> * fix KVCacheManagerVariableWindowAttentionWithReuseTest for multiple window block managers Signed-off-by: Netanel Haber <[email protected]> * if TYPE_CHECKING Signed-off-by: Netanel Haber <[email protected]> * set temp_attention_window_inputs to None explicitly Signed-off-by: Netanel Haber <[email protected]> * set temp_attention_window_inputs to None explicitly Signed-off-by: Netanel Haber <[email protected]> * pass dtype as well Signed-off-by: Netanel Haber <[email protected]> * test_gemma variable sliding window attention Signed-off-by: Netanel Haber <[email protected]> * allot a fraction of primary/secondaryBlocks to different window size heaps, depending on the window size's total contribution to the kvcache size (i.e., including all layers) Signed-off-by: Netanel Haber <[email protected]> * remove || mEnableBlockReuse which erroneously triggers beamsearch code for cyclic variable attention window code Signed-off-by: Netanel Haber <[email protected]> * turn off request delaying for MaxUtil Signed-off-by: Netanel Haber <[email protected]> * make comments better Signed-off-by: Netanel Haber <[email protected]> * windowSizesTotalSum using std::accumulate Signed-off-by: Netanel Haber <[email protected]> * fix error handling of forwardAsync - forwardAsync catch-all catch cleanup code that runs terminateRequest can also fail and must be caught Signed-off-by: Netanel Haber <[email protected]> * fix comments Signed-off-by: Netanel Haber <[email protected]> * remove assert that kills disagg tests, since it isn't necessary Signed-off-by: Netanel Haber <[email protected]> * fix corrupted expression: 'isNewTask && (peftCacheManager ?' -> '(isNewTask && peftCacheManager) ?' which caused boolean algebra. Main is correct Signed-off-by: Netanel Haber <[email protected]> * add Gemma3 to SUPPORTED_HF_ARCHITECTURES Signed-off-by: Netanel Haber <[email protected]> * support Gemma3 Signed-off-by: Netanel Haber <[email protected]> * finally fix test_gemma - always spread at least {} into generate_summary_cmd, never None Signed-off-by: Netanel Haber <[email protected]> * finally fix test_gemma - always spread at least {} into generate_summary_cmd, never None Signed-off-by: Netanel Haber <[email protected]> * fix kvfactor field for deepseek Signed-off-by: Netanel Haber <[email protected]> * fix comment Signed-off-by: Netanel Haber <[email protected]> * fix gemma-3 entries in testlist to include vswa Signed-off-by: Netanel Haber <[email protected]> * only quantize gemma2 VSWA Signed-off-by: Netanel Haber <[email protected]> remove misleading comment Signed-off-by: Netanel Haber <[email protected]> fix test_gemma Signed-off-by: Netanel Haber <[email protected]> * fix test_gemma Signed-off-by: Netanel Haber <[email protected]> * fix test_gemma Signed-off-by: Netanel Haber <[email protected]> * in sendRequestInfo, fromOldAllocatedBlockIds->fromOldAllocatedBlockIds, like in main Signed-off-by: Netanel Haber <[email protected]> * fix: disable KV cache reuse if using attention sink (#3021) * fix: disable KV cache reuse if using attention sink Signed-off-by: Robin Kobus <[email protected]> * fix: disable KV cache reuse if sink bubble Signed-off-by: Robin Kobus <[email protected]> * add comment Signed-off-by: Robin Kobus <[email protected]> --------- Signed-off-by: Robin Kobus <[email protected]> --------- Signed-off-by: Netanel Haber <[email protected]> Signed-off-by: Robin Kobus <[email protected]> Co-authored-by: Robin Kobus <[email protected]>
1 parent 1c6f3de commit 3c52ac0

24 files changed

+1975
-1045
lines changed

cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h

-6
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,6 @@ class MaxUtilizationScheduler : public BaseCapacityScheduler
9797
RequestList const& activeRequests) const;
9898

9999
private:
100-
/// @return {fitsKvCache, fitsPeft}
101-
std::pair<bool, bool> trySchedulingRequestMaxUtilization(kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
102-
OptionalRef<BasePeftCacheManager const> peftCacheManager, std::shared_ptr<LlmRequest> const& req,
103-
RequestVector& scheduledRequests, SizeType32& numScheduledBlocks, SizeType32& numScheduledPeftPages,
104-
std::unordered_set<uint64_t>& seenTaskIds) const;
105-
106100
SizeType32 mMaxNumRequests;
107101
/// @brief Boolean that indicates if multiple micro batches might be in flight
108102
bool mManyMicroBatches;

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

+600-173
Large diffs are not rendered by default.

cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h

+43-13
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,28 @@ class BlockRange
3131
{
3232
};
3333

34-
BlockRange(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId, SizeType32 beam,
35-
SizeType32 poolIdx = 0)
36-
: mManager(&cacheManager)
37-
, mPool(cacheManager.getBlockManager().getPrimaryPool(poolIdx))
38-
, mBlockIds(cacheManager.getSequence(requestId).getCacheBlockIds().at(beam))
34+
static BlockRange fromOldAllocatedBlockIds(BaseKVCacheManager const& cacheManager,
35+
LlmRequest::RequestIdType requestId, SizeType32 beam = kFIRST_AND_ONLY_BEAM)
3936
{
37+
assert(kFIRST_AND_ONLY_BEAM == beam);
38+
auto const windowSize = firstWindowSize(cacheManager);
39+
auto const blockIds = cacheManager.getSequence(requestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM);
40+
return BlockRange(cacheManager, blockIds, requestId);
4041
}
4142

42-
BlockRange(BaseKVCacheManager const& cacheManager, std::vector<SizeType32> blockIds, SizeType32 poolIdx = 0)
43-
: mManager(&cacheManager)
44-
, mPool(cacheManager.getBlockManager().getPrimaryPool(poolIdx))
45-
, mBlockIds(std::move(blockIds))
43+
static BlockRange fromNewlyAllocatedBlockIds(
44+
BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
4645
{
46+
auto const windowSize = firstWindowSize(cacheManager);
47+
auto const blockIds = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize);
48+
return BlockRange(cacheManager, blockIds, requestId);
4749
}
4850

49-
BlockRange(runtime::ITensor::SharedPtr pool, std::vector<SizeType32> const& blockIds)
51+
BlockRange(runtime::ITensor::SharedPtr pool, std::vector<SizeType32> const& blockIds) // Only used in tests
5052
: mManager{nullptr}
5153
, mPool{std::move(pool)}
54+
, mWindowSize{0}
55+
, mRequestId{0}
5256
, mBlockIds{blockIds}
5357
{
5458
TLLM_CHECK(mPool);
@@ -84,25 +88,51 @@ class BlockRange
8488
auto& blockManager = mManager->getBlockManager();
8589
for (auto id : mBlockIds)
8690
{
87-
blockHashes.emplace_back(blockManager.getBlockById(id)->getHash());
91+
blockHashes.emplace_back(blockManager.getBlockById(id, mWindowSize)->getHash());
8892
}
8993
return blockHashes;
9094
}
9195

9296
void updatePoolIdx(SizeType32 poolIdx)
9397
{
94-
if (mManager)
98+
TLLM_CHECK(mManager);
99+
mPool = mManager->getBlockManager().getPrimaryPool(poolIdx);
100+
auto const newWindowSize = mManager->getBlockManager().getPoolWindowSize(poolIdx);
101+
if (newWindowSize != mWindowSize)
95102
{
96-
mPool = mManager->getBlockManager().getPrimaryPool(poolIdx);
103+
mWindowSize = newWindowSize;
104+
mBlockIds = mManager->getSequence(mRequestId).getCacheBlockIds(mWindowSize).at(kFIRST_AND_ONLY_BEAM);
97105
}
98106
}
99107

100108
friend class BlockIterator;
101109

110+
private:
111+
BlockRange(
112+
BaseKVCacheManager const& cacheManager, std::vector<SizeType32> blockIds, LlmRequest::RequestIdType requestId)
113+
: mManager(&cacheManager)
114+
, mPool(cacheManager.getBlockManager().getPrimaryPool(kFIRST_POOL_INDEX))
115+
, mWindowSize(firstWindowSize(cacheManager))
116+
, mRequestId(requestId)
117+
, mBlockIds(std::move(blockIds))
118+
{
119+
}
120+
121+
static SizeType32 firstWindowSize(BaseKVCacheManager const& cacheManager)
122+
{
123+
constexpr SizeType32 FIRST_POOL_IDX = 0;
124+
return cacheManager.getBlockManager().getPoolWindowSize(FIRST_POOL_IDX);
125+
}
126+
102127
private:
103128
BaseKVCacheManager const* mManager;
104129
runtime::ITensor::SharedPtr mPool;
130+
SizeType32 mWindowSize;
131+
const LlmRequest::RequestIdType mRequestId;
105132
std::vector<SizeType32> mBlockIds;
133+
134+
static constexpr SizeType32 kFIRST_AND_ONLY_BEAM = 0;
135+
static constexpr SizeType32 kFIRST_POOL_INDEX = 0;
106136
};
107137

108138
class BlockIterator

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ void CacheFormatter::formatOutput(LlmRequest const& llmRequest,
4949
constexpr SizeType32 beam{0};
5050
auto& blockManager = mCacheManager->getBlockManager();
5151
size_t requestBlockNum = llmRequest.getRequestedBlockHashes().size();
52-
auto blockRange = BlockRange(*mCacheManager, llmRequest.mRequestId, beam);
52+
auto blockRange = BlockRange::fromOldAllocatedBlockIds(*mCacheManager, llmRequest.mRequestId, beam);
5353
if (requestBlockNum < blockRange.size() && requestBlockNum > 0)
5454
{
5555
// handle block reuse, the prefix blocks are reused
@@ -109,7 +109,7 @@ void CacheFormatter::formatOutput(LlmRequest const& llmRequest,
109109
}
110110
TLLM_CHECK(!inputKvCacheBlocks.empty());
111111
TLLM_CHECK(blockNum > 0);
112-
int deviceId = mCacheManager->getBlockManager().getBufferManager().getStream().getDevice();
112+
int deviceId = mCacheManager->getBlockManager().getStreamDevice();
113113

114114
if (common::getEnvTryZCopyForKVCacheTransfer()
115115
&& (destConfig.getParallelConfig().mPipelineParallelism
@@ -318,8 +318,7 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
318318
"Start receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId,
319319
llmRequest.getContextPhaseParams().value().getReqId());
320320
TLLM_CHECK(!connections.empty());
321-
auto blockRange = BlockRange(*mCacheManager, mCacheManager->getNewlyAllocatedBlockIds(llmRequest.mRequestId));
322-
321+
auto blockRange = BlockRange::fromNewlyAllocatedBlockIds(*mCacheManager, llmRequest.mRequestId);
323322
std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
324323
std::vector<runtime::ITensor::SharedPtr> outputBuffers;
325324
auto const numPools = mCacheManager->getBlockManager().getNumPools();

cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp

+60-51
Original file line numberDiff line numberDiff line change
@@ -199,25 +199,32 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
199199
RequestVector scheduledRequests;
200200

201201
// Now check if we can add pending requests
202-
auto const numFreeBlocks = kvCacheManager.getNumFreeBlocks();
203-
auto const numFreeCrossBlocks = crossKvCacheManager ? crossKvCacheManager->getNumFreeBlocks() : 0;
204202
auto const maxPeftCachePages
205203
= peftCacheManager ? peftCacheManager->getMaxDevicePages() : std::numeric_limits<SizeType32>::max();
206204

205+
// The optimization of delaying requests won't work for variable window attention
206+
bool skippingIsRelevant = (!kvCacheManager.getBlockManager().isVariableWindow())
207+
&& (!crossKvCacheManager || !crossKvCacheManager->getBlockManager().isVariableWindow());
208+
207209
// Keep track of blocks contributed by requests in context phase
208210
std::unordered_set<BlockKey, BlockKeyHasher> newlyContributedContextBlocks;
209211
std::unordered_set<BlockKey, BlockKeyHasher> newlyContributedCrossContextBlocks;
210212
if constexpr (!StaticBatchScheduling)
211213
{
212-
std::tie(newlyContributedContextBlocks, newlyContributedCrossContextBlocks)
213-
= prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager, crossKvCacheManager);
214+
if (skippingIsRelevant)
215+
{
216+
std::tie(newlyContributedContextBlocks, newlyContributedCrossContextBlocks)
217+
= prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager, crossKvCacheManager);
218+
}
214219
}
215220

216221
// If a request is already in progress, include it
217222
// If it's been allocated, it had resource to run to completion
218223
// Also keep track of blocks needed to drive all in-progress requests to completion
219-
SizeType32 reservedBlocks{0};
220-
SizeType32 reservedCrossBlocks{0};
224+
auto reservedBlocks = kv_cache_manager::NoEvictScheduledBlocksManager(kvCacheManager);
225+
auto reservedCrossBlocks = crossKvCacheManager
226+
? std::optional(kv_cache_manager::NoEvictScheduledBlocksManager(*crossKvCacheManager))
227+
: std::nullopt;
221228
SizeType32 claimedPeftPages{0};
222229
std::unordered_set<uint64_t> uniqTaskIds{};
223230
RequestVector pendingRequests;
@@ -242,16 +249,16 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
242249
else if (req->isGenerationInProgressState())
243250
{
244251
scheduledRequests.emplace_back(req);
245-
reservedBlocks += kvCacheManager.getRemainingBlocksToCompletion(*req);
246-
252+
reservedBlocks.decrementReservedBlocks(*req);
253+
if (reservedCrossBlocks)
254+
reservedCrossBlocks->decrementReservedBlocks(*req);
247255
bool const reqHasLora = req->getLoraTaskId().has_value();
248256
bool const isNewTask = reqHasLora && !uniqTaskIds.count(req->getLoraTaskId().value());
249257
if (isNewTask)
250258
{
251259
claimedPeftPages += peftCacheManager ? peftCacheManager->determineNumPages(req) : 0;
252260
uniqTaskIds.insert(req->getLoraTaskId().value());
253261
}
254-
reservedCrossBlocks += crossKvCacheManager ? crossKvCacheManager->getRemainingBlocksToCompletion(*req) : 0;
255262
}
256263
else if (req->isDisaggGenerationInitState())
257264
{
@@ -268,8 +275,6 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
268275
if (!StaticBatchScheduling || scheduledRequests.size() == 0)
269276
{
270277
// Now check if we can add pending requests
271-
auto availableBlocks = numFreeBlocks - reservedBlocks;
272-
auto availableCrossBlocks = numFreeCrossBlocks - reservedCrossBlocks;
273278
auto availablePeftPages = maxPeftCachePages - claimedPeftPages;
274279

275280
// Loop over pending requests and add them if they can be scheduled
@@ -279,7 +284,7 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
279284
for (auto const& req : requests)
280285
{
281286
// if context request can reuse blocks contributed by another context request, skip
282-
if (!StaticBatchScheduling && !req->isDisaggGenerationInitState()
287+
if (!StaticBatchScheduling && skippingIsRelevant && !req->isDisaggGenerationInitState()
283288
&& beneficialToSkip(req, kvCacheManager, crossKvCacheManager, newlyContributedContextBlocks,
284289
newlyContributedCrossContextBlocks))
285290
{
@@ -292,27 +297,26 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
292297
}
293298
else if (req->isContextInitState() || req->isDisaggGenerationInitState())
294299
{
295-
auto const neededBlocks = kvCacheManager.getRemainingBlocksToCompletion(*req);
296-
auto const neededCrossBlocks
297-
= crossKvCacheManager ? crossKvCacheManager->getRemainingBlocksToCompletion(*req) : 0;
298-
bool const reqHasLora = req->getLoraTaskId().has_value();
299-
bool const isNewTask = reqHasLora && !uniqTaskIds.count(req->getLoraTaskId().value());
300-
auto const neededPeftPages
301-
= (isNewTask && peftCacheManager) ? peftCacheManager->determineNumPages(req) : 0;
302-
303-
if (neededBlocks <= availableBlocks && neededCrossBlocks <= availableCrossBlocks
304-
&& neededPeftPages <= availablePeftPages)
300+
bool enoughBlocks = reservedBlocks.enoughAvailableBlocks(*req);
301+
bool enoughCrossBlocks
302+
= reservedCrossBlocks ? reservedCrossBlocks->enoughAvailableBlocks(*req) : true;
303+
bool reqHasLora = req->getLoraTaskId().has_value();
304+
bool isNewTask = reqHasLora && !uniqTaskIds.count(req->getLoraTaskId().value());
305+
auto neededPeftPages = isNewTask && peftCacheManager ? peftCacheManager->determineNumPages(req) : 0;
306+
307+
if (enoughBlocks && enoughCrossBlocks && neededPeftPages <= availablePeftPages)
305308
{
306309
scheduledRequests.emplace_back(req);
307-
availableBlocks -= neededBlocks;
308-
availableCrossBlocks -= neededCrossBlocks;
310+
reservedBlocks.decrementReservedBlocks(*req);
311+
if (reservedCrossBlocks)
312+
reservedCrossBlocks->decrementReservedBlocks(*req);
309313
availablePeftPages -= neededPeftPages;
310314
if (isNewTask)
311315
{
312316
uniqTaskIds.insert(req->getLoraTaskId().value());
313317
}
314318
}
315-
else if (neededBlocks > availableBlocks || neededCrossBlocks > availableCrossBlocks)
319+
else if (!enoughBlocks || !enoughCrossBlocks)
316320
{
317321
// If one requests fails to be scheduled, break
318322
break;
@@ -324,14 +328,25 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
324328
return {std::move(scheduledRequests), RequestVector{}};
325329
}
326330

331+
// TODO(nhaber): remove forward declare and just keep the function here, right before the merge. I put it below just so
332+
// the remote diff is easier to look at/rebase conflicts
333+
bool trySchedulingRequestMaxUtilization(std::shared_ptr<LlmRequest> const& req, SizeType32 maxNumRequests,
334+
RequestVector& scheduledRequests, kv_cache_manager::MaxUtilizationScheduledBlocksManager& blocksManager,
335+
OptionalRef<BasePeftCacheManager const> peftCacheManager, SizeType32& numScheduledPeftPages,
336+
std::unordered_set<uint64_t>& seenTaskIds);
337+
327338
std::tuple<RequestVector, RequestVector> MaxUtilizationScheduler::operator()(
328339
kv_cache_manager::BaseKVCacheManager& kvCacheManager, OptionalRef<BasePeftCacheManager const> peftCacheManager,
329340
RequestList const& activeRequests) const
330341
{
331342
kvCacheManager.startScheduling();
332343

344+
// The optimization of delaying requests won't work for variable window attention
345+
bool skippingIsRelevant = !kvCacheManager.getBlockManager().isVariableWindow();
346+
333347
// Keep track of number of requests and block needed for the scheduled requests
334-
SizeType32 numScheduledBlocks{0};
348+
auto scheduledBlocksManager
349+
= kv_cache_manager::MaxUtilizationScheduledBlocksManager(kvCacheManager, mManyMicroBatches);
335350
SizeType32 numScheduledPeftPages{0};
336351
std::unordered_set<uint64_t> seenTaskIds;
337352

@@ -366,16 +381,17 @@ std::tuple<RequestVector, RequestVector> MaxUtilizationScheduler::operator()(
366381
}
367382

368383
// if context request can reuse blocks contributed by another context request, skip
369-
if (beneficialToSkip(
384+
if (skippingIsRelevant
385+
&& beneficialToSkip(
370386
req, kvCacheManager, std::nullopt, newlyContributedContextBlocks, newlyContributedCrossContextBlocks))
371387
{
372388
reqIt++;
373389
continue;
374390
}
375391

376-
auto const [fitsKvCache, fitsPeftCache] = trySchedulingRequestMaxUtilization(kvCacheManager, peftCacheManager,
377-
req, scheduledRequests, numScheduledBlocks, numScheduledPeftPages, seenTaskIds);
378-
if (fitsKvCache && fitsPeftCache)
392+
bool const wasScheduled = trySchedulingRequestMaxUtilization(req, mMaxNumRequests, scheduledRequests,
393+
scheduledBlocksManager, peftCacheManager, numScheduledPeftPages, seenTaskIds);
394+
if (wasScheduled)
379395
{
380396
TLLM_LOG_DEBUG("MaxUtilizationScheduler: request ID %lu -> start", req->mRequestId);
381397
reqIt++;
@@ -405,45 +421,38 @@ std::tuple<RequestVector, RequestVector> MaxUtilizationScheduler::operator()(
405421
return {std::move(scheduledRequests), std::move(pausedRequests)};
406422
}
407423

408-
std::pair<bool, bool> MaxUtilizationScheduler::trySchedulingRequestMaxUtilization(
409-
kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
410-
OptionalRef<BasePeftCacheManager const> peftCacheManager, std::shared_ptr<LlmRequest> const& req,
411-
RequestVector& scheduledRequests, SizeType32& numScheduledBlocks, SizeType32& numScheduledPeftPages,
412-
std::unordered_set<uint64_t>& seenTaskIds) const
424+
bool trySchedulingRequestMaxUtilization(std::shared_ptr<LlmRequest> const& req, SizeType32 maxNumRequests,
425+
RequestVector& scheduledRequests, kv_cache_manager::MaxUtilizationScheduledBlocksManager& blocksManager,
426+
OptionalRef<BasePeftCacheManager const> peftCacheManager, SizeType32& numScheduledPeftPages,
427+
std::unordered_set<uint64_t>& seenTaskIds)
413428
{
414-
if (scheduledRequests.size() < static_cast<std::size_t>(mMaxNumRequests))
429+
if (scheduledRequests.size() < static_cast<std::size_t>(maxNumRequests))
415430
{
416-
SizeType32 numRequiredBlocks = kvCacheManager.getNeededBlocksOneStep(*req, mManyMicroBatches);
417-
TLLM_LOG_DEBUG(
418-
"MaxUtilizationScheduler: request ID %lu required blocks: %i", req->mRequestId, numRequiredBlocks);
419-
420-
bool const reqHasLora = req->getLoraTaskId().has_value();
421-
bool const isNewTask = reqHasLora && !seenTaskIds.count(req->getLoraTaskId().value());
422-
auto const numRequiredPeftPages
431+
bool reqHasLora = req->getLoraTaskId().has_value();
432+
bool isNewTask = reqHasLora && !seenTaskIds.count(req->getLoraTaskId().value());
433+
SizeType32 numRequiredPeftPages
423434
= (isNewTask && peftCacheManager) ? peftCacheManager->determineNumPages(req) : 0;
424435
TLLM_LOG_DEBUG(
425436
"MaxUtilizationScheduler: request ID %lu required peft pages: %i", req->mRequestId, numRequiredPeftPages);
426-
bool const fitsKvCache
427-
= kvCacheManager.getBlockManager().schedulingHasFreeBlocks(numScheduledBlocks + numRequiredBlocks);
428-
bool const fitsPeft
437+
auto const scheduledBlocksIfFitsKvCache = blocksManager.prepareNewNumberOfBlocksIfWeEndUpScheduling(*req);
438+
bool fitsPeft
429439
= (peftCacheManager ? numRequiredPeftPages + numScheduledPeftPages <= peftCacheManager->getMaxDevicePages()
430440
: true);
431441

432-
if (fitsKvCache && fitsPeft)
442+
if (scheduledBlocksIfFitsKvCache && fitsPeft)
433443
{
434-
numScheduledBlocks += numRequiredBlocks;
435-
TLLM_LOG_DEBUG("MaxUtilizationScheduler: scheduled blocks: %i", numScheduledBlocks);
444+
blocksManager.updateScheduledBlocks(scheduledBlocksIfFitsKvCache.value());
436445
numScheduledPeftPages += numRequiredPeftPages;
437446
TLLM_LOG_DEBUG("MaxUtilizationScheduler: scheduled peft pages: %i", numRequiredPeftPages);
438447
scheduledRequests.emplace_back(req);
439448
if (isNewTask)
440449
{
441450
seenTaskIds.insert(req->getLoraTaskId().value());
442451
}
452+
return true;
443453
}
444-
return std::make_pair(fitsKvCache, fitsPeft);
445454
}
446-
return std::make_pair(false, false);
455+
return false;
447456
}
448457

449458
CapacityScheduler::CapacityScheduler(SizeType32 maxNumRequests,

cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ void DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
142142
if (cacheFormatter != nullptr)
143143
{
144144
auto* cacheManager = cacheFormatter->getCacheManager();
145-
auto blockRange = kv_cache_manager::BlockRange(
146-
*cacheManager, cacheManager->getNewlyAllocatedBlockIds(llmRequest.mRequestId));
145+
auto blockRange
146+
= kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
147147
requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState);
148148
}
149149

0 commit comments

Comments
 (0)