Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit dbf5d29

Browse files
hanzhi713alexm-redhat
authored andcommitted
Implement custom all reduce kernels (vllm-project#2192)
1 parent eb28215 commit dbf5d29

18 files changed

+1456
-66
lines changed

csrc/custom_all_reduce.cu

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#include <ATen/cuda/Exceptions.h>
2+
#include <c10/cuda/CUDAGuard.h>
3+
#include <c10/cuda/CUDAStream.h>
4+
#include <torch/extension.h>
5+
6+
#include "custom_all_reduce.cuh"
7+
8+
// fake pointer type
9+
using fptr_t = uint64_t;
10+
static_assert(sizeof(void *) == sizeof(fptr_t));
11+
12+
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
13+
const std::vector<std::string> &handles,
14+
const std::vector<int64_t> &offsets, int rank,
15+
bool full_nvlink) {
16+
int world_size = offsets.size();
17+
if (world_size > 8)
18+
throw std::invalid_argument("world size > 8 is not supported");
19+
if (world_size % 2 != 0)
20+
throw std::invalid_argument("Odd num gpus is not supported for now");
21+
if (world_size != handles.size())
22+
throw std::invalid_argument(
23+
"handles length should equal to offsets length");
24+
if (rank < 0 || rank >= world_size)
25+
throw std::invalid_argument("invalid rank passed in");
26+
27+
cudaIpcMemHandle_t ipc_handles[8];
28+
for (int i = 0; i < world_size; i++) {
29+
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
30+
}
31+
return (fptr_t) new vllm::CustomAllreduce(
32+
reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
33+
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
34+
}
35+
36+
/**
37+
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
38+
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
39+
* because it allows transpose of contiguous slice (i.e. slicing the first
40+
* dimension). Currently, we require this because stride information is not
41+
* passed into the kernels and we treat input tensors as flat.
42+
*
43+
* Examples
44+
* A = torch.zeros(3, 3, 3)
45+
* 1. A: OK
46+
* 2. A[1:]: OK
47+
* 3. A.permute(2, 0, 1): OK
48+
* 4. A[1:].permute(2, 0, 1): OK
49+
* 5. A[None].expand(2, -1, -1, -1): Not OK
50+
* 6. A[:, 1:, 1:]: Not OK
51+
*/
52+
bool _is_weak_contiguous(torch::Tensor &t) {
53+
return t.is_contiguous() ||
54+
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
55+
t.numel() * t.element_size());
56+
}
57+
58+
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
59+
bool full_nvlink) {
60+
auto inp_size = inp.numel() * inp.element_size();
61+
// custom allreduce requires input byte size to be multiples of 16
62+
if (inp_size % 16 != 0) return false;
63+
if (!_is_weak_contiguous(inp)) return false;
64+
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
65+
// 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
66+
// <= 512k
67+
return world_size <= 4 && inp_size <= 512 * 1024;
68+
}
69+
70+
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
71+
cudaStream_t stream) {
72+
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
73+
TORCH_CHECK(_is_weak_contiguous(out));
74+
switch (out.scalar_type()) {
75+
case at::ScalarType::Float: {
76+
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
77+
reinterpret_cast<float *>(out.data_ptr()),
78+
out.numel());
79+
break;
80+
}
81+
case at::ScalarType::Half: {
82+
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
83+
reinterpret_cast<half *>(out.data_ptr()),
84+
out.numel());
85+
break;
86+
}
87+
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
88+
case at::ScalarType::BFloat16: {
89+
fa->allreduce<nv_bfloat16>(
90+
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
91+
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
92+
break;
93+
}
94+
#endif
95+
default:
96+
throw std::runtime_error(
97+
"custom allreduce only supports float32, float16 and bfloat16");
98+
}
99+
}
100+
101+
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
102+
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
103+
auto stream = c10::cuda::getCurrentCUDAStream().stream();
104+
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
105+
TORCH_CHECK_EQ(inp.numel(), out.numel());
106+
_all_reduce(_fa, inp, out, stream);
107+
}
108+
109+
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
110+
torch::Tensor &out) {
111+
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
112+
auto stream = c10::cuda::getCurrentCUDAStream().stream();
113+
114+
auto input_size = inp.numel() * inp.element_size();
115+
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
116+
TORCH_CHECK_EQ(inp.numel(), out.numel());
117+
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
118+
"registered buffer is too small to contain the input");
119+
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
120+
input_size, cudaMemcpyDeviceToDevice, stream));
121+
_all_reduce(_fa, reg_buffer, out, stream);
122+
}
123+
124+
void dispose(fptr_t _fa) {
125+
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
126+
delete fa;
127+
}
128+
129+
int meta_size() { return sizeof(vllm::Metadata); }
130+
131+
void register_buffer(fptr_t _fa, torch::Tensor &t,
132+
const std::vector<std::string> &handles,
133+
const std::vector<int64_t> &offsets) {
134+
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
135+
fa->register_buffer(handles, offsets, t.data_ptr());
136+
}
137+
138+
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
139+
fptr_t _fa) {
140+
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
141+
return fa->get_graph_buffer_ipc_meta();
142+
}
143+
144+
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
145+
const std::vector<std::vector<int64_t>> &offsets) {
146+
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
147+
fa->register_graph_buffers(handles, offsets);
148+
}

0 commit comments

Comments
 (0)