Skip to content

Commit 4196d0b

Browse files
committed
ut
1 parent 62e8dba commit 4196d0b

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

test/microbench/scatter.gather.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
from torch.profiler import profile, ProfilerActivity
3+
4+
device = "xpu"
5+
backward = True
6+
7+
# Define shapes for scatter/gather testing
8+
# (input_shape, index_shape, dim_to_scatter_gather)
9+
shape_list = [
10+
((4096, 8192), (4096, 8192), 1), # Simple 2D case
11+
((2, 4096, 320), (2, 4096, 1), 2), # Scatter/Gather along the last dim
12+
((512, 3136, 128), (512, 1, 128), 1), # Scatter/Gather along the middle dim
13+
((128, 49, 196, 1024), (128, 49, 196, 1), 3), # 4D case, scatter/gather last dim
14+
]
15+
16+
for shape_config in shape_list:
17+
input_shape, index_shape, dim_to_operate = shape_config
18+
19+
for dtype in [torch.bfloat16, torch.float16, torch.float32]:
20+
# Generate input tensor
21+
input_tensor = torch.randn(input_shape, device=device, dtype=dtype)
22+
23+
# Generate index tensor for gather/scatter
24+
# Ensure indices are within valid bounds for the dimension
25+
max_idx_val = input_tensor.shape[dim_to_operate]
26+
index_tensor = torch.randint(0, max_idx_val, index_shape, device=device, dtype=torch.int64)
27+
28+
# Generate source tensor for scatter
29+
# Its shape should match index_tensor in the dimension being scattered into,
30+
# and input_tensor in other dimensions.
31+
scatter_source_shape = list(input_tensor.shape)
32+
for i, dim_size in enumerate(index_shape):
33+
if i == dim_to_operate:
34+
scatter_source_shape[i] = dim_size
35+
scatter_source = torch.randn(scatter_source_shape, device=device, dtype=dtype)
36+
37+
if backward:
38+
input_tensor.requires_grad_(True)
39+
scatter_source.requires_grad_(True)
40+
41+
# Warm-up phase
42+
# Gather operation
43+
gathered_output_warmup = torch.gather(input_tensor, dim_to_operate, index_tensor)
44+
if backward:
45+
gy_gather = torch.empty_like(gathered_output_warmup)
46+
gathered_output_warmup.backward(gy_gather)
47+
48+
# Scatter operation (using out-of-place scatter_ to ensure a fresh tensor for profiling)
49+
scattered_output_warmup = input_tensor.clone().scatter_(dim_to_operate, index_tensor, scatter_source)
50+
if backward:
51+
gy_scatter = torch.empty_like(scattered_output_warmup)
52+
scattered_output_warmup.backward(gy_scatter)
53+
54+
print(
55+
"---"
56+
)
57+
print(
58+
"Testing Scatter/Gather -- input shape:",
59+
input_shape,
60+
"; index shape:",
61+
index_shape,
62+
"; datatype:",
63+
dtype,
64+
"; dim:",
65+
dim_to_operate,
66+
"; backward:",
67+
backward,
68+
)
69+
print(
70+
"---"
71+
)
72+
73+
# Profiling phase
74+
with profile(
75+
activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], record_shapes=True
76+
) as prof:
77+
for i in range(20):
78+
# Gather operation
79+
gathered_output = torch.gather(input_tensor, dim_to_operate, index_tensor)
80+
if backward:
81+
gy_gather = torch.empty_like(gathered_output)
82+
gathered_output.backward(gy_gather)
83+
84+
# Scatter operation
85+
# We clone input_tensor each time to avoid modifying the same tensor
86+
# across iterations, which could affect profiling if in-place ops are used.
87+
scattered_output = input_tensor.clone().scatter_(dim_to_operate, index_tensor, scatter_source)
88+
if backward:
89+
gy_scatter = torch.empty_like(scattered_output)
90+
scattered_output.backward(gy_scatter)
91+
92+
print(prof.key_averages().table(sort_by="xpu_time_total"))

0 commit comments

Comments
 (0)