@@ -527,8 +527,7 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
527
527
}
528
528
// Get the output scale.
529
529
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal))
530
- float outputScale
531
- = SFValue != 0 ? reciprocal_approximate_ftz (SFValue * reciprocal_approximate_ftz (SFScaleVal)) : 0 .0f ;
530
+ float outputScale = SFValue != 0 ? SFScaleVal * reciprocal_approximate_ftz (SFValue) : 0 .0f ;
532
531
533
532
if (SFout)
534
533
{
@@ -557,6 +556,46 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
557
556
#endif
558
557
}
559
558
559
+ inline __device__ int64_t get_sf_out_offset_128x4 (
560
+ std::optional<int > batchIdx, int mIdx , int kIdx , std::optional<int > numRows, int numCols)
561
+ {
562
+ // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
563
+ // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
564
+
565
+ // batched tensor
566
+ // SF layout [numBTiles, numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
567
+ // --> index [bTileIdx, mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
568
+
569
+ int32_t innerKIdx = (kIdx % 4 );
570
+ int64_t innerKStride = 1 ;
571
+
572
+ int32_t innerMIdx = (mIdx % (32 * 4 )) / 32 ;
573
+ int64_t innerMStride = 4 * innerKStride; // 4
574
+
575
+ // M tile layout [32, 4] is column-major.
576
+ int32_t outerMIdx = (mIdx % 32 );
577
+ int64_t outerMStride = 4 * innerMStride; // 16
578
+
579
+ int32_t kTileIdx = (kIdx / 4 );
580
+ int64_t kTileStride = 32 * outerMStride; // 512
581
+
582
+ // SF vector size 16. We round the "numCols" up to a multiple of 64.
583
+ int factor = CVT_FP4_SF_VEC_SIZE * 4 ;
584
+ int32_t numKTiles = (numCols + factor - 1 ) / factor;
585
+ int32_t mTileIdx = mIdx / (32 * 4 );
586
+ int64_t mTileStride = numKTiles * kTileStride ;
587
+
588
+ // Each SF block has 128 rows so pad rows to the multiple of 128.
589
+ int32_t numMTiles = (numRows.value_or (0 ) + 128 - 1 ) / 128 ;
590
+ int64_t bTileStride = numMTiles * mTileStride ;
591
+
592
+ // Compute the global offset.
593
+ int64_t SFOffset = batchIdx.value_or (0 ) * bTileStride + mTileIdx * mTileStride + kTileIdx * kTileStride
594
+ + outerMIdx * outerMStride + innerMIdx * innerMStride + innerKIdx * innerKStride;
595
+
596
+ return SFOffset;
597
+ }
598
+
560
599
template <class SFType , int CVT_FP4_NUM_THREADS_PER_SF>
561
600
__device__ uint8_t * cvt_quant_to_fp4_get_sf_out_offset (std::optional<int > batchIdx, int rowIdx, int colIdx,
562
601
std::optional<int > numRows, int numCols, SFType* SFout, FP4QuantizationSFLayout layout)
@@ -576,40 +615,7 @@ __device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchI
576
615
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
577
616
int32_t mIdx = rowIdx;
578
617
579
- // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
580
- // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
581
-
582
- // batched tensor
583
- // SF layout [numBTiles, numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
584
- // --> index [bTileIdx, mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
585
-
586
- int32_t innerKIdx = (kIdx % 4 );
587
- int64_t innerKStride = 1 ;
588
-
589
- int32_t innerMIdx = (mIdx % (32 * 4 )) / 32 ;
590
- int64_t innerMStride = 4 * innerKStride; // 4
591
-
592
- // M tile layout [32, 4] is column-major.
593
- int32_t outerMIdx = (mIdx % 32 );
594
- int64_t outerMStride = 4 * innerMStride; // 16
595
-
596
- int32_t kTileIdx = (kIdx / 4 );
597
- int64_t kTileStride = 32 * outerMStride; // 512
598
-
599
- // SF vector size 16. We round the "numCols" up to a multiple of 64.
600
- int factor = CVT_FP4_SF_VEC_SIZE * 4 ;
601
- int32_t numKTiles = (numCols + factor - 1 ) / factor;
602
- int32_t mTileIdx = mIdx / (32 * 4 );
603
- int64_t mTileStride = numKTiles * kTileStride ;
604
-
605
- // Each SF block has 128 rows so pad rows to the multiple of 128.
606
- int32_t numMTiles = (numRows.value_or (0 ) + 128 - 1 ) / 128 ;
607
- int64_t bTileStride = numMTiles * mTileStride ;
608
-
609
- // Compute the global offset.
610
- int64_t SFOffset = batchIdx.value_or (0 ) * bTileStride + mTileIdx * mTileStride + kTileIdx * kTileStride
611
- + outerMIdx * outerMStride + innerMIdx * innerMStride + innerKIdx * innerKStride;
612
-
618
+ auto SFOffset = get_sf_out_offset_128x4 (batchIdx, mIdx , kIdx , numRows, numCols);
613
619
return reinterpret_cast <uint8_t *>(SFout) + SFOffset;
614
620
}
615
621
else if (layout == FP4QuantizationSFLayout::LINEAR)
@@ -819,5 +825,7 @@ cvt_fp8_to_fp4(
819
825
#endif
820
826
}
821
827
828
+ __global__ void nvfp4_block_scale_interleave_kernel (
829
+ int numbatches, int numRows, int numCols, uint8_t const * SFIn, uint8_t * SFOutput);
822
830
} // namespace kernels
823
831
} // namespace tensorrt_llm
0 commit comments