Skip to content

Commit b4d81fe

Browse files
author
msol
committed
fix bug related to change behavior of threadid()
1 parent 1b1ed3c commit b4d81fe

File tree

4 files changed

+71
-33
lines changed

4 files changed

+71
-33
lines changed

src/byrow/hp_row_functions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function _hp_row_generic_vec!(res, ds, f, colsidx, ::Val{T}) where T
9292
max_cz = length(res) - 1000 - (loopsize - 1)*1000
9393
inmat_all = [Matrix{T}(undef, length(colsidx), max_cz) for i in 1:nt]
9494
# make sure that the variable inside the loop are not the same as the out of scope one
95-
Threads.@threads for i in 1:loopsize
95+
Threads.@threads :static for i in 1:loopsize
9696
t_st = i*1000 + 1
9797
i == loopsize ? t_en = length(res) : t_en = (i+1)*1000
9898
_fill_matrix!(inmat_all[Threads.threadid()], all_data, t_st:t_en, colsidx)

src/sort/int.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ end
9898
function _sort_chunks_int_right!(x, idx::Vector{<:Integer}, idx_cpy, where, number_of_chunks, rangelen, minval, o::Ordering)
9999
cz = div(length(x), number_of_chunks)
100100
en = length(x)
101-
Threads.@threads for i in 1:number_of_chunks
101+
Threads.@threads :static for i in 1:number_of_chunks
102102
ds_sort_int_missatright!(x, idx, idx_cpy, where[Threads.threadid()], (i-1)*cz+1,i*cz, rangelen, minval)
103103
end
104104
# take care of the last few observations
@@ -111,7 +111,7 @@ end
111111
function _sort_chunks_int_left!(x, idx::Vector{<:Integer}, idx_cpy, where, number_of_chunks, rangelen, minval, o::Ordering)
112112
cz = div(length(x), number_of_chunks)
113113
en = length(x)
114-
Threads.@threads for i in 1:number_of_chunks
114+
Threads.@threads :static for i in 1:number_of_chunks
115115
ds_sort_int_missatleft!(x, idx, idx_cpy, where[Threads.threadid()], (i-1)*cz+1,i*cz, rangelen, minval)
116116
end
117117
# take care of the last few observations
@@ -262,7 +262,7 @@ function _ds_sort_int_missatright_nopermx_threaded!(x, original_P, copy_P, lo, h
262262
where[i][1] = 1
263263
where[i][2] = 1
264264
end
265-
Threads.@threads for i = lo:hi
265+
Threads.@threads :static for i = lo:hi
266266
@inbounds ismissing(x[i]) ? where[Threads.threadid()][rangelen+3] += 1 : where[Threads.threadid()][Int(x[i]) + offs + 2] += 1
267267
end
268268
for j in 3:length(where[1])
@@ -306,7 +306,7 @@ function _ds_sort_int_missatright_nopermx_threaded!(x, original_P, rangelen, min
306306
where[i][1] = 1
307307
where[i][2] = 1
308308
end
309-
Threads.@threads for i = 1:length(x)
309+
Threads.@threads :static for i = 1:length(x)
310310
@inbounds ismissing(x[i]) ? where[Threads.threadid()][rangelen+3] += 1 : where[Threads.threadid()][Int(x[i]) + offs + 2] += 1
311311
end
312312
for j in 3:length(where[1])
@@ -348,7 +348,7 @@ function _ds_sort_int_missatleft_nopermx_threaded!(x, original_P, copy_P, lo, hi
348348
where[i][1] = 1
349349
where[i][2] = 1
350350
end
351-
Threads.@threads for i = lo:hi
351+
Threads.@threads :static for i = lo:hi
352352
@inbounds ismissing(x[i]) ? where[Threads.threadid()][3] += 1 : where[Threads.threadid()][Int(x[i]) + offs + 3] += 1
353353
end
354354
for j in 3:length(where[1])
@@ -392,7 +392,7 @@ function _ds_sort_int_missatleft_nopermx_threaded!(x, original_P, rangelen, minv
392392
where[i][1] = 1
393393
where[i][2] = 1
394394
end
395-
Threads.@threads for i = 1:length(x)
395+
Threads.@threads :static for i = 1:length(x)
396396
@inbounds ismissing(x[i]) ? where[Threads.threadid()][3] += 1 : where[Threads.threadid()][Int(x[i]) + offs + 3] += 1
397397
end
398398
for j in 3:length(where[1])

src/sort/sort.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,21 @@ end
213213

214214
function _issorted_check_for_each_range(v, starts, lastvalid, _ord, nrows; threads = true)
215215
part_res = ones(Bool, threads ? Threads.nthreads() : 1)
216-
@_threadsfor threads for rng in 1:lastvalid
217-
lo = starts[rng]
218-
rng == lastvalid ? hi = nrows : hi = starts[rng+1] - 1
219-
part_res[Threads.threadid()] = _issorted_barrier(v, _ord, lo, hi)
220-
!part_res[Threads.threadid()] && break
216+
if threads
217+
218+
Threads.@threads :static for rng in 1:lastvalid
219+
lo = starts[rng]
220+
rng == lastvalid ? hi = nrows : hi = starts[rng+1] - 1
221+
part_res[Threads.threadid()] = _issorted_barrier(v, _ord, lo, hi)
222+
!part_res[Threads.threadid()] && break
223+
end
224+
else
225+
for rng in 1:lastvalid
226+
lo = starts[rng]
227+
rng == lastvalid ? hi = nrows : hi = starts[rng+1] - 1
228+
part_res[Threads.threadid()] = _issorted_barrier(v, _ord, lo, hi)
229+
!part_res[Threads.threadid()] && break
230+
end
221231
end
222232
all(part_res)
223233
end

src/sort/sortperm.jl

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ end
2929
# we should find starts here
3030
function fast_sortperm_int_threaded!(x, original_P, copy_P, ranges, rangelen, minval, misatleft, last_valid_range, ::Val{T}) where T
3131
starts = [T[] for i in 1:Threads.nthreads()]
32-
Threads.@threads for i in 1:last_valid_range
32+
Threads.@threads :static for i in 1:last_valid_range
3333
rangestart = ranges[i]
3434
i == last_valid_range ? rangeend = length(x) : rangeend = ranges[i+1] - 1
3535
# if (rangeend - rangestart) == 0
@@ -105,29 +105,57 @@ function fast_sortperm_int!(x, original_P, copy_P, ranges, rangelen, minval, mis
105105
end
106106

107107
function _sortperm_int!(idx, idx_cpy, x, ranges, where, last_valid_range, missingatleft, ord, a; threads = true)
108-
@_threadsfor threads for i in 1:last_valid_range
109-
rangestart = ranges[i]
110-
i == last_valid_range ? rangeend = length(x) : rangeend = ranges[i+1] - 1
111-
if (rangeend - rangestart + 1) == 1
112-
continue
113-
end
114-
_minval = stat_minimum(x, lo = rangestart, hi = rangeend)
115-
if ismissing(_minval)
116-
continue
117-
else
118-
minval::Int = _minval
108+
if threads
109+
Threads.@threads :static for i in 1:last_valid_range
110+
rangestart = ranges[i]
111+
i == last_valid_range ? rangeend = length(x) : rangeend = ranges[i+1] - 1
112+
if (rangeend - rangestart + 1) == 1
113+
continue
114+
end
115+
_minval = stat_minimum(x, lo = rangestart, hi = rangeend)
116+
if ismissing(_minval)
117+
continue
118+
else
119+
minval::Int = _minval
120+
end
121+
maxval::Int = stat_maximum(x, lo = rangestart, hi = rangeend)
122+
# the overflow is check before calling _sortperm_int!
123+
rangelen = maxval - minval + 1
124+
if rangelen < div(rangeend - rangestart + 1, 2)
125+
if missingatleft
126+
ds_sort_int_missatleft!(x, idx, idx_cpy, where[Threads.threadid()], rangestart, rangeend, rangelen, minval)
127+
else
128+
ds_sort_int_missatright!(x, idx, idx_cpy, where[Threads.threadid()], rangestart, rangeend, rangelen, minval)
129+
end
130+
else
131+
ds_sort!(x, idx, rangestart, rangeend, a, ord)
132+
end
119133
end
120-
maxval::Int = stat_maximum(x, lo = rangestart, hi = rangeend)
121-
# the overflow is check before calling _sortperm_int!
122-
rangelen = maxval - minval + 1
123-
if rangelen < div(rangeend - rangestart + 1, 2)
124-
if missingatleft
125-
ds_sort_int_missatleft!(x, idx, idx_cpy, where[Threads.threadid()], rangestart, rangeend, rangelen, minval)
134+
else
135+
for i in 1:last_valid_range
136+
rangestart = ranges[i]
137+
i == last_valid_range ? rangeend = length(x) : rangeend = ranges[i+1] - 1
138+
if (rangeend - rangestart + 1) == 1
139+
continue
140+
end
141+
_minval = stat_minimum(x, lo = rangestart, hi = rangeend)
142+
if ismissing(_minval)
143+
continue
126144
else
127-
ds_sort_int_missatright!(x, idx, idx_cpy, where[Threads.threadid()], rangestart, rangeend, rangelen, minval)
145+
minval::Int = _minval
146+
end
147+
maxval::Int = stat_maximum(x, lo = rangestart, hi = rangeend)
148+
# the overflow is check before calling _sortperm_int!
149+
rangelen = maxval - minval + 1
150+
if rangelen < div(rangeend - rangestart + 1, 2)
151+
if missingatleft
152+
ds_sort_int_missatleft!(x, idx, idx_cpy, where[Threads.threadid()], rangestart, rangeend, rangelen, minval)
153+
else
154+
ds_sort_int_missatright!(x, idx, idx_cpy, where[Threads.threadid()], rangestart, rangeend, rangelen, minval)
155+
end
156+
else
157+
ds_sort!(x, idx, rangestart, rangeend, a, ord)
128158
end
129-
else
130-
ds_sort!(x, idx, rangestart, rangeend, a, ord)
131159
end
132160
end
133161
end

0 commit comments

Comments
 (0)