Skip to content

Commit 8cf28e8

Browse files
committed
CI clean
Signed-off-by: Zongfei Jing <[email protected]>
1 parent 77133b0 commit 8cf28e8

File tree

6 files changed

+25
-15
lines changed

6 files changed

+25
-15
lines changed

cpp/tensorrt_llm/thop/fp4Op.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn, tensor
144144
}
145145

146146
torch::autograd::variable_list FloatToE2M1AndUFP8SFScale(
147-
th::Tensor floatTensor, int64_t sfVecSize, int64_t sfType, bool isSfSwizzledLayout = true)
147+
th::Tensor floatTensor, int64_t sfVecSize, int64_t sfType, torch::optional<bool> isSfSwizzledLayout)
148148
{
149149
CHECK_CPU_INPUT(floatTensor, th::kFloat32);
150150
auto inputShape = floatTensor.sizes();
@@ -160,8 +160,10 @@ torch::autograd::variable_list FloatToE2M1AndUFP8SFScale(
160160
int packedFp4HiddenDim = hiddenDim / 2;
161161
int groupsPerHiddenDim = hiddenDim / sfVecSize;
162162

163-
tensorrt_llm::FP4QuantizationSFLayout layout = isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED
164-
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
163+
// Note: if isSfSwizzledLayout is provided, use its value; otherwise default to true.
164+
tensorrt_llm::FP4QuantizationSFLayout layout = isSfSwizzledLayout.value_or(true)
165+
? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED
166+
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
165167

166168
for (size_t vIdx = 0; vIdx < static_cast<size_t>(inputShape[0]); ++vIdx)
167169
{

tensorrt_llm/_torch/models/modeling_deepseekv3.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333
import torch
3434
import torch.nn.functional as F
35-
from examples.infinitebench import args
3635
import triton
3736
import triton.language as tl
3837
from torch import nn
@@ -269,13 +268,9 @@ def __init__(
269268
topk_group: int,
270269
routed_scaling_factor: float,
271270
dtype: Optional[torch.dtype] = None,
272-
<<<<<<< HEAD
273271
fuse_routing_kernel: bool = True,
274272
apply_routing: bool = False,
275-
=======
276-
is_thop: bool = True,
277273
moe_backend: str = 'CUTLASS',
278-
>>>>>>> 14626789cf (Add TRT-LLM Gen MOE to Deepseek)
279274
):
280275
super().__init__()
281276
self.weight = nn.Parameter(torch.empty((num_experts, hidden_size),
@@ -358,12 +353,9 @@ def __init__(self,
358353
topk_group=config.topk_group,
359354
routed_scaling_factor=config.routed_scaling_factor,
360355
dtype=dtype,
361-
<<<<<<< HEAD
362356
fuse_routing_kernel=True,
363-
apply_routing=False)
364-
=======
357+
apply_routing=False,
365358
moe_backend=model_config.moe_backend)
366-
>>>>>>> 14626789cf (Add TRT-LLM Gen MOE to Deepseek)
367359
self.experts = FusedMoE(
368360
num_experts=num_experts,
369361
routing_method=self.gate.routing_method,
@@ -602,7 +594,7 @@ def forward(
602594
attn_metadata=attn_metadata,
603595
all_reduce_params=AllReduceParams(
604596
enable_allreduce=not self.disable_attn_allreduce),
605-
**args,
597+
**kwargs,
606598
)
607599

608600
# deepseek allreduce kernel is better when m < 512, two shot(128~512) has acc bug, waive

tensorrt_llm/models/modeling_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,14 @@ def _get_modelopt_kv_cache_dtype(self):
221221
return None
222222

223223
def is_module_excluded_from_quantization(self, name: str) -> bool:
224+
"""Check if the module is excluded from quantization.
225+
226+
Args:
227+
name (str): The name of the module.
228+
229+
Returns:
230+
bool: True if the module is excluded from quantization, False otherwise.
231+
"""
224232
if self.exclude_modules is not None:
225233
for exclude_module in self.exclude_modules:
226234
if fnmatch.fnmatchcase(name, exclude_module):

tests/unittest/_torch/test_fp4_gemm_quantize.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import unittest
1717

18+
import pytest
1819
import torch
1920
from parameterized import parameterized
2021
from utils.util import skip_pre_blackwell_unittest, unittest_name_func

tests/unittest/api_stability/references/quant_config.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,10 @@ methods:
3838
to_dict:
3939
parameters: {}
4040
return_annotation: dict
41+
is_module_excluded_from_quantization:
42+
parameters:
43+
name:
44+
annotation: str
45+
default: inspect._empty
46+
return_annotation: bool
4147
properties: {}

tests/unittest/trt/functional/test_fp4_gemm.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828
# ufp8_type: 0 for ue8m0, 1 for ue4m3
2929
def float_tensor_to_e2m1_and_ufp8_scale(float_tensor: torch.Tensor,
3030
sf_vec_size,
31-
ufp8_type: int = 1):
31+
ufp8_type: int = 1,
32+
is_sf_swizzled_layout: bool = True):
3233
value_e2m1, scale_ufp8, rep_float = torch.ops.tensorrt_llm.float_to_e2m1_and_ufp8sf_scale(
33-
float_tensor, sf_vec_size, ufp8_type)
34+
float_tensor, sf_vec_size, ufp8_type, is_sf_swizzled_layout)
3435
return value_e2m1, scale_ufp8, rep_float
3536

3637

0 commit comments

Comments
 (0)