4
4
5
5
namespace GaussianSplatting . Runtime
6
6
{
7
- // GPU (uint key, uint payload) radix sort, originally based on code derived from AMD FidelityFX SDK:
8
- // Copyright © 2023 Advanced Micro Devices, Inc., MIT license
9
- // https://github.com/GPUOpen-Effects/FidelityFX-ParallelSort v1.1.1
7
+ // GPU (uint key, uint payload) 8 bit-LSD radix sort, using reduce-then-scan
8
+ // Copyright Thomas Smith 2023, MIT license
9
+ // https://github.com/b0nes164/GPUSorting
10
+
10
11
public class GpuSorting
11
12
{
12
- // These need to match constants in the compute shader
13
- const uint FFX_PARALLELSORT_ELEMENTS_PER_THREAD = 4 ;
14
- const uint FFX_PARALLELSORT_THREADGROUP_SIZE = 128 ;
15
- const int FFX_PARALLELSORT_SORT_BITS_PER_PASS = 4 ;
16
- const uint FFX_PARALLELSORT_SORT_BIN_COUNT = 1u << FFX_PARALLELSORT_SORT_BITS_PER_PASS ;
17
- // The maximum number of thread groups to run in parallel. Modifying this value can help or hurt GPU occupancy,
18
- // but is very hardware class specific
19
- const uint FFX_PARALLELSORT_MAX_THREADGROUPS_TO_RUN = 800 ;
13
+ //The size of a threadblock partition in the sort
14
+ const uint DEVICE_RADIX_SORT_PARTITION_SIZE = 3840 ;
15
+
16
+ //The size of our radix in bits
17
+ const uint DEVICE_RADIX_SORT_BITS = 8 ;
18
+
19
+ //Number of digits in our radix, 1 << DEVICE_RADIX_SORT_BITS
20
+ const uint DEVICE_RADIX_SORT_RADIX = 256 ;
21
+
22
+ //Number of sorting passes required to sort a 32bit key, KEY_BITS / DEVICE_RADIX_SORT_BITS
23
+ const uint DEVICE_RADIX_SORT_PASSES = 4 ;
20
24
21
25
public struct Args
22
26
{
@@ -29,51 +33,48 @@ public struct Args
29
33
30
34
public struct SupportResources
31
35
{
32
- public GraphicsBuffer sortScratchBuffer ;
33
- public GraphicsBuffer payloadScratchBuffer ;
34
- public GraphicsBuffer scratchBuffer ;
35
- public GraphicsBuffer reducedScratchBuffer ;
36
+ public GraphicsBuffer altBuffer ;
37
+ public GraphicsBuffer altPayloadBuffer ;
38
+ public GraphicsBuffer passHistBuffer ;
39
+ public GraphicsBuffer globalHistBuffer ;
36
40
37
41
public static SupportResources Load ( uint count )
38
42
{
39
- uint BlockSize = FFX_PARALLELSORT_ELEMENTS_PER_THREAD * FFX_PARALLELSORT_THREADGROUP_SIZE ;
40
- uint NumBlocks = DivRoundUp ( count , BlockSize ) ;
41
- uint NumReducedBlocks = DivRoundUp ( NumBlocks , BlockSize ) ;
42
-
43
- uint scratchBufferSize = FFX_PARALLELSORT_SORT_BIN_COUNT * NumBlocks ;
44
- uint reduceScratchBufferSize = FFX_PARALLELSORT_SORT_BIN_COUNT * NumReducedBlocks ;
43
+ //This is threadBlocks * DEVICE_RADIX_SORT_RADIX
44
+ uint scratchBufferSize = DivRoundUp ( count , DEVICE_RADIX_SORT_PARTITION_SIZE ) * DEVICE_RADIX_SORT_RADIX ;
45
+ uint reducedScratchBufferSize = DEVICE_RADIX_SORT_RADIX * DEVICE_RADIX_SORT_PASSES ;
45
46
46
47
var target = GraphicsBuffer . Target . Structured ;
47
48
var resources = new SupportResources
48
49
{
49
- sortScratchBuffer = new GraphicsBuffer ( target , ( int ) count , 4 ) { name = "FfxSortSortScratch " } ,
50
- payloadScratchBuffer = new GraphicsBuffer ( target , ( int ) count , 4 ) { name = "FfxSortPayloadScratch " } ,
51
- scratchBuffer = new GraphicsBuffer ( target , ( int ) scratchBufferSize , 4 ) { name = "FfxSortScratch " } ,
52
- reducedScratchBuffer = new GraphicsBuffer ( target , ( int ) reduceScratchBufferSize , 4 ) { name = "FfxSortReducedScratch " } ,
50
+ altBuffer = new GraphicsBuffer ( target , ( int ) count , 4 ) { name = "DeviceRadixAlt " } ,
51
+ altPayloadBuffer = new GraphicsBuffer ( target , ( int ) count , 4 ) { name = "DeviceRadixAltPayload " } ,
52
+ passHistBuffer = new GraphicsBuffer ( target , ( int ) scratchBufferSize , 4 ) { name = "DeviceRadixPassHistogram " } ,
53
+ globalHistBuffer = new GraphicsBuffer ( target , ( int ) reducedScratchBufferSize , 4 ) { name = "DeviceRadixGlobalHistogram " } ,
53
54
} ;
54
55
return resources ;
55
56
}
56
57
57
58
public void Dispose ( )
58
59
{
59
- sortScratchBuffer ? . Dispose ( ) ;
60
- payloadScratchBuffer ? . Dispose ( ) ;
61
- scratchBuffer ? . Dispose ( ) ;
62
- reducedScratchBuffer ? . Dispose ( ) ;
63
-
64
- sortScratchBuffer = null ;
65
- payloadScratchBuffer = null ;
66
- scratchBuffer = null ;
67
- reducedScratchBuffer = null ;
60
+ altBuffer ? . Dispose ( ) ;
61
+ altPayloadBuffer ? . Dispose ( ) ;
62
+ passHistBuffer ? . Dispose ( ) ;
63
+ globalHistBuffer ? . Dispose ( ) ;
64
+
65
+ altBuffer = null ;
66
+ altPayloadBuffer = null ;
67
+ passHistBuffer = null ;
68
+ globalHistBuffer = null ;
68
69
}
69
70
}
70
71
71
72
readonly ComputeShader m_CS ;
72
- readonly int m_KernelReduce = - 1 ;
73
- readonly int m_KernelScanAdd = - 1 ;
74
- readonly int m_KernelScan = - 1 ;
75
- readonly int m_KernelScatter = - 1 ;
76
- readonly int m_KernelSum = - 1 ;
73
+ readonly int m_kernelInitDeviceRadixSort = - 1 ;
74
+ readonly int m_kernelUpsweep = - 1 ;
75
+ readonly int m_kernelScan = - 1 ;
76
+ readonly int m_kernelDownsweep = - 1 ;
77
+
77
78
readonly bool m_Valid ;
78
79
79
80
public bool Valid => m_Valid ;
@@ -83,25 +84,22 @@ public GpuSorting(ComputeShader cs)
83
84
m_CS = cs ;
84
85
if ( cs )
85
86
{
86
- m_KernelReduce = cs . FindKernel ( "FfxParallelSortReduce" ) ;
87
- m_KernelScanAdd = cs . FindKernel ( "FfxParallelSortScanAdd" ) ;
88
- m_KernelScan = cs . FindKernel ( "FfxParallelSortScan" ) ;
89
- m_KernelScatter = cs . FindKernel ( "FfxParallelSortScatter" ) ;
90
- m_KernelSum = cs . FindKernel ( "FfxParallelSortCount" ) ;
87
+ m_kernelInitDeviceRadixSort = cs . FindKernel ( "InitDeviceRadixSort" ) ;
88
+ m_kernelUpsweep = cs . FindKernel ( "Upsweep" ) ;
89
+ m_kernelScan = cs . FindKernel ( "Scan" ) ;
90
+ m_kernelDownsweep = cs . FindKernel ( "Downsweep" ) ;
91
91
}
92
92
93
- m_Valid = m_KernelReduce >= 0 &&
94
- m_KernelScanAdd >= 0 &&
95
- m_KernelScan >= 0 &&
96
- m_KernelScatter >= 0 &&
97
- m_KernelSum >= 0 ;
93
+ m_Valid = m_kernelInitDeviceRadixSort >= 0 &&
94
+ m_kernelUpsweep >= 0 &&
95
+ m_kernelScan >= 0 &&
96
+ m_kernelDownsweep >= 0 ;
98
97
if ( m_Valid )
99
98
{
100
- if ( ! cs . IsSupported ( m_KernelReduce ) ||
101
- ! cs . IsSupported ( m_KernelScanAdd ) ||
102
- ! cs . IsSupported ( m_KernelScan ) ||
103
- ! cs . IsSupported ( m_KernelScatter ) ||
104
- ! cs . IsSupported ( m_KernelSum ) )
99
+ if ( ! cs . IsSupported ( m_kernelInitDeviceRadixSort ) ||
100
+ ! cs . IsSupported ( m_kernelUpsweep ) ||
101
+ ! cs . IsSupported ( m_kernelScan ) ||
102
+ ! cs . IsSupported ( m_kernelDownsweep ) )
105
103
{
106
104
m_Valid = false ;
107
105
}
@@ -110,16 +108,13 @@ public GpuSorting(ComputeShader cs)
110
108
111
109
static uint DivRoundUp ( uint x , uint y ) => ( x + y - 1 ) / y ;
112
110
111
+ //Can we remove the last 4 padding without breaking?
113
112
struct SortConstants
114
113
{
115
- public uint numKeys ; // The number of keys to sort
116
- public uint numBlocksPerThreadGroup ; // How many blocks of keys each thread group needs to process
117
- public uint numThreadGroups ; // How many thread groups are being run concurrently for sort
118
- public uint numThreadGroupsWithAdditionalBlocks ; // How many thread groups need to process additional block data
119
- public uint numReduceThreadgroupPerBin ; // How many thread groups are summed together for each reduced bin entry
120
- public uint numScanValues ; // How many values to perform scan prefix (+ add) on
121
- public uint shift ; // What bits are being sorted (4 bit increments)
122
- public uint padding ; // Padding - unused
114
+ public uint numKeys ; // The number of keys to sort
115
+ public uint radixShift ; // The radix shift value for the current pass
116
+ public uint threadBlocks ; // threadBlocks
117
+ public uint padding0 ; // Padding - unused
123
118
}
124
119
125
120
public void Dispatch ( CommandBuffer cmd , Args args )
@@ -128,78 +123,51 @@ public void Dispatch(CommandBuffer cmd, Args args)
128
123
129
124
GraphicsBuffer srcKeyBuffer = args . inputKeys ;
130
125
GraphicsBuffer srcPayloadBuffer = args . inputValues ;
131
- GraphicsBuffer dstKeyBuffer = args . resources . sortScratchBuffer ;
132
- GraphicsBuffer dstPayloadBuffer = args . resources . payloadScratchBuffer ;
126
+ GraphicsBuffer dstKeyBuffer = args . resources . altBuffer ;
127
+ GraphicsBuffer dstPayloadBuffer = args . resources . altPayloadBuffer ;
133
128
134
- // Initialize constants for the sort job
135
129
SortConstants constants = default ;
136
130
constants . numKeys = args . count ;
131
+ constants . threadBlocks = DivRoundUp ( args . count , DEVICE_RADIX_SORT_PARTITION_SIZE ) ;
137
132
138
- uint BlockSize = FFX_PARALLELSORT_ELEMENTS_PER_THREAD * FFX_PARALLELSORT_THREADGROUP_SIZE ;
139
- uint NumBlocks = DivRoundUp ( args . count , BlockSize ) ;
133
+ // Setup overall constants
134
+ cmd . SetComputeIntParam ( m_CS , "e_numKeys" , ( int ) constants . numKeys ) ;
135
+ cmd . SetComputeIntParam ( m_CS , "e_threadBlocks" , ( int ) constants . threadBlocks ) ;
140
136
141
- // Figure out data distribution
142
- uint numThreadGroupsToRun = FFX_PARALLELSORT_MAX_THREADGROUPS_TO_RUN ;
143
- uint BlocksPerThreadGroup = ( NumBlocks / numThreadGroupsToRun ) ;
144
- constants . numThreadGroupsWithAdditionalBlocks = NumBlocks % numThreadGroupsToRun ;
137
+ //Set statically located buffers
138
+ //Upsweep
139
+ cmd . SetComputeBufferParam ( m_CS , m_kernelUpsweep , "b_passHist" , args . resources . passHistBuffer ) ;
140
+ cmd . SetComputeBufferParam ( m_CS , m_kernelUpsweep , "b_globalHist" , args . resources . globalHistBuffer ) ;
145
141
146
- if ( NumBlocks < numThreadGroupsToRun )
147
- {
148
- BlocksPerThreadGroup = 1 ;
149
- numThreadGroupsToRun = NumBlocks ;
150
- constants . numThreadGroupsWithAdditionalBlocks = 0 ;
151
- }
142
+ //Scan
143
+ cmd . SetComputeBufferParam ( m_CS , m_kernelScan , "b_passHist" , args . resources . passHistBuffer ) ;
152
144
153
- constants . numThreadGroups = numThreadGroupsToRun ;
154
- constants . numBlocksPerThreadGroup = BlocksPerThreadGroup ;
145
+ //Downsweep
146
+ cmd . SetComputeBufferParam ( m_CS , m_kernelDownsweep , "b_passHist" , args . resources . passHistBuffer ) ;
147
+ cmd . SetComputeBufferParam ( m_CS , m_kernelDownsweep , "b_globalHist" , args . resources . globalHistBuffer ) ;
155
148
156
- // Calculate the number of thread groups to run for reduction (each thread group can process BlockSize number of entries)
157
- uint numReducedThreadGroupsToRun = FFX_PARALLELSORT_SORT_BIN_COUNT * ( ( BlockSize > numThreadGroupsToRun ) ? 1 : ( numThreadGroupsToRun + BlockSize - 1 ) / BlockSize ) ;
158
- constants . numReduceThreadgroupPerBin = numReducedThreadGroupsToRun / FFX_PARALLELSORT_SORT_BIN_COUNT ;
159
- constants . numScanValues = numReducedThreadGroupsToRun ; // The number of reduce thread groups becomes our scan count (as each thread group writes out 1 value that needs scan prefix)
149
+ //Clear the global histogram
150
+ cmd . SetComputeBufferParam ( m_CS , m_kernelInitDeviceRadixSort , "b_globalHist" , args . resources . globalHistBuffer ) ;
151
+ cmd . DispatchCompute ( m_CS , m_kernelInitDeviceRadixSort , 1 , 1 , 1 ) ;
160
152
161
- // Setup overall constants
162
- cmd . SetComputeIntParam ( m_CS , "numKeys" , ( int ) constants . numKeys ) ;
163
- cmd . SetComputeIntParam ( m_CS , "numBlocksPerThreadGroup" , ( int ) constants . numBlocksPerThreadGroup ) ;
164
- cmd . SetComputeIntParam ( m_CS , "numThreadGroups" , ( int ) constants . numThreadGroups ) ;
165
- cmd . SetComputeIntParam ( m_CS , "numThreadGroupsWithAdditionalBlocks" , ( int ) constants . numThreadGroupsWithAdditionalBlocks ) ;
166
- cmd . SetComputeIntParam ( m_CS , "numReduceThreadgroupPerBin" , ( int ) constants . numReduceThreadgroupPerBin ) ;
167
- cmd . SetComputeIntParam ( m_CS , "numScanValues" , ( int ) constants . numScanValues ) ;
168
-
169
- // Execute the sort algorithm in 4-bit increments
170
- constants . shift = 0 ;
171
- for ( uint i = 0 ; constants . shift < 32 ; constants . shift += FFX_PARALLELSORT_SORT_BITS_PER_PASS , ++ i )
153
+ // Execute the sort algorithm in 8-bit increments
154
+ for ( constants . radixShift = 0 ; constants . radixShift < 32 ; constants . radixShift += DEVICE_RADIX_SORT_BITS )
172
155
{
173
- cmd . SetComputeIntParam ( m_CS , "shift" , ( int ) constants . shift ) ;
174
-
175
- // Sum
176
- cmd . SetComputeBufferParam ( m_CS , m_KernelSum , "rw_source_keys" , srcKeyBuffer ) ;
177
- cmd . SetComputeBufferParam ( m_CS , m_KernelSum , "rw_sum_table" , args . resources . scratchBuffer ) ;
178
- cmd . DispatchCompute ( m_CS , m_KernelSum , ( int ) numThreadGroupsToRun , 1 , 1 ) ;
156
+ cmd . SetComputeIntParam ( m_CS , "e_radixShift" , ( int ) constants . radixShift ) ;
179
157
180
- // Reduce
181
- cmd . SetComputeBufferParam ( m_CS , m_KernelReduce , "rw_sum_table" , args . resources . scratchBuffer ) ;
182
- cmd . SetComputeBufferParam ( m_CS , m_KernelReduce , "rw_reduce_table" , args . resources . reducedScratchBuffer ) ;
183
- cmd . DispatchCompute ( m_CS , m_KernelReduce , ( int ) numReducedThreadGroupsToRun , 1 , 1 ) ;
158
+ //Upsweep
159
+ cmd . SetComputeBufferParam ( m_CS , m_kernelUpsweep , "b_sort" , srcKeyBuffer ) ;
160
+ cmd . DispatchCompute ( m_CS , m_kernelUpsweep , ( int ) constants . threadBlocks , 1 , 1 ) ;
184
161
185
162
// Scan
186
- cmd . SetComputeBufferParam ( m_CS , m_KernelScan , "rw_scan_source" , args . resources . reducedScratchBuffer ) ;
187
- cmd . SetComputeBufferParam ( m_CS , m_KernelScan , "rw_scan_dest" , args . resources . reducedScratchBuffer ) ;
188
- cmd . DispatchCompute ( m_CS , m_KernelScan , 1 , 1 , 1 ) ;
189
-
190
- // Scan add
191
- cmd . SetComputeBufferParam ( m_CS , m_KernelScanAdd , "rw_scan_source" , args . resources . scratchBuffer ) ;
192
- cmd . SetComputeBufferParam ( m_CS , m_KernelScanAdd , "rw_scan_dest" , args . resources . scratchBuffer ) ;
193
- cmd . SetComputeBufferParam ( m_CS , m_KernelScanAdd , "rw_scan_scratch" , args . resources . reducedScratchBuffer ) ;
194
- cmd . DispatchCompute ( m_CS , m_KernelScanAdd , ( int ) numReducedThreadGroupsToRun , 1 , 1 ) ;
195
-
196
- // Scatter
197
- cmd . SetComputeBufferParam ( m_CS , m_KernelScatter , "rw_source_keys" , srcKeyBuffer ) ;
198
- cmd . SetComputeBufferParam ( m_CS , m_KernelScatter , "rw_dest_keys" , dstKeyBuffer ) ;
199
- cmd . SetComputeBufferParam ( m_CS , m_KernelScatter , "rw_sum_table" , args . resources . scratchBuffer ) ;
200
- cmd . SetComputeBufferParam ( m_CS , m_KernelScatter , "rw_source_payloads" , srcPayloadBuffer ) ;
201
- cmd . SetComputeBufferParam ( m_CS , m_KernelScatter , "rw_dest_payloads" , dstPayloadBuffer ) ;
202
- cmd . DispatchCompute ( m_CS , m_KernelScatter , ( int ) numThreadGroupsToRun , 1 , 1 ) ;
163
+ cmd . DispatchCompute ( m_CS , m_kernelScan , ( int ) DEVICE_RADIX_SORT_RADIX , 1 , 1 ) ;
164
+
165
+ // Downsweep
166
+ cmd . SetComputeBufferParam ( m_CS , m_kernelDownsweep , "b_sort" , srcKeyBuffer ) ;
167
+ cmd . SetComputeBufferParam ( m_CS , m_kernelDownsweep , "b_sortPayload" , srcPayloadBuffer ) ;
168
+ cmd . SetComputeBufferParam ( m_CS , m_kernelDownsweep , "b_alt" , dstKeyBuffer ) ;
169
+ cmd . SetComputeBufferParam ( m_CS , m_kernelDownsweep , "b_altPayload" , dstPayloadBuffer ) ;
170
+ cmd . DispatchCompute ( m_CS , m_kernelDownsweep , ( int ) constants . threadBlocks , 1 , 1 ) ;
203
171
204
172
// Swap
205
173
( srcKeyBuffer , dstKeyBuffer ) = ( dstKeyBuffer , srcKeyBuffer ) ;
0 commit comments