1
+ import torch
2
+ from torch .profiler import profile , ProfilerActivity
3
+
4
+ device = "xpu"
5
+ backward = True
6
+
7
+ # Define shapes for scatter/gather testing
8
+ # (input_shape, index_shape, dim_to_scatter_gather)
9
+ shape_list = [
10
+ ((4096 , 8192 ), (4096 , 8192 ), 1 ), # Simple 2D case
11
+ ((2 , 4096 , 320 ), (2 , 4096 , 1 ), 2 ), # Scatter/Gather along the last dim
12
+ ((512 , 3136 , 128 ), (512 , 1 , 128 ), 1 ), # Scatter/Gather along the middle dim
13
+ ((128 , 49 , 196 , 1024 ), (128 , 49 , 196 , 1 ), 3 ), # 4D case, scatter/gather last dim
14
+ ]
15
+
16
+ for shape_config in shape_list :
17
+ input_shape , index_shape , dim_to_operate = shape_config
18
+
19
+ for dtype in [torch .bfloat16 , torch .float16 , torch .float32 ]:
20
+ # Generate input tensor
21
+ input_tensor = torch .randn (input_shape , device = device , dtype = dtype )
22
+
23
+ # Generate index tensor for gather/scatter
24
+ # Ensure indices are within valid bounds for the dimension
25
+ max_idx_val = input_tensor .shape [dim_to_operate ]
26
+ index_tensor = torch .randint (0 , max_idx_val , index_shape , device = device , dtype = torch .int64 )
27
+
28
+ # Generate source tensor for scatter
29
+ # Its shape should match index_tensor in the dimension being scattered into,
30
+ # and input_tensor in other dimensions.
31
+ scatter_source_shape = list (input_tensor .shape )
32
+ for i , dim_size in enumerate (index_shape ):
33
+ if i == dim_to_operate :
34
+ scatter_source_shape [i ] = dim_size
35
+ scatter_source = torch .randn (scatter_source_shape , device = device , dtype = dtype )
36
+
37
+ if backward :
38
+ input_tensor .requires_grad_ (True )
39
+ scatter_source .requires_grad_ (True )
40
+
41
+ # Warm-up phase
42
+ # Gather operation
43
+ gathered_output_warmup = torch .gather (input_tensor , dim_to_operate , index_tensor )
44
+ if backward :
45
+ gy_gather = torch .empty_like (gathered_output_warmup )
46
+ gathered_output_warmup .backward (gy_gather )
47
+
48
+ # Scatter operation (using out-of-place scatter_ to ensure a fresh tensor for profiling)
49
+ scattered_output_warmup = input_tensor .clone ().scatter_ (dim_to_operate , index_tensor , scatter_source )
50
+ if backward :
51
+ gy_scatter = torch .empty_like (scattered_output_warmup )
52
+ scattered_output_warmup .backward (gy_scatter )
53
+
54
+ print (
55
+ "---"
56
+ )
57
+ print (
58
+ "Testing Scatter/Gather -- input shape:" ,
59
+ input_shape ,
60
+ "; index shape:" ,
61
+ index_shape ,
62
+ "; datatype:" ,
63
+ dtype ,
64
+ "; dim:" ,
65
+ dim_to_operate ,
66
+ "; backward:" ,
67
+ backward ,
68
+ )
69
+ print (
70
+ "---"
71
+ )
72
+
73
+ # Profiling phase
74
+ with profile (
75
+ activities = [ProfilerActivity .CPU , ProfilerActivity .XPU ], record_shapes = True
76
+ ) as prof :
77
+ for i in range (20 ):
78
+ # Gather operation
79
+ gathered_output = torch .gather (input_tensor , dim_to_operate , index_tensor )
80
+ if backward :
81
+ gy_gather = torch .empty_like (gathered_output )
82
+ gathered_output .backward (gy_gather )
83
+
84
+ # Scatter operation
85
+ # We clone input_tensor each time to avoid modifying the same tensor
86
+ # across iterations, which could affect profiling if in-place ops are used.
87
+ scattered_output = input_tensor .clone ().scatter_ (dim_to_operate , index_tensor , scatter_source )
88
+ if backward :
89
+ gy_scatter = torch .empty_like (scattered_output )
90
+ scattered_output .backward (gy_scatter )
91
+
92
+ print (prof .key_averages ().table (sort_by = "xpu_time_total" ))
0 commit comments