@@ -2615,6 +2615,36 @@ SizeType32 TrtGptModelInflightBatching::getMaxCapacityBatchSize(SizeType32 input
2615
2615
return mKvCacheManager ->getMaxCapacityBatchSize (inputLength, outputLength);
2616
2616
}
2617
2617
2618
+ /*
2619
+ * Manages prefetching of prompt table chunks using a double-buffer strategy
2620
+ *
2621
+ * Function Flow:
2622
+ * 1. First Chunk Processing (isBeforePrepareBuffers == true):
2623
+ * - Uses blocking prefetch on main runtime stream
2624
+ * - Ensures initial data is ready before computation starts
2625
+ *
2626
+ * 2. Subsequent Chunks (isBeforePrepareBuffers == false):
2627
+ * - Uses non-blocking prefetch on separate copy stream
2628
+ * - Overlaps data transfer with computation
2629
+ *
2630
+ * Synchronization:
2631
+ * - First prefetch: No wait needed (fresh start)
2632
+ * - Later prefetches: Wait for previous copy to complete
2633
+ * - Uses mPtableCopyDoneEvent to track completion
2634
+ *
2635
+ * Key Functions:
2636
+ * 1. prefetchNextPromptTableChunk:
2637
+ * - Calls the correct function based on position in code (before or after prepareBuffers())
2638
+ * - Waits for previous copy to complete if not the first chunk
2639
+ *
2640
+ * 2. remapInputTokensForPromptTable:
2641
+ * - Identifies tokens that need prompt table embeddings (tokens that are greater than vocabSize)
2642
+ * - Remaps IDs to match chunked prompt table layout
2643
+ *
2644
+ * 3. copyPromptTableToGpuInChunk:
2645
+ * - Handles actual transfer from CPU pinned memory to GPU
2646
+ * - Uses appropriate buffer manager based on isBeforePrepareBuffers
2647
+ */
2618
2648
void TrtGptModelInflightBatching::prefetchNextPromptTableChunk (
2619
2649
RequestVector const & contextRequests, bool isBeforePrepareBuffers, SizeType32 bufferId)
2620
2650
{
@@ -2663,7 +2693,6 @@ void TrtGptModelInflightBatching::remapInputTokensForPromptTable(
2663
2693
auto & inputTokensMutable = llmReq->getTokensMutable (0 );
2664
2694
auto vocabSize = mModelConfig .getVocabSize ();
2665
2695
2666
- // For first chunk's initialization
2667
2696
if (isBeforePrepareBuffers)
2668
2697
{
2669
2698
promptTuningBuffers->initializeChunkPtableBuffers (
@@ -2698,15 +2727,10 @@ void TrtGptModelInflightBatching::remapInputTokensForPromptTable(
2698
2727
beginPos = llmReq->getContextCurrentPosition ();
2699
2728
}
2700
2729
2701
- // Bounds check
2702
- if (beginPos + processChunkSize > inputTokensMutable.size ())
2703
- {
2704
- TLLM_THROW (" Invalid chunk access: beginPos(%zu) + processChunkSize(%zu) > totalSize(%zu)" , beginPos,
2705
- processChunkSize, inputTokensMutable.size ());
2706
- return ;
2707
- }
2730
+ TLLM_CHECK_WITH_INFO (beginPos + processChunkSize <= inputTokensMutable.size (),
2731
+ " Invalid chunk access: beginPos(%zu) + processChunkSize(%zu) > totalSize(%zu)" , beginPos, processChunkSize,
2732
+ inputTokensMutable.size ());
2708
2733
2709
- // Process tokens
2710
2734
auto inputTokensChunk = inputTokensMutable.begin () + beginPos;
2711
2735
std::vector<SizeType32> outOfVocabTokens;
2712
2736
SizeType32 ptableTokenId = vocabSize;
@@ -2724,7 +2748,7 @@ void TrtGptModelInflightBatching::remapInputTokensForPromptTable(
2724
2748
2725
2749
void TrtGptModelInflightBatching::copyPromptTableToGpuInChunk (std::shared_ptr<LlmRequest> const & llmReq,
2726
2750
std::vector<int32_t > const & outOfVocabTokens, bool isBeforePrepareBuffers, SizeType32 bufferId,
2727
- SizeType32 contextId) // Add parameter to choose which buffer to use
2751
+ SizeType32 contextId)
2728
2752
{
2729
2753
TLLM_LOG_TRACE (" %s start" , __PRETTY_FUNCTION__);
2730
2754
NVTX3_SCOPED_RANGE_WITH_NAME (range, " copyPromptTableToGpuInChunk" );
@@ -2756,25 +2780,20 @@ void TrtGptModelInflightBatching::copyPromptTableToGpuInChunk(std::shared_ptr<Ll
2756
2780
auto table1D = runtime::ITensor::view (
2757
2781
promptTable.value (), runtime::ITensor::makeShape ({static_cast <int64_t >(totalElements)}));
2758
2782
2759
- // Add bounds checking
2760
- if (srcOffset + sliceSize > totalElements)
2761
- {
2762
- printf (" ERROR: Would access beyond buffer bounds!\n " );
2763
- printf (" Total elements: %zu, Trying to access up to: %zu\n " , totalElements, srcOffset + (sliceSize));
2764
- }
2783
+ TLLM_CHECK_WITH_INFO (srcOffset + sliceSize <= totalElements,
2784
+ " Buffer bounds violation: Trying to access up to %zu elements but buffer only has %zu elements (offset: %zu, "
2785
+ " slice size: %zu)" ,
2786
+ srcOffset + sliceSize, totalElements, srcOffset, sliceSize);
2765
2787
2766
- // Convert UniquePtr to SharedPtr
2767
2788
auto table1DShared = runtime::ITensor::SharedPtr (table1D.release ());
2768
2789
auto pTableView = runtime::ITensor::slice (table1DShared, srcOffset, sliceSize);
2769
2790
2770
2791
auto gpuBufferSlice = runtime::ITensor::slice (gpuBuffer, dstOffset, numRows);
2771
2792
2772
2793
currentBufferManager.copy (*pTableView, *gpuBufferSlice);
2773
2794
2774
- // Update buffer sizes
2775
2795
promptTuningBuffers->updateBufferStartPosition (currentIndex, outOfVocabTokens.size ());
2776
2796
2777
- // Update position for next chunk
2778
2797
llmReq->mPtableCurrentPosition += outOfVocabTokens.size ();
2779
2798
2780
2799
TLLM_LOG_TRACE (" %s stop" , __PRETTY_FUNCTION__);
0 commit comments