Skip to content

Commit 37b5b67

Browse files
committed
fix: fix cublas_scaled_mm
Signed-off-by: Zhenhuan Chen <[email protected]>
1 parent 41ce544 commit 37b5b67

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

tests/unittest/_torch/thop/test_scaled_mm.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@
3939
)
4040
def test_fp8_scaled_mm(output_dtype, m, k_n):
4141
# Skip specific problematic case
42-
if m == 228 and k_n == (28672, 8192):
43-
pytest.skip("Skipping problematic case with m=228, k=28672, n=8192")
4442

4543
k, n = k_n
4644
torch.random.manual_seed(0)
@@ -50,6 +48,10 @@ def test_fp8_scaled_mm(output_dtype, m, k_n):
5048
w = torch.rand(shape_w, device="cuda").to(torch.float8_e4m3fn)
5149
scale_x = torch.rand(1, device="cuda")
5250
scale_w = torch.rand(1, device="cuda")
51+
if (m == 12 or m == 228) and k_n == (28672, 8192):
52+
from torch.profiler import ProfilerActivity, profile
53+
p = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA
54+
]).__enter__()
5355
output = torch.ops.trtllm.cublas_scaled_mm(
5456
x,
5557
w.t(),
@@ -60,7 +62,7 @@ def test_fp8_scaled_mm(output_dtype, m, k_n):
6062
)
6163
# set pytorch's cublas workspace size to 32MB to be aligned with trtllm
6264
old_env = os.environ.get("CUBLASLT_WORKSPACE_SIZE", "")
63-
os.environ["CUBLASLT_WORKSPACE_SIZE"] = f"{32*1024*1024}"
65+
os.environ["CUBLASLT_WORKSPACE_SIZE"] = f"{32*1024}"
6466
ref = torch._scaled_mm(
6567
x,
6668
w.t(),
@@ -69,6 +71,10 @@ def test_fp8_scaled_mm(output_dtype, m, k_n):
6971
scale_b=scale_w,
7072
use_fast_accum=True,
7173
)
74+
if (m == 12 or m == 228) and k_n == (28672, 8192):
75+
p.__exit__(None, None, None)
76+
warn(p.key_averages().table(sort_by="self_cuda_time_total",
77+
row_limit=-1))
7278
os.environ["CUBLASLT_WORKSPACE_SIZE"] = old_env
7379
np.testing.assert_allclose(ref.float().cpu(), output.float().cpu())
7480

0 commit comments

Comments
 (0)