Skip to content

Commit a1febce

Browse files
authored
Merge pull request #82 from b0nes164/main
Add DeviceRadixSort: faster GPU sorting code
2 parents b38e133 + c303564 commit a1febce

File tree

5 files changed

+900
-129
lines changed

5 files changed

+900
-129
lines changed

package/Runtime/GpuSorting.cs

Lines changed: 89 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,23 @@
44

55
namespace GaussianSplatting.Runtime
66
{
7-
// GPU (uint key, uint payload) radix sort, originally based on code derived from AMD FidelityFX SDK:
8-
// Copyright © 2023 Advanced Micro Devices, Inc., MIT license
9-
// https://github.com/GPUOpen-Effects/FidelityFX-ParallelSort v1.1.1
7+
// GPU (uint key, uint payload) 8 bit-LSD radix sort, using reduce-then-scan
8+
// Copyright Thomas Smith 2023, MIT license
9+
// https://github.com/b0nes164/GPUSorting
10+
1011
public class GpuSorting
1112
{
12-
// These need to match constants in the compute shader
13-
const uint FFX_PARALLELSORT_ELEMENTS_PER_THREAD = 4;
14-
const uint FFX_PARALLELSORT_THREADGROUP_SIZE = 128;
15-
const int FFX_PARALLELSORT_SORT_BITS_PER_PASS = 4;
16-
const uint FFX_PARALLELSORT_SORT_BIN_COUNT = 1u << FFX_PARALLELSORT_SORT_BITS_PER_PASS;
17-
// The maximum number of thread groups to run in parallel. Modifying this value can help or hurt GPU occupancy,
18-
// but is very hardware class specific
19-
const uint FFX_PARALLELSORT_MAX_THREADGROUPS_TO_RUN = 800;
13+
//The size of a threadblock partition in the sort
14+
const uint DEVICE_RADIX_SORT_PARTITION_SIZE = 3840;
15+
16+
//The size of our radix in bits
17+
const uint DEVICE_RADIX_SORT_BITS = 8;
18+
19+
//Number of digits in our radix, 1 << DEVICE_RADIX_SORT_BITS
20+
const uint DEVICE_RADIX_SORT_RADIX = 256;
21+
22+
//Number of sorting passes required to sort a 32bit key, KEY_BITS / DEVICE_RADIX_SORT_BITS
23+
const uint DEVICE_RADIX_SORT_PASSES = 4;
2024

2125
public struct Args
2226
{
@@ -29,51 +33,48 @@ public struct Args
2933

3034
public struct SupportResources
3135
{
32-
public GraphicsBuffer sortScratchBuffer;
33-
public GraphicsBuffer payloadScratchBuffer;
34-
public GraphicsBuffer scratchBuffer;
35-
public GraphicsBuffer reducedScratchBuffer;
36+
public GraphicsBuffer altBuffer;
37+
public GraphicsBuffer altPayloadBuffer;
38+
public GraphicsBuffer passHistBuffer;
39+
public GraphicsBuffer globalHistBuffer;
3640

3741
public static SupportResources Load(uint count)
3842
{
39-
uint BlockSize = FFX_PARALLELSORT_ELEMENTS_PER_THREAD * FFX_PARALLELSORT_THREADGROUP_SIZE;
40-
uint NumBlocks = DivRoundUp(count, BlockSize);
41-
uint NumReducedBlocks = DivRoundUp(NumBlocks, BlockSize);
42-
43-
uint scratchBufferSize = FFX_PARALLELSORT_SORT_BIN_COUNT * NumBlocks;
44-
uint reduceScratchBufferSize = FFX_PARALLELSORT_SORT_BIN_COUNT * NumReducedBlocks;
43+
//This is threadBlocks * DEVICE_RADIX_SORT_RADIX
44+
uint scratchBufferSize = DivRoundUp(count, DEVICE_RADIX_SORT_PARTITION_SIZE) * DEVICE_RADIX_SORT_RADIX;
45+
uint reducedScratchBufferSize = DEVICE_RADIX_SORT_RADIX * DEVICE_RADIX_SORT_PASSES;
4546

4647
var target = GraphicsBuffer.Target.Structured;
4748
var resources = new SupportResources
4849
{
49-
sortScratchBuffer = new GraphicsBuffer(target, (int)count, 4) { name = "FfxSortSortScratch" },
50-
payloadScratchBuffer = new GraphicsBuffer(target, (int)count, 4) { name = "FfxSortPayloadScratch" },
51-
scratchBuffer = new GraphicsBuffer(target, (int)scratchBufferSize, 4) { name = "FfxSortScratch" },
52-
reducedScratchBuffer = new GraphicsBuffer(target, (int)reduceScratchBufferSize, 4) { name = "FfxSortReducedScratch" },
50+
altBuffer = new GraphicsBuffer(target, (int)count, 4) { name = "DeviceRadixAlt" },
51+
altPayloadBuffer = new GraphicsBuffer(target, (int)count, 4) { name = "DeviceRadixAltPayload" },
52+
passHistBuffer = new GraphicsBuffer(target, (int)scratchBufferSize, 4) { name = "DeviceRadixPassHistogram" },
53+
globalHistBuffer = new GraphicsBuffer(target, (int)reducedScratchBufferSize, 4) { name = "DeviceRadixGlobalHistogram" },
5354
};
5455
return resources;
5556
}
5657

5758
public void Dispose()
5859
{
59-
sortScratchBuffer?.Dispose();
60-
payloadScratchBuffer?.Dispose();
61-
scratchBuffer?.Dispose();
62-
reducedScratchBuffer?.Dispose();
63-
64-
sortScratchBuffer = null;
65-
payloadScratchBuffer = null;
66-
scratchBuffer = null;
67-
reducedScratchBuffer = null;
60+
altBuffer?.Dispose();
61+
altPayloadBuffer?.Dispose();
62+
passHistBuffer?.Dispose();
63+
globalHistBuffer?.Dispose();
64+
65+
altBuffer = null;
66+
altPayloadBuffer = null;
67+
passHistBuffer = null;
68+
globalHistBuffer = null;
6869
}
6970
}
7071

7172
readonly ComputeShader m_CS;
72-
readonly int m_KernelReduce = -1;
73-
readonly int m_KernelScanAdd = -1;
74-
readonly int m_KernelScan = -1;
75-
readonly int m_KernelScatter = -1;
76-
readonly int m_KernelSum = -1;
73+
readonly int m_kernelInitDeviceRadixSort = -1;
74+
readonly int m_kernelUpsweep = -1;
75+
readonly int m_kernelScan = -1;
76+
readonly int m_kernelDownsweep = -1;
77+
7778
readonly bool m_Valid;
7879

7980
public bool Valid => m_Valid;
@@ -83,25 +84,22 @@ public GpuSorting(ComputeShader cs)
8384
m_CS = cs;
8485
if (cs)
8586
{
86-
m_KernelReduce = cs.FindKernel("FfxParallelSortReduce");
87-
m_KernelScanAdd = cs.FindKernel("FfxParallelSortScanAdd");
88-
m_KernelScan = cs.FindKernel("FfxParallelSortScan");
89-
m_KernelScatter = cs.FindKernel("FfxParallelSortScatter");
90-
m_KernelSum = cs.FindKernel("FfxParallelSortCount");
87+
m_kernelInitDeviceRadixSort = cs.FindKernel("InitDeviceRadixSort");
88+
m_kernelUpsweep = cs.FindKernel("Upsweep");
89+
m_kernelScan = cs.FindKernel("Scan");
90+
m_kernelDownsweep = cs.FindKernel("Downsweep");
9191
}
9292

93-
m_Valid = m_KernelReduce >= 0 &&
94-
m_KernelScanAdd >= 0 &&
95-
m_KernelScan >= 0 &&
96-
m_KernelScatter >= 0 &&
97-
m_KernelSum >= 0;
93+
m_Valid = m_kernelInitDeviceRadixSort >= 0 &&
94+
m_kernelUpsweep >= 0 &&
95+
m_kernelScan >= 0 &&
96+
m_kernelDownsweep >= 0;
9897
if (m_Valid)
9998
{
100-
if (!cs.IsSupported(m_KernelReduce) ||
101-
!cs.IsSupported(m_KernelScanAdd) ||
102-
!cs.IsSupported(m_KernelScan) ||
103-
!cs.IsSupported(m_KernelScatter) ||
104-
!cs.IsSupported(m_KernelSum))
99+
if (!cs.IsSupported(m_kernelInitDeviceRadixSort) ||
100+
!cs.IsSupported(m_kernelUpsweep) ||
101+
!cs.IsSupported(m_kernelScan) ||
102+
!cs.IsSupported(m_kernelDownsweep))
105103
{
106104
m_Valid = false;
107105
}
@@ -110,16 +108,13 @@ public GpuSorting(ComputeShader cs)
110108

111109
static uint DivRoundUp(uint x, uint y) => (x + y - 1) / y;
112110

111+
//Can we remove the last 4 padding without breaking?
113112
struct SortConstants
114113
{
115-
public uint numKeys; // The number of keys to sort
116-
public uint numBlocksPerThreadGroup; // How many blocks of keys each thread group needs to process
117-
public uint numThreadGroups; // How many thread groups are being run concurrently for sort
118-
public uint numThreadGroupsWithAdditionalBlocks; // How many thread groups need to process additional block data
119-
public uint numReduceThreadgroupPerBin; // How many thread groups are summed together for each reduced bin entry
120-
public uint numScanValues; // How many values to perform scan prefix (+ add) on
121-
public uint shift; // What bits are being sorted (4 bit increments)
122-
public uint padding; // Padding - unused
114+
public uint numKeys; // The number of keys to sort
115+
public uint radixShift; // The radix shift value for the current pass
116+
public uint threadBlocks; // threadBlocks
117+
public uint padding0; // Padding - unused
123118
}
124119

125120
public void Dispatch(CommandBuffer cmd, Args args)
@@ -128,78 +123,51 @@ public void Dispatch(CommandBuffer cmd, Args args)
128123

129124
GraphicsBuffer srcKeyBuffer = args.inputKeys;
130125
GraphicsBuffer srcPayloadBuffer = args.inputValues;
131-
GraphicsBuffer dstKeyBuffer = args.resources.sortScratchBuffer;
132-
GraphicsBuffer dstPayloadBuffer = args.resources.payloadScratchBuffer;
126+
GraphicsBuffer dstKeyBuffer = args.resources.altBuffer;
127+
GraphicsBuffer dstPayloadBuffer = args.resources.altPayloadBuffer;
133128

134-
// Initialize constants for the sort job
135129
SortConstants constants = default;
136130
constants.numKeys = args.count;
131+
constants.threadBlocks = DivRoundUp(args.count, DEVICE_RADIX_SORT_PARTITION_SIZE);
137132

138-
uint BlockSize = FFX_PARALLELSORT_ELEMENTS_PER_THREAD * FFX_PARALLELSORT_THREADGROUP_SIZE;
139-
uint NumBlocks = DivRoundUp(args.count, BlockSize);
133+
// Setup overall constants
134+
cmd.SetComputeIntParam(m_CS, "e_numKeys", (int)constants.numKeys);
135+
cmd.SetComputeIntParam(m_CS, "e_threadBlocks", (int)constants.threadBlocks);
140136

141-
// Figure out data distribution
142-
uint numThreadGroupsToRun = FFX_PARALLELSORT_MAX_THREADGROUPS_TO_RUN;
143-
uint BlocksPerThreadGroup = (NumBlocks / numThreadGroupsToRun);
144-
constants.numThreadGroupsWithAdditionalBlocks = NumBlocks % numThreadGroupsToRun;
137+
//Set statically located buffers
138+
//Upsweep
139+
cmd.SetComputeBufferParam(m_CS, m_kernelUpsweep, "b_passHist", args.resources.passHistBuffer);
140+
cmd.SetComputeBufferParam(m_CS, m_kernelUpsweep, "b_globalHist", args.resources.globalHistBuffer);
145141

146-
if (NumBlocks < numThreadGroupsToRun)
147-
{
148-
BlocksPerThreadGroup = 1;
149-
numThreadGroupsToRun = NumBlocks;
150-
constants.numThreadGroupsWithAdditionalBlocks = 0;
151-
}
142+
//Scan
143+
cmd.SetComputeBufferParam(m_CS, m_kernelScan, "b_passHist", args.resources.passHistBuffer);
152144

153-
constants.numThreadGroups = numThreadGroupsToRun;
154-
constants.numBlocksPerThreadGroup = BlocksPerThreadGroup;
145+
//Downsweep
146+
cmd.SetComputeBufferParam(m_CS, m_kernelDownsweep, "b_passHist", args.resources.passHistBuffer);
147+
cmd.SetComputeBufferParam(m_CS, m_kernelDownsweep, "b_globalHist", args.resources.globalHistBuffer);
155148

156-
// Calculate the number of thread groups to run for reduction (each thread group can process BlockSize number of entries)
157-
uint numReducedThreadGroupsToRun = FFX_PARALLELSORT_SORT_BIN_COUNT * ((BlockSize > numThreadGroupsToRun) ? 1 : (numThreadGroupsToRun + BlockSize - 1) / BlockSize);
158-
constants.numReduceThreadgroupPerBin = numReducedThreadGroupsToRun / FFX_PARALLELSORT_SORT_BIN_COUNT;
159-
constants.numScanValues = numReducedThreadGroupsToRun; // The number of reduce thread groups becomes our scan count (as each thread group writes out 1 value that needs scan prefix)
149+
//Clear the global histogram
150+
cmd.SetComputeBufferParam(m_CS, m_kernelInitDeviceRadixSort, "b_globalHist", args.resources.globalHistBuffer);
151+
cmd.DispatchCompute(m_CS, m_kernelInitDeviceRadixSort, 1, 1, 1);
160152

161-
// Setup overall constants
162-
cmd.SetComputeIntParam(m_CS, "numKeys", (int)constants.numKeys);
163-
cmd.SetComputeIntParam(m_CS, "numBlocksPerThreadGroup", (int)constants.numBlocksPerThreadGroup);
164-
cmd.SetComputeIntParam(m_CS, "numThreadGroups", (int)constants.numThreadGroups);
165-
cmd.SetComputeIntParam(m_CS, "numThreadGroupsWithAdditionalBlocks", (int)constants.numThreadGroupsWithAdditionalBlocks);
166-
cmd.SetComputeIntParam(m_CS, "numReduceThreadgroupPerBin", (int)constants.numReduceThreadgroupPerBin);
167-
cmd.SetComputeIntParam(m_CS, "numScanValues", (int)constants.numScanValues);
168-
169-
// Execute the sort algorithm in 4-bit increments
170-
constants.shift = 0;
171-
for (uint i = 0; constants.shift < 32; constants.shift += FFX_PARALLELSORT_SORT_BITS_PER_PASS, ++i)
153+
// Execute the sort algorithm in 8-bit increments
154+
for (constants.radixShift = 0; constants.radixShift < 32; constants.radixShift += DEVICE_RADIX_SORT_BITS)
172155
{
173-
cmd.SetComputeIntParam(m_CS, "shift", (int)constants.shift);
174-
175-
// Sum
176-
cmd.SetComputeBufferParam(m_CS, m_KernelSum, "rw_source_keys", srcKeyBuffer);
177-
cmd.SetComputeBufferParam(m_CS, m_KernelSum, "rw_sum_table", args.resources.scratchBuffer);
178-
cmd.DispatchCompute(m_CS, m_KernelSum, (int)numThreadGroupsToRun, 1, 1);
156+
cmd.SetComputeIntParam(m_CS, "e_radixShift", (int)constants.radixShift);
179157

180-
// Reduce
181-
cmd.SetComputeBufferParam(m_CS, m_KernelReduce, "rw_sum_table", args.resources.scratchBuffer);
182-
cmd.SetComputeBufferParam(m_CS, m_KernelReduce, "rw_reduce_table", args.resources.reducedScratchBuffer);
183-
cmd.DispatchCompute(m_CS, m_KernelReduce, (int)numReducedThreadGroupsToRun, 1, 1);
158+
//Upsweep
159+
cmd.SetComputeBufferParam(m_CS, m_kernelUpsweep, "b_sort", srcKeyBuffer);
160+
cmd.DispatchCompute(m_CS, m_kernelUpsweep, (int)constants.threadBlocks, 1, 1);
184161

185162
// Scan
186-
cmd.SetComputeBufferParam(m_CS, m_KernelScan, "rw_scan_source", args.resources.reducedScratchBuffer);
187-
cmd.SetComputeBufferParam(m_CS, m_KernelScan, "rw_scan_dest", args.resources.reducedScratchBuffer);
188-
cmd.DispatchCompute(m_CS, m_KernelScan, 1, 1, 1);
189-
190-
// Scan add
191-
cmd.SetComputeBufferParam(m_CS, m_KernelScanAdd, "rw_scan_source", args.resources.scratchBuffer);
192-
cmd.SetComputeBufferParam(m_CS, m_KernelScanAdd, "rw_scan_dest", args.resources.scratchBuffer);
193-
cmd.SetComputeBufferParam(m_CS, m_KernelScanAdd, "rw_scan_scratch", args.resources.reducedScratchBuffer);
194-
cmd.DispatchCompute(m_CS, m_KernelScanAdd, (int)numReducedThreadGroupsToRun, 1, 1);
195-
196-
// Scatter
197-
cmd.SetComputeBufferParam(m_CS, m_KernelScatter, "rw_source_keys", srcKeyBuffer);
198-
cmd.SetComputeBufferParam(m_CS, m_KernelScatter, "rw_dest_keys", dstKeyBuffer);
199-
cmd.SetComputeBufferParam(m_CS, m_KernelScatter, "rw_sum_table", args.resources.scratchBuffer);
200-
cmd.SetComputeBufferParam(m_CS, m_KernelScatter, "rw_source_payloads", srcPayloadBuffer);
201-
cmd.SetComputeBufferParam(m_CS, m_KernelScatter, "rw_dest_payloads", dstPayloadBuffer);
202-
cmd.DispatchCompute(m_CS, m_KernelScatter, (int)numThreadGroupsToRun, 1, 1);
163+
cmd.DispatchCompute(m_CS, m_kernelScan, (int)DEVICE_RADIX_SORT_RADIX, 1, 1);
164+
165+
// Downsweep
166+
cmd.SetComputeBufferParam(m_CS, m_kernelDownsweep, "b_sort", srcKeyBuffer);
167+
cmd.SetComputeBufferParam(m_CS, m_kernelDownsweep, "b_sortPayload", srcPayloadBuffer);
168+
cmd.SetComputeBufferParam(m_CS, m_kernelDownsweep, "b_alt", dstKeyBuffer);
169+
cmd.SetComputeBufferParam(m_CS, m_kernelDownsweep, "b_altPayload", dstPayloadBuffer);
170+
cmd.DispatchCompute(m_CS, m_kernelDownsweep, (int)constants.threadBlocks, 1, 1);
203171

204172
// Swap
205173
(srcKeyBuffer, dstKeyBuffer) = (dstKeyBuffer, srcKeyBuffer);

0 commit comments

Comments
 (0)