Skip to content

Commit 7d823ea

Browse files
committed
Add commits to explain copy from cpu <-> gpu using pinned memory
Signed-off-by: Kate Cheng <[email protected]>
1 parent 8398ea3 commit 7d823ea

File tree

5 files changed

+60
-40
lines changed

5 files changed

+60
-40
lines changed

cpp/tensorrt_llm/batch_manager/llmRequest.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ void LlmRequest::movePromptEmbeddingTableToGpu(runtime::BufferManager const& man
302302
{
303303
return;
304304
}
305+
305306
TensorPtr gpuPromptEmbeddingTable = manager.copyFrom(*mPromptEmbeddingTable.value(), runtime::MemoryType::kGPU);
306307
mPromptEmbeddingTable = gpuPromptEmbeddingTable;
307308
}

cpp/tensorrt_llm/batch_manager/promptTuningBuffers.cpp

+12-19
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ PromptTuningBuffers::PromptTuningBuffers(SizeType32 maxBatchSize, runtime::Buffe
3434
// vocabSize and mMaxPromptVocabSize
3535
mPromptTuningParams.vocabSize = manager.gpu(runtime::ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
3636
mMaxPromptVocabSize = maxPromptEmbeddingTableSize / maxBatchSize;
37-
// optionalParams.enableChunkedContext || modelConfig.getContextFMHA()
3837

3938
auto promptVocabSizeHost
4039
= runtime::BufferManager::pinned(runtime::ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
@@ -143,23 +142,18 @@ void PromptTuningBuffers::fill(RequestVector const& contextRequests, RequestVect
143142

144143
std::optional<TensorPtr> optReqPromptEmbeddingTable = std::nullopt;
145144
std::optional<SizeType32> optReqPromptVocabSize = std::nullopt;
146-
// If context chunk mode, the context chunk size would be less than the total number of tokens in the
147-
// request This if statement is to check if the context chunk mode is enabled
148-
if (batchIdx < numContextRequests)
145+
146+
if (mPromptTableOffloading)
149147
{
150-
if (mPromptTableOffloading)
151-
{
152-
optReqPromptEmbeddingTable = getChunkPtableBuffer(getChunkPtableCurrentIndex());
153-
optReqPromptVocabSize = getChunkPtableBufferSliceSize(getChunkPtableCurrentIndex(), batchIdx);
154-
}
155-
else
156-
{
157-
optReqPromptEmbeddingTable = llmReq->getPromptEmbeddingTable();
158-
optReqPromptVocabSize = llmReq->getPromptVocabSize();
159-
}
148+
optReqPromptEmbeddingTable = getChunkPtableBuffer(getChunkPtableCurrentIndex());
149+
optReqPromptVocabSize = getChunkPtableBufferSliceSize(getChunkPtableCurrentIndex(), batchIdx);
150+
}
151+
else
152+
{
153+
optReqPromptEmbeddingTable = llmReq->getPromptEmbeddingTable();
154+
optReqPromptVocabSize = llmReq->getPromptVocabSize();
160155
}
161-
// auto optReqPromptEmbeddingTable = llmReq->getPromptEmbeddingTable();
162-
// auto const optReqPromptVocabSize = llmReq->getPromptVocabSize();
156+
163157
mPromptTuningParams.promptTuningEnabled.push_back(optReqPromptEmbeddingTable.has_value());
164158

165159
// If context request & has embedding table, validate it
@@ -174,9 +168,8 @@ void PromptTuningBuffers::fill(RequestVector const& contextRequests, RequestVect
174168
// The size depends on optReqPromptVocabSize which stores how many fake prompts are in the chunk
175169
auto slicedPtable = runtime::ITensor::slice(
176170
optReqPromptEmbeddingTable.value(), 0, optReqPromptVocabSize.value());
177-
// Add leading dimension 1 for batch
178-
slicedPtable->unsqueeze(0); // Call unsqueeze() as member function
179-
optReqPromptEmbeddingTable = std::move(slicedPtable); // Move ownership of the unique_ptr
171+
slicedPtable->unsqueeze(0);
172+
optReqPromptEmbeddingTable = std::move(slicedPtable);
180173
}
181174
else
182175
{

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

+38-19
Original file line numberDiff line numberDiff line change
@@ -2615,6 +2615,36 @@ SizeType32 TrtGptModelInflightBatching::getMaxCapacityBatchSize(SizeType32 input
26152615
return mKvCacheManager->getMaxCapacityBatchSize(inputLength, outputLength);
26162616
}
26172617

2618+
/*
2619+
* Manages prefetching of prompt table chunks using a double-buffer strategy
2620+
*
2621+
* Function Flow:
2622+
* 1. First Chunk Processing (isBeforePrepareBuffers == true):
2623+
* - Uses blocking prefetch on main runtime stream
2624+
* - Ensures initial data is ready before computation starts
2625+
*
2626+
* 2. Subsequent Chunks (isBeforePrepareBuffers == false):
2627+
* - Uses non-blocking prefetch on separate copy stream
2628+
* - Overlaps data transfer with computation
2629+
*
2630+
* Synchronization:
2631+
* - First prefetch: No wait needed (fresh start)
2632+
* - Later prefetches: Wait for previous copy to complete
2633+
* - Uses mPtableCopyDoneEvent to track completion
2634+
*
2635+
* Key Functions:
2636+
* 1. prefetchNextPromptTableChunk:
2637+
* - Calls the correct function based on position in code (before or after prepareBuffers())
2638+
* - Waits for previous copy to complete if not the first chunk
2639+
*
2640+
* 2. remapInputTokensForPromptTable:
2641+
* - Identifies tokens that need prompt table embeddings (tokens that are greater than vocabSize)
2642+
* - Remaps IDs to match chunked prompt table layout
2643+
*
2644+
* 3. copyPromptTableToGpuInChunk:
2645+
* - Handles actual transfer from CPU pinned memory to GPU
2646+
* - Uses appropriate buffer manager based on isBeforePrepareBuffers
2647+
*/
26182648
void TrtGptModelInflightBatching::prefetchNextPromptTableChunk(
26192649
RequestVector const& contextRequests, bool isBeforePrepareBuffers, SizeType32 bufferId)
26202650
{
@@ -2663,7 +2693,6 @@ void TrtGptModelInflightBatching::remapInputTokensForPromptTable(
26632693
auto& inputTokensMutable = llmReq->getTokensMutable(0);
26642694
auto vocabSize = mModelConfig.getVocabSize();
26652695

2666-
// For first chunk's initialization
26672696
if (isBeforePrepareBuffers)
26682697
{
26692698
promptTuningBuffers->initializeChunkPtableBuffers(
@@ -2698,15 +2727,10 @@ void TrtGptModelInflightBatching::remapInputTokensForPromptTable(
26982727
beginPos = llmReq->getContextCurrentPosition();
26992728
}
27002729

2701-
// Bounds check
2702-
if (beginPos + processChunkSize > inputTokensMutable.size())
2703-
{
2704-
TLLM_THROW("Invalid chunk access: beginPos(%zu) + processChunkSize(%zu) > totalSize(%zu)", beginPos,
2705-
processChunkSize, inputTokensMutable.size());
2706-
return;
2707-
}
2730+
TLLM_CHECK_WITH_INFO(beginPos + processChunkSize <= inputTokensMutable.size(),
2731+
"Invalid chunk access: beginPos(%zu) + processChunkSize(%zu) > totalSize(%zu)", beginPos, processChunkSize,
2732+
inputTokensMutable.size());
27082733

2709-
// Process tokens
27102734
auto inputTokensChunk = inputTokensMutable.begin() + beginPos;
27112735
std::vector<SizeType32> outOfVocabTokens;
27122736
SizeType32 ptableTokenId = vocabSize;
@@ -2724,7 +2748,7 @@ void TrtGptModelInflightBatching::remapInputTokensForPromptTable(
27242748

27252749
void TrtGptModelInflightBatching::copyPromptTableToGpuInChunk(std::shared_ptr<LlmRequest> const& llmReq,
27262750
std::vector<int32_t> const& outOfVocabTokens, bool isBeforePrepareBuffers, SizeType32 bufferId,
2727-
SizeType32 contextId) // Add parameter to choose which buffer to use
2751+
SizeType32 contextId)
27282752
{
27292753
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
27302754
NVTX3_SCOPED_RANGE_WITH_NAME(range, "copyPromptTableToGpuInChunk");
@@ -2756,25 +2780,20 @@ void TrtGptModelInflightBatching::copyPromptTableToGpuInChunk(std::shared_ptr<Ll
27562780
auto table1D = runtime::ITensor::view(
27572781
promptTable.value(), runtime::ITensor::makeShape({static_cast<int64_t>(totalElements)}));
27582782

2759-
// Add bounds checking
2760-
if (srcOffset + sliceSize > totalElements)
2761-
{
2762-
printf("ERROR: Would access beyond buffer bounds!\n");
2763-
printf("Total elements: %zu, Trying to access up to: %zu\n", totalElements, srcOffset + (sliceSize));
2764-
}
2783+
TLLM_CHECK_WITH_INFO(srcOffset + sliceSize <= totalElements,
2784+
"Buffer bounds violation: Trying to access up to %zu elements but buffer only has %zu elements (offset: %zu, "
2785+
"slice size: %zu)",
2786+
srcOffset + sliceSize, totalElements, srcOffset, sliceSize);
27652787

2766-
// Convert UniquePtr to SharedPtr
27672788
auto table1DShared = runtime::ITensor::SharedPtr(table1D.release());
27682789
auto pTableView = runtime::ITensor::slice(table1DShared, srcOffset, sliceSize);
27692790

27702791
auto gpuBufferSlice = runtime::ITensor::slice(gpuBuffer, dstOffset, numRows);
27712792

27722793
currentBufferManager.copy(*pTableView, *gpuBufferSlice);
27732794

2774-
// Update buffer sizes
27752795
promptTuningBuffers->updateBufferStartPosition(currentIndex, outOfVocabTokens.size());
27762796

2777-
// Update position for next chunk
27782797
llmReq->mPtableCurrentPosition += outOfVocabTokens.size();
27792798

27802799
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);

tensorrt_llm/runtime/model_runner_cpp.py

+3
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,9 @@ def _prepare_ptuning_executor(self, batch_input_ids_list, prompt_table,
834834
prompt_tuning_configs = len(batch_input_ids_list) * [None]
835835
if prompt_table is not None:
836836
if mm_embedding_offloading:
837+
# CUDA Stream Overlapping Requirements:
838+
# 1. Both memory copy stream and kernel execution stream must be non-default streams
839+
# 2. For host<->device transfers (H2D/D2H), host memory MUST be page-locked (pinned)
837840
prompt_table_data = self._prepare_embedding_table(
838841
prompt_table).pin_memory()
839842
else:

tensorrt_llm/runtime/multimodal_model_runner.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1379,11 +1379,12 @@ def get_visual_features(self, image, other_vision_inputs):
13791379
image_embeds = visual_outputs[self.vision_output_names[0]]
13801380

13811381
if self.args.mm_embedding_offloading:
1382-
# Allocate pinned memory with same shape and dtype
1382+
# CUDA Stream Overlapping Requirements:
1383+
# 1. Both memory copy stream and kernel execution stream must be non-default streams
1384+
# 2. For host<->device transfers (H2D/D2H), host memory MUST be page-locked (pinned)
13831385
pinned_embeds = torch.empty_like(image_embeds,
13841386
device='cpu',
13851387
pin_memory=True)
1386-
# Copy directly from GPU to pinned memory
13871388
pinned_embeds.copy_(image_embeds, non_blocking=True)
13881389
image_embeds = pinned_embeds
13891390

@@ -1823,6 +1824,9 @@ def ptuning_setup(self, prompt_table, input_ids, input_lengths):
18231824
dtype=str_dtype_to_torch(self.model_config.dtype))
18241825
else:
18251826
if self.args.mm_embedding_offloading:
1827+
# CUDA Stream Overlapping Requirements:
1828+
# 1. Both memory copy stream and kernel execution stream must be non-default streams
1829+
# 2. For host<->device transfers (H2D/D2H), host memory MUST be page-locked (pinned)
18261830
prompt_table = prompt_table.pin_memory().to(
18271831
dtype=self.model.dtype)
18281832
else:

0 commit comments

Comments
 (0)