Skip to content

Fix cross scope of function call, parameter size limit and performance impact of large capture by using functor. #1398

Open
@CaoZhongZ

Description

@CaoZhongZ

Considering following case in DeepSpeed kernel, a global function template with parameter pack

template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(
  int chunk_size, volatile int* noop_flag, T tl, U callable, ArgTypes... args) {
  callable(chunk_size, noop_flag, tl, args...);
}
...
  multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
      chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);

Current translation would be like:

queue.parallel_for(...,
[=] (sycl::nd_item<3> item) {
  multi_tensor_apply_kernel(chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
});

If transform with minimum lexical change, noop_flag.DATA_PTR<int>() function call will cross scope from host to target while not in CUDA. noop_flag happens to be a torch::tensor object which is not device copiable. Even if we introduce temporary variable and move the call out of lambda, the object tl (tensor list) will result capture size larger than 2048 bytes limit. A simple case hit every limitations we have.

We suggest that translate CUDA global function (template) use explicit functor (template) instead lambda capture. The kernel launching code would be like:

  if constexpr (sizeof multi_tensor_apply_kernel(
       chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...) < 2048 ) {
    queue.parallel_for(...,
      multi_tensor_apply_kernel(
         chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...));
  } else {
    auto capture = multi_tensor_apply_kernel(chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
    sycl::buffer params(const_cast<const decltype(capture) *>(&capture, sycl::range<1>(1)));

    stream.submit([&] (sycl::handler &cgh) {
      auto device_params = params.template get_access<
        sycl::access_mode::read, sycl::target::constant_buffer>(cgh);
      cgh.parallel_for(...,
        [=] (sycl::nd_item<3> item) {
          device_params[0](item);
        });
    });
  }

All substitutions in the translation are localized. Manual example can be found at:
https://github.com/CaoZhongZ/sycl_compiler_test/blob/global_call_migrate/deepspeed/global_call_migrate.cpp

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions