Skip to content

Commit eeb605a

Browse files
feat: Offloading Multimodal embedding table to CPU in Chunked Prefill Mode (#3380)
* Feat: Offload ptable to cpu if enable_chunk_context Signed-off-by: Kate Cheng <[email protected]> * Feat: offload ptable to cpu for chunk context mode Signed-off-by: Kate Cheng <[email protected]> * Fix and add comment Signed-off-by: Kate Cheng <[email protected]> * Update Readme for multimodal and add a new param mm_embedding_offloading Signed-off-by: Kate Cheng <[email protected]> * fix: Correct prompt table offloading condition in PromptTuningBuffers Signed-off-by: Kate Cheng <[email protected]> * Clean up the code Signed-off-by: Kate Cheng <[email protected]> * Add commits to explain copy from cpu <-> gpu using pinned memory Signed-off-by: Kate Cheng <[email protected]> * Fix namings based on comments Signed-off-by: Kate Cheng <[email protected]> * Fix format based on precommit Signed-off-by: Kate Cheng <[email protected]> * Modify --mm_embedding_offloading flag Signed-off-by: Kate Cheng <[email protected]> --------- Signed-off-by: Kate Cheng <[email protected]> Co-authored-by: Haohang Huang <[email protected]>
1 parent faef377 commit eeb605a

File tree

19 files changed

+622
-86
lines changed

19 files changed

+622
-86
lines changed

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,14 @@ class GenericLlmRequest
553553
return mTokens.at(beam);
554554
}
555555

556+
/// @brief Get mutable reference to tokens for a specific beam
557+
/// @param beam The beam index
558+
/// @return Mutable reference to the tokens vector
559+
[[nodiscard]] VecTokens& getTokensMutable(SizeType32 beam)
560+
{
561+
return mTokens.at(beam);
562+
}
563+
556564
/// @brief Get all tokens (input+output) for all beams
557565
/// @return A vector of vector of tokens.
558566
[[nodiscard]] BeamTokens const& getTokens() const
@@ -1772,6 +1780,9 @@ class GenericLlmRequest
17721780

17731781
LlmRequestState mState{LlmRequestState::kCONTEXT_INIT};
17741782

1783+
// current position of the prompt tuning table (only used in chunked prefill mode)
1784+
SizeType32 mPtableCurrentPosition{0};
1785+
17751786
protected:
17761787
bool mIsStreaming;
17771788

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
20+
#include "tensorrt_llm/batch_manager/common.h"
21+
#include "tensorrt_llm/runtime/bufferManager.h"
22+
#include "tensorrt_llm/runtime/iTensor.h"
23+
#include "tensorrt_llm/runtime/modelConfig.h"
24+
#include "tensorrt_llm/runtime/promptTuningParams.h"
25+
#include "tensorrt_llm/runtime/worldConfig.h"
26+
27+
namespace tensorrt_llm::batch_manager
28+
{
29+
30+
class PromptTuningBuffers
31+
{
32+
33+
public:
34+
using SizeType32 = tensorrt_llm::runtime::SizeType32;
35+
using ITensor = tensorrt_llm::runtime::ITensor;
36+
using TensorPtr = runtime::ITensor::SharedPtr;
37+
38+
runtime::PromptTuningParams mPromptTuningParams;
39+
SizeType32 mMaxPromptVocabSize;
40+
41+
PromptTuningBuffers(SizeType32 maxBatchSize, runtime::BufferManager const& manager,
42+
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig);
43+
44+
PromptTuningBuffers(SizeType32 maxBatchSize, runtime::BufferManager const& manager,
45+
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, bool promptTableOffloading);
46+
47+
void validate(std::optional<TensorPtr> const& optReqPromptEmbeddingTable,
48+
std::optional<SizeType32> const& optReqPromptVocabSize);
49+
50+
void fill(RequestVector const& contextRequests, RequestVector const& genRequests,
51+
runtime::BufferManager const& manager, bool packed);
52+
53+
/*
54+
* The below functions are specific for Chunked Prefill mode
55+
* Chunk Ptable with Ping-Pong Buffer Implementation
56+
* -----------------------------------------------
57+
*
58+
* Overview:
59+
* The chunk ptable (prompt tuning table) system uses a ping-pong buffer mechanism to efficiently
60+
* manage large embedding tables when operating in context Prefill mode. This allows
61+
* for processing of large embedding tables by loading them in chunks from CPU to GPU memory,
62+
* enabling support for tables that exceed available GPU memory.
63+
*
64+
* Key Components:
65+
* 1. Ping-Pong Buffers (mChunkPtableBuffers):
66+
* - Two alternating GPU buffers that store chunks of the embedding table
67+
* - While the current buffer is being processed by the model,
68+
* the next chunk can be asynchronously loaded into the other buffer
69+
* - Managed through mChunkPtableCurrentIndex (toggles between 0 and 1)
70+
* 2. Start Positions Tracking (mChunkPtableBufferStartPositions):
71+
* - Mainly used for multi-batch processing
72+
* - Maintains the starting position of each batch's data within each buffer
73+
* - Maintained separately for each ping-pong buffer
74+
*
75+
* Memory Optimization:
76+
* - Only two GPU buffers are maintained regardless of total embedding table size
77+
* - Each buffer size is limited to contextChunkSize * hiddenSize
78+
* - Efficient memory usage through chunk-based processing
79+
*/
80+
81+
bool mPromptTableOffloading;
82+
83+
bool mChunkPtableInitialized{false};
84+
std::optional<std::array<TensorPtr, 2>> mChunkPtableBuffers;
85+
std::optional<std::vector<std::vector<SizeType32>>> mChunkPtableBufferStartPositions;
86+
size_t mChunkPtableCurrentIndex{0};
87+
88+
void initializeChunkPtableBuffers(runtime::BufferManager const& manager, runtime::ModelConfig const& modelConfig,
89+
SizeType32 contextChunkSize, std::shared_ptr<LlmRequest> const& llmReq);
90+
91+
void switchChunkPtableBuffer();
92+
93+
size_t getChunkPtableCurrentIndex();
94+
95+
[[nodiscard]] TensorPtr& getChunkPtableBuffer(size_t index);
96+
97+
[[nodiscard]] SizeType32 getChunkPtableBufferSliceSize(size_t index, size_t batchIdx);
98+
99+
[[nodiscard]] SizeType32 getChunkPtableBufferStartPosition(size_t index, size_t batchIdx);
100+
101+
void updateBufferStartPosition(size_t index, SizeType32 numRows);
102+
103+
void clearBufferStartPositions(size_t index);
104+
};
105+
106+
} // namespace tensorrt_llm::batch_manager

cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ class RuntimeBuffers
135135

136136
public:
137137
TensorPtr sequenceLengthsDevice;
138+
bool promptTableOffloading;
139+
140+
//! Prompt-Tuning
141+
std::unique_ptr<PromptTuningBuffers> promptTuningBuffers;
138142

139143
private:
140144
//! Runtime
@@ -148,9 +152,6 @@ class RuntimeBuffers
148152
//! Pipeline-Parallelism
149153
TensorPtr hiddenStates;
150154

151-
//! Prompt-Tuning
152-
std::unique_ptr<PromptTuningBuffers> promptTuningBuffers;
153-
154155
//! Mrope
155156
TensorPtr mropeRotaryCosSin;
156157
TensorPtr mropePositionDeltas;
@@ -259,7 +260,8 @@ class RuntimeBuffers
259260
runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
260261
runtime::WorldConfig const& worldConfig, executor::DecodingConfig const& decodingConfig,
261262
bool gatherGenerationLogits, std::optional<SizeType32> maxNumTokens = std::nullopt,
262-
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs = std::nullopt);
263+
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs = std::nullopt,
264+
bool promptTableOffloading = false);
263265

264266
RuntimeBuffers(RuntimeBuffers const& other) = delete;
265267
RuntimeBuffers& operator=(RuntimeBuffers const& other) = delete;

cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class TrtGptModelOptionalParams
5353
std::optional<executor::GuidedDecodingConfig> guidedDecodingConfig = std::nullopt,
5454
bool isLeaderInOrchMode = false,
5555
std::optional<std::vector<executor::AdditionalModelOutput>> additionalModelOutputs = std::nullopt,
56-
bool gatherGenerationLogits = false)
56+
bool gatherGenerationLogits = false, bool promptTableOffloading = false)
5757
: kvCacheConfig{std::move(kvCacheConfig)}
5858
, enableTrtOverlap{enableTrtOverlap}
5959
, deviceIds(std::move(deviceIds))
@@ -75,6 +75,7 @@ class TrtGptModelOptionalParams
7575
, isLeaderInOrchMode{isLeaderInOrchMode}
7676
, additionalModelOutputs{std::move(additionalModelOutputs)}
7777
, gatherGenerationLogits{gatherGenerationLogits}
78+
, promptTableOffloading{promptTableOffloading}
7879
{
7980
if (guidedDecodingConfig)
8081
{
@@ -125,6 +126,8 @@ class TrtGptModelOptionalParams
125126
bool isLeaderInOrchMode;
126127
std::optional<std::vector<executor::AdditionalModelOutput>> additionalModelOutputs;
127128
bool gatherGenerationLogits;
129+
// Whether to offload the prompt table to CPU and prefetching to GPU
130+
bool promptTableOffloading;
128131
};
129132

130133
} // namespace tensorrt_llm::batch_manager

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1408,7 +1408,8 @@ class ExecutorConfig
14081408
std::optional<SpeculativeDecodingConfig> specDecConfig = std::nullopt,
14091409
std::optional<GuidedDecodingConfig> guidedDecodingConfig = std::nullopt,
14101410
std::optional<std::vector<AdditionalModelOutput>> additionalModelOutputs = std::nullopt,
1411-
bool gatherGenerationLogits = false, bool useVariableBeamWidthSearch = false);
1411+
bool gatherGenerationLogits = false, bool useVariableBeamWidthSearch = false,
1412+
bool promptTableOffloading = false);
14121413

14131414
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
14141415
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
@@ -1441,6 +1442,7 @@ class ExecutorConfig
14411442
[[nodiscard]] std::optional<std::vector<AdditionalModelOutput>> getAdditionalModelOutputs() const;
14421443
[[nodiscard]] bool getGatherGenerationLogits() const;
14431444
[[nodiscard]] bool getUseVariableBeamWidthSearch() const;
1445+
[[nodiscard]] bool getPromptTableOffloading() const;
14441446

14451447
void setMaxBeamWidth(SizeType32 maxBeamWidth);
14461448
void setMaxBatchSize(SizeType32 maxBatchSize);
@@ -1468,6 +1470,7 @@ class ExecutorConfig
14681470
void setAdditionalModelOutputs(std::vector<AdditionalModelOutput> const& additionalModelOutputs);
14691471
void setGatherGenerationLogits(bool gatherGenerationLogits);
14701472
void setUseVariableBeamWidthSearch(bool useVariableBeamWidthSearch);
1473+
void setPromptTableOffloading(bool promptTableOffloading);
14711474

14721475
private:
14731476
friend class Serialization;
@@ -1548,6 +1551,9 @@ class ExecutorConfig
15481551

15491552
/// @brief Controls if Variable-Beam-Width-Search is enabled.
15501553
bool mUseVariableBeamWidthSearch{false};
1554+
1555+
/// @brief Controls if prompt table offloading is enabled.
1556+
bool mPromptTableOffloading{false};
15511557
};
15521558

15531559
struct KVCacheCreatedData

0 commit comments

Comments
 (0)