39
39
)
40
40
def test_fp8_scaled_mm (output_dtype , m , k_n ):
41
41
# Skip specific problematic case
42
- if m == 228 and k_n == (28672 , 8192 ):
43
- pytest .skip ("Skipping problematic case with m=228, k=28672, n=8192" )
44
42
45
43
k , n = k_n
46
44
torch .random .manual_seed (0 )
@@ -50,6 +48,10 @@ def test_fp8_scaled_mm(output_dtype, m, k_n):
50
48
w = torch .rand (shape_w , device = "cuda" ).to (torch .float8_e4m3fn )
51
49
scale_x = torch .rand (1 , device = "cuda" )
52
50
scale_w = torch .rand (1 , device = "cuda" )
51
+ if (m == 12 or m == 228 ) and k_n == (28672 , 8192 ):
52
+ from torch .profiler import ProfilerActivity , profile
53
+ p = profile (activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA
54
+ ]).__enter__ ()
53
55
output = torch .ops .trtllm .cublas_scaled_mm (
54
56
x ,
55
57
w .t (),
@@ -60,7 +62,7 @@ def test_fp8_scaled_mm(output_dtype, m, k_n):
60
62
)
61
63
# set pytorch's cublas workspace size to 32MB to be aligned with trtllm
62
64
old_env = os .environ .get ("CUBLASLT_WORKSPACE_SIZE" , "" )
63
- os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = f"{ 32 * 1024 * 1024 } "
65
+ os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = f"{ 32 * 1024 } "
64
66
ref = torch ._scaled_mm (
65
67
x ,
66
68
w .t (),
@@ -69,6 +71,10 @@ def test_fp8_scaled_mm(output_dtype, m, k_n):
69
71
scale_b = scale_w ,
70
72
use_fast_accum = True ,
71
73
)
74
+ if (m == 12 or m == 228 ) and k_n == (28672 , 8192 ):
75
+ p .__exit__ (None , None , None )
76
+ warn (p .key_averages ().table (sort_by = "self_cuda_time_total" ,
77
+ row_limit = - 1 ))
72
78
os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = old_env
73
79
np .testing .assert_allclose (ref .float ().cpu (), output .float ().cpu ())
74
80
0 commit comments