Skip to content

Commit d5ce77a

Browse files
committed
Move cublas_handle into own header file
1 parent 1219fd6 commit d5ce77a

File tree

5 files changed

+64
-47
lines changed

5 files changed

+64
-47
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/***************************************************************************
2+
* Copyright (C) Codeplay Software Limited
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* For your convenience, a copy of the License has been included in this
10+
* repository.
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*
18+
**************************************************************************/
19+
#ifndef CUBLAS_HANDLE_HIPSYCL_HPP
20+
#define CUBLAS_HANDLE_HIPSYCL_HPP
21+
#include<atomic>
22+
#include<unordered_map>
23+
24+
namespace oneapi {
25+
namespace mkl {
26+
namespace blas {
27+
namespace cublas {
28+
29+
template<typename T>
30+
struct cublas_handle {
31+
using handle_container_t = std::unordered_map<T, std::atomic<cublasHandle_t> *>;
32+
handle_container_t cublas_handle_mapper_{};
33+
~cublas_handle() noexcept(false){
34+
for (auto &handle_pair : cublas_handle_mapper_) {
35+
cublasStatus_t err;
36+
if (handle_pair.second != nullptr) {
37+
auto handle = handle_pair.second->exchange(nullptr);
38+
if (handle != nullptr) {
39+
CUBLAS_ERROR_FUNC(cublasDestroy, err, handle);
40+
handle = nullptr;
41+
}
42+
delete handle_pair.second;
43+
handle_pair.second = nullptr;
44+
}
45+
}
46+
cublas_handle_mapper_.clear();
47+
}
48+
};
49+
50+
51+
} // namespace cublas
52+
} // namespace blas
53+
} // namespace mkl
54+
} // namespace oneapi
55+
56+
#endif // CUBLAS_HANDLE_HIPSYCL_HPP

src/blas/backends/cublas/cublas_scope_handle.cpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,14 @@ namespace mkl {
2424
namespace blas {
2525
namespace cublas {
2626

27-
cublas_handle::~cublas_handle() noexcept(false) {
28-
for (auto &handle_pair : cublas_handle_mapper_) {
29-
cublasStatus_t err;
30-
if (handle_pair.second != nullptr) {
31-
auto handle = handle_pair.second->exchange(nullptr);
32-
if (handle != nullptr) {
33-
CUBLAS_ERROR_FUNC(cublasDestroy, err, handle);
34-
handle = nullptr;
35-
}
36-
delete handle_pair.second;
37-
handle_pair.second = nullptr;
38-
}
39-
}
40-
cublas_handle_mapper_.clear();
41-
}
4227
/**
4328
* Inserts a new element in the map if its key is unique. This new element
4429
* is constructed in place using args as the arguments for the construction
4530
* of a value_type (which is an object of a pair type). The insertion only
4631
* takes place if no other element in the container has a key equivalent to
4732
* the one being emplaced (keys in a map container are unique).
4833
*/
49-
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};
34+
thread_local cublas_handle<pi_context> CublasScopedContextHandler::handle_helper = cublas_handle<pi_context>{};
5035

5136
CublasScopedContextHandler::CublasScopedContextHandler(cl::sycl::queue queue,
5237
cl::sycl::interop_handler &ih)

src/blas/backends/cublas/cublas_scope_handle.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,13 @@
2727
#include <thread>
2828
#include <unordered_map>
2929
#include "cublas_helper.hpp"
30+
#include "cublas_handle.hpp"
31+
3032
namespace oneapi {
3133
namespace mkl {
3234
namespace blas {
3335
namespace cublas {
3436

35-
struct cublas_handle {
36-
using handle_container_t = std::unordered_map<pi_context, std::atomic<cublasHandle_t> *>;
37-
handle_container_t cublas_handle_mapper_{};
38-
~cublas_handle() noexcept(false);
39-
};
40-
4137
/**
4238
* @brief NVIDIA advise for handle creation:
4339
https://devtalk.nvidia.com/default/topic/838794/gpu-accelerated libraries/using-cublas-in-different-cuda-streams/
@@ -69,7 +65,7 @@ class CublasScopedContextHandler {
6965
cl::sycl::context placedContext_;
7066
bool needToRecover_;
7167
cl::sycl::interop_handler &ih;
72-
static thread_local cublas_handle handle_helper;
68+
static thread_local cublas_handle<pi_context> handle_helper;
7369
CUstream get_stream(const cl::sycl::queue &queue);
7470
cl::sycl::context get_context(const cl::sycl::queue &queue);
7571

src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,12 @@
11
#include "cublas_scope_handle_hipsycl.hpp"
2+
#include "cublas_handle.hpp"
23

34
namespace oneapi {
45
namespace mkl {
56
namespace blas {
67
namespace cublas {
78

8-
cublas_handle::~cublas_handle() noexcept(false) {
9-
for (auto &handle_pair : cublas_handle_mapper_) {
10-
cublasStatus_t err;
11-
if (handle_pair.second != nullptr) {
12-
auto handle = handle_pair.second->exchange(nullptr);
13-
if (handle != nullptr) {
14-
CUBLAS_ERROR_FUNC(cublasDestroy, err, handle);
15-
handle = nullptr;
16-
}
17-
delete handle_pair.second;
18-
handle_pair.second = nullptr;
19-
}
20-
}
21-
cublas_handle_mapper_.clear();
22-
}
23-
24-
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};
9+
thread_local cublas_handle<int> CublasScopedContextHandler::handle_helper = cublas_handle<int>{};
2510

2611
CublasScopedContextHandler::CublasScopedContextHandler(cl::sycl::queue queue,
2712
cl::sycl::interop_handle &ih)

src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,12 @@
2323
#include <thread>
2424
#include <unordered_map>
2525
#include "cublas_helper.hpp"
26+
#include "cublas_handle.hpp"
2627
namespace oneapi {
2728
namespace mkl {
2829
namespace blas {
2930
namespace cublas {
3031

31-
struct cublas_handle {
32-
using handle_container_t = std::unordered_map<int, std::atomic<cublasHandle_t> *>;
33-
handle_container_t cublas_handle_mapper_{};
34-
~cublas_handle() noexcept(false);
35-
};
36-
3732
/**
3833
* @brief NVIDIA advise for handle creation:
3934
https://devtalk.nvidia.com/default/topic/838794/gpu-accelerated libraries/using-cublas-in-different-cuda-streams/
@@ -61,7 +56,7 @@ the handle must be destroyed when the context goes out of scope. This will bind
6156

6257
class CublasScopedContextHandler {
6358
cl::sycl::interop_handle interop_h;
64-
static thread_local cublas_handle handle_helper;
59+
static thread_local cublas_handle<int> handle_helper;
6560
cl::sycl::context get_context(const cl::sycl::queue &queue);
6661
CUstream get_stream(const cl::sycl::queue &queue);
6762

0 commit comments

Comments
 (0)