Description
Considering following case in DeepSpeed kernel, a global function template with parameter pack
template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(
int chunk_size, volatile int* noop_flag, T tl, U callable, ArgTypes... args) {
callable(chunk_size, noop_flag, tl, args...);
}
...
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
Current translation would be like:
queue.parallel_for(...,
[=] (sycl::nd_item<3> item) {
multi_tensor_apply_kernel(chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
});
If transform with minimum lexical change, noop_flag.DATA_PTR<int>()
function call will cross scope from host to target while not in CUDA. noop_flag
happens to be a torch::tensor object which is not device copiable. Even if we introduce temporary variable and move the call out of lambda, the object tl
(tensor list) will result capture size larger than 2048 bytes limit. A simple case hit every limitations we have.
We suggest that translate CUDA global function (template) use explicit functor (template) instead lambda capture. The kernel launching code would be like:
if constexpr (sizeof multi_tensor_apply_kernel(
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...) < 2048 ) {
queue.parallel_for(...,
multi_tensor_apply_kernel(
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...));
} else {
auto capture = multi_tensor_apply_kernel(chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
sycl::buffer params(const_cast<const decltype(capture) *>(&capture, sycl::range<1>(1)));
stream.submit([&] (sycl::handler &cgh) {
auto device_params = params.template get_access<
sycl::access_mode::read, sycl::target::constant_buffer>(cgh);
cgh.parallel_for(...,
[=] (sycl::nd_item<3> item) {
device_params[0](item);
});
});
}
All substitutions in the translation are localized. Manual example can be found at:
https://github.com/CaoZhongZ/sycl_compiler_test/blob/global_call_migrate/deepspeed/global_call_migrate.cpp