Skip to content

Commit 1855208

Browse files
Adding synchronization mechanism for host_task and updates
1 parent 8a105a3 commit 1855208

File tree

3 files changed

+35
-28
lines changed

3 files changed

+35
-28
lines changed

cmake/FindLAPACKE.cmake

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919

2020
include_guard()
2121

22-
find_library(LAPACKE64_file NAMES lapacke64.dll.lib lapacke64.lib lapacke64 HINTS ${REF_LAPACK_ROOT} PATH_SUFFIXES lib lib64)
22+
find_library(LAPACKE64_file NAMES lapacke64.dll.lib lapacke64.lib lapacke HINTS ${REF_LAPACK_ROOT} PATH_SUFFIXES lib lib64)
2323
find_package_handle_standard_args(LAPACKE REQUIRED_VARS LAPACKE64_file)
24-
find_library(LAPACK64_file NAMES lapack64.dll.lib lapack64.lib lapack64 HINTS ${REF_LAPACK_ROOT} PATH_SUFFIXES lib lib64)
24+
find_library(LAPACK64_file NAMES lapack64.dll.lib lapack64.lib lapack HINTS ${REF_LAPACK_ROOT} PATH_SUFFIXES lib lib64)
2525
find_package_handle_standard_args(LAPACKE REQUIRED_VARS LAPACK64_file)
26-
find_library(CBLAS64_file NAMES cblas64.dll.lib cblas64.lib cblas64 HINTS ${REF_LAPACK_ROOT} PATH_SUFFIXES lib lib64)
26+
find_library(CBLAS64_file NAMES cblas64.dll.lib cblas64.lib cblas HINTS ${REF_LAPACK_ROOT} PATH_SUFFIXES lib lib64)
2727
find_package_handle_standard_args(LAPACKE REQUIRED_VARS CBLAS64_file)
28-
find_library(BLAS64_file NAMES blas64.dll.lib blas64.lib blas64 HINTS ${REF_LAPACK_ROOT} PATH_SUFFIXES lib lib64)
28+
find_library(BLAS64_file NAMES blas64.dll.lib blas64.lib blas HINTS ${REF_LAPACK_ROOT} PATH_SUFFIXES lib lib64)
2929
find_package_handle_standard_args(LAPACKE REQUIRED_VARS BLAS64_file)
3030

3131
get_filename_component(LAPACKE64_LIB_DIR ${LAPACKE64_file} DIRECTORY)

src/lapack/backends/rocsolver/rocsolver_lapack.cpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2447,7 +2447,7 @@ inline void gebrd_scratchpad_size(const char *func_name, Func func, sycl::queue
24472447
template <> \
24482448
std::int64_t gebrd_scratchpad_size<TYPE>(sycl::queue & queue, std::int64_t m, std::int64_t n, \
24492449
std::int64_t lda) { \
2450-
return 1; \
2450+
return 0; \
24512451
}
24522452

24532453
GEBRD_LAUNCHER_SCRATCH(float, rocsolverDnSgebrd_bufferSize)
@@ -2495,7 +2495,7 @@ inline void geqrf_scratchpad_size(const char *func_name, Func func, sycl::queue
24952495
template <> \
24962496
std::int64_t geqrf_scratchpad_size<TYPE>(sycl::queue & queue, std::int64_t m, std::int64_t n, \
24972497
std::int64_t lda) { \
2498-
return 1; \
2498+
return 0; \
24992499
}
25002500

25012501
GEQRF_LAUNCHER_SCRATCH(float, rocsolverDnSgeqrf_bufferSize)
@@ -2525,7 +2525,7 @@ inline void gesvd_scratchpad_size(const char *func_name, Func func, sycl::queue
25252525
std::int64_t gesvd_scratchpad_size<TYPE>( \
25262526
sycl::queue & queue, oneapi::mkl::jobsvd jobu, oneapi::mkl::jobsvd jobvt, std::int64_t m, \
25272527
std::int64_t n, std::int64_t lda, std::int64_t ldu, std::int64_t ldvt) { \
2528-
return 1; \
2528+
return 0; \
25292529
}
25302530

25312531
GESVD_LAUNCHER_SCRATCH(float, rocsolverDnSgesvd_bufferSize)
@@ -2552,7 +2552,7 @@ inline void getrf_scratchpad_size(const char *func_name, Func func, sycl::queue
25522552
template <> \
25532553
std::int64_t getrf_scratchpad_size<TYPE>(sycl::queue & queue, std::int64_t m, std::int64_t n, \
25542554
std::int64_t lda) { \
2555-
return 1; \
2555+
return 0; \
25562556
} // namespace rocsolver
25572557

25582558
GETRF_LAUNCHER_SCRATCH(float, rocsolverDnSgetrf_bufferSize)
@@ -2617,7 +2617,7 @@ inline void heevd_scratchpad_size(const char *func_name, Func func, sycl::queue
26172617
std::int64_t heevd_scratchpad_size<TYPE>(sycl::queue & queue, oneapi::mkl::job jobz, \
26182618
oneapi::mkl::uplo uplo, std::int64_t n, \
26192619
std::int64_t lda) { \
2620-
return 1; \
2620+
return 0; \
26212621
} // namespace lapack
26222622

26232623
HEEVD_LAUNCHER_SCRATCH(std::complex<float>, rocsolverDnCheevd_bufferSize)
@@ -2646,7 +2646,7 @@ inline void hegvd_scratchpad_size(const char *func_name, Func func, sycl::queue
26462646
std::int64_t hegvd_scratchpad_size<TYPE>(sycl::queue & queue, std::int64_t itype, \
26472647
oneapi::mkl::job jobz, oneapi::mkl::uplo uplo, \
26482648
std::int64_t n, std::int64_t lda, std::int64_t ldb) { \
2649-
return 1; \
2649+
return 0; \
26502650
} // namespace mkl
26512651

26522652
HEGVD_LAUNCHER_SCRATCH(std::complex<float>, rocsolverDnChegvd_bufferSize)
@@ -2672,7 +2672,7 @@ inline void hetrd_scratchpad_size(const char *func_name, Func func, sycl::queue
26722672
template <> \
26732673
std::int64_t hetrd_scratchpad_size<TYPE>(sycl::queue & queue, oneapi::mkl::uplo uplo, \
26742674
std::int64_t n, std::int64_t lda) { \
2675-
return 1; \
2675+
return 0; \
26762676
} // namespace oneapi
26772677

26782678
HETRD_LAUNCHER_SCRATCH(std::complex<float>, rocsolverDnChetrd_bufferSize)
@@ -2710,7 +2710,7 @@ inline void orgbr_scratchpad_size(const char *func_name, Func func, sycl::queue
27102710
std::int64_t orgbr_scratchpad_size<TYPE>(sycl::queue & queue, oneapi::mkl::generate vec, \
27112711
std::int64_t m, std::int64_t n, std::int64_t k, \
27122712
std::int64_t lda) { \
2713-
return 1; \
2713+
return 0; \
27142714
}
27152715

27162716
ORGBR_LAUNCHER_SCRATCH(float, rocsolverDnSorgbr_bufferSize)
@@ -2736,7 +2736,7 @@ inline void orgtr_scratchpad_size(const char *func_name, Func func, sycl::queue
27362736
template <> \
27372737
std::int64_t orgtr_scratchpad_size<TYPE>(sycl::queue & queue, oneapi::mkl::uplo uplo, \
27382738
std::int64_t n, std::int64_t lda) { \
2739-
return 1; \
2739+
return 0; \
27402740
}
27412741

27422742
ORGTR_LAUNCHER_SCRATCH(float, rocsolverDnSorgtr_bufferSize)
@@ -2762,7 +2762,7 @@ inline void orgqr_scratchpad_size(const char *func_name, Func func, sycl::queue
27622762
template <> \
27632763
std::int64_t orgqr_scratchpad_size<TYPE>(sycl::queue & queue, std::int64_t m, std::int64_t n, \
27642764
std::int64_t k, std::int64_t lda) { \
2765-
return 1; \
2765+
return 0; \
27662766
}
27672767

27682768
ORGQR_LAUNCHER_SCRATCH(float, rocsolverDnSorgqr_bufferSize)
@@ -2806,7 +2806,7 @@ inline void ormqr_scratchpad_size(const char *func_name, Func func, sycl::queue
28062806
std::int64_t ormqr_scratchpad_size<TYPE>( \
28072807
sycl::queue & queue, oneapi::mkl::side side, oneapi::mkl::transpose trans, std::int64_t m, \
28082808
std::int64_t n, std::int64_t k, std::int64_t lda, std::int64_t ldc) { \
2809-
return 1; \
2809+
return 0; \
28102810
}
28112811

28122812
ORMQRF_LAUNCHER_SCRATCH(float, rocsolverDnSormqr_bufferSize)
@@ -2836,7 +2836,7 @@ inline void ormtr_scratchpad_size(const char *func_name, Func func, sycl::queue
28362836
oneapi::mkl::uplo uplo, oneapi::mkl::transpose trans, \
28372837
std::int64_t m, std::int64_t n, std::int64_t lda, \
28382838
std::int64_t ldc) { \
2839-
return 1; \
2839+
return 0; \
28402840
}
28412841

28422842
ORMTR_LAUNCHER_SCRATCH(float, rocsolverDnSormtr_bufferSize)
@@ -2862,7 +2862,7 @@ inline void potrf_scratchpad_size(const char *func_name, Func func, sycl::queue
28622862
template <> \
28632863
std::int64_t potrf_scratchpad_size<TYPE>(sycl::queue & queue, oneapi::mkl::uplo uplo, \
28642864
std::int64_t n, std::int64_t lda) { \
2865-
return 1; \
2865+
return 0; \
28662866
}
28672867

28682868
POTRF_LAUNCHER_SCRATCH(float, rocsolverDnSpotrf_bufferSize)
@@ -2906,7 +2906,7 @@ inline void potri_scratchpad_size(const char *func_name, Func func, sycl::queue
29062906
template <> \
29072907
std::int64_t potri_scratchpad_size<TYPE>(sycl::queue & queue, oneapi::mkl::uplo uplo, \
29082908
std::int64_t n, std::int64_t lda) { \
2909-
return 1; \
2909+
return 0; \
29102910
}
29112911

29122912
POTRI_LAUNCHER_SCRATCH(float, rocsolverDnSpotri_bufferSize)
@@ -2933,7 +2933,7 @@ inline void sytrf_scratchpad_size(const char *func_name, Func func, sycl::queue
29332933
template <> \
29342934
std::int64_t sytrf_scratchpad_size<TYPE>(sycl::queue & queue, oneapi::mkl::uplo uplo, \
29352935
std::int64_t n, std::int64_t lda) { \
2936-
return 1; \
2936+
return 0; \
29372937
}
29382938

29392939
SYTRF_LAUNCHER_SCRATCH(float, rocsolverDnSsytrf_bufferSize)
@@ -2963,7 +2963,7 @@ inline void syevd_scratchpad_size(const char *func_name, Func func, sycl::queue
29632963
std::int64_t syevd_scratchpad_size<TYPE>(sycl::queue & queue, oneapi::mkl::job jobz, \
29642964
oneapi::mkl::uplo uplo, std::int64_t n, \
29652965
std::int64_t lda) { \
2966-
return 1; \
2966+
return 0; \
29672967
}
29682968

29692969
SYEVD_LAUNCHER_SCRATCH(float, rocsolverDnSsyevd_bufferSize)
@@ -2992,7 +2992,7 @@ inline void sygvd_scratchpad_size(const char *func_name, Func func, sycl::queue
29922992
std::int64_t sygvd_scratchpad_size<TYPE>(sycl::queue & queue, std::int64_t itype, \
29932993
oneapi::mkl::job jobz, oneapi::mkl::uplo uplo, \
29942994
std::int64_t n, std::int64_t lda, std::int64_t ldb) { \
2995-
return 1; \
2995+
return 0; \
29962996
}
29972997

29982998
SYGVD_LAUNCHER_SCRATCH(float, rocsolverDnSsygvd_bufferSize)
@@ -3018,7 +3018,7 @@ inline void sytrd_scratchpad_size(const char *func_name, Func func, sycl::queue
30183018
template <> \
30193019
std::int64_t sytrd_scratchpad_size<TYPE>(sycl::queue & queue, oneapi::mkl::uplo uplo, \
30203020
std::int64_t n, std::int64_t lda) { \
3021-
return 1; \
3021+
return 0; \
30223022
}
30233023

30243024
SYTRD_LAUNCHER_SCRATCH(float, rocsolverDnSsytrd_bufferSize)
@@ -3076,7 +3076,7 @@ inline void ungbr_scratchpad_size(const char *func_name, Func func, sycl::queue
30763076
std::int64_t ungbr_scratchpad_size<TYPE>(sycl::queue & queue, oneapi::mkl::generate vec, \
30773077
std::int64_t m, std::int64_t n, std::int64_t k, \
30783078
std::int64_t lda) { \
3079-
return 1; \
3079+
return 0; \
30803080
}
30813081

30823082
UNGBR_LAUNCHER_SCRATCH(std::complex<float>, rocsolverDnCungbr_bufferSize)
@@ -3102,7 +3102,7 @@ inline void ungqr_scratchpad_size(const char *func_name, Func func, sycl::queue
31023102
template <> \
31033103
std::int64_t ungqr_scratchpad_size<TYPE>(sycl::queue & queue, std::int64_t m, std::int64_t n, \
31043104
std::int64_t k, std::int64_t lda) { \
3105-
return 1; \
3105+
return 0; \
31063106
}
31073107

31083108
UNGQR_LAUNCHER_SCRATCH(std::complex<float>, rocsolverDnCungqr_bufferSize)
@@ -3128,7 +3128,7 @@ inline void ungtr_scratchpad_size(const char *func_name, Func func, sycl::queue
31283128
template <> \
31293129
std::int64_t ungtr_scratchpad_size<TYPE>(sycl::queue & queue, oneapi::mkl::uplo uplo, \
31303130
std::int64_t n, std::int64_t lda) { \
3131-
return 1; \
3131+
return 0; \
31323132
}
31333133

31343134
UNGTR_LAUNCHER_SCRATCH(std::complex<float>, rocsolverDnCungtr_bufferSize)
@@ -3174,7 +3174,7 @@ inline void unmqr_scratchpad_size(const char *func_name, Func func, sycl::queue
31743174
std::int64_t unmqr_scratchpad_size<TYPE>( \
31753175
sycl::queue & queue, oneapi::mkl::side side, oneapi::mkl::transpose trans, std::int64_t m, \
31763176
std::int64_t n, std::int64_t k, std::int64_t lda, std::int64_t ldc) { \
3177-
return 1; \
3177+
return 0; \
31783178
}
31793179

31803180
UNMQR_LAUNCHER_SCRATCH(std::complex<float>, rocsolverDnCunmqr_bufferSize)
@@ -3204,7 +3204,7 @@ inline void unmtr_scratchpad_size(const char *func_name, Func func, sycl::queue
32043204
oneapi::mkl::uplo uplo, oneapi::mkl::transpose trans, \
32053205
std::int64_t m, std::int64_t n, std::int64_t lda, \
32063206
std::int64_t ldc) { \
3207-
return 1; \
3207+
return 0; \
32083208
}
32093209

32103210
UNMTR_LAUNCHER_SCRATCH(std::complex<float>, rocsolverDnCunmtr_bufferSize)

src/lapack/backends/rocsolver/rocsolver_task.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,17 @@ namespace lapack {
3434
namespace rocsolver {
3535

3636
template <typename H, typename F>
37+
static inline void stream_wait(sycl::queue queue) {
38+
hipStream_t stream = sycl::get_native<sycl::backend::hip>(queue);
39+
hipStreamSynchronize(stream);
40+
}
41+
42+
3743
static inline void host_task_internal(H &cgh, sycl::queue queue, F f) {
38-
cgh.host_task([f, queue](cl::sycl::interop_handle ih) {
44+
cgh.host_task([f, queue](cl::sycl::interop_handle ih) {
3945
auto sc = RocsolverScopedContextHandler(queue, ih);
4046
f(sc);
47+
stream_wait(queue);
4148
});
4249
}
4350

0 commit comments

Comments
 (0)