Skip to content

Commit b3989ed

Browse files
committed
add threads static to aid thread API evolution
1 parent 9010b7f commit b3989ed

File tree

4 files changed

+51
-19
lines changed

4 files changed

+51
-19
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ Command-line option changes
9797

9898
Multi-threading changes
9999
-----------------------
100+
* `@threads` now allows an optional schedule argument. Use `@threads static ...` to
101+
ensure that the same schedule will be used as in past versions; the default schedule
102+
is likely to change in the future.
100103

101104

102105
Build system changes

base/threadingconstructs.jl

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ function _threadsfor(iter,lbody)
8282
end
8383
end
8484
end
85-
if threadid() != 1
86-
# only thread 1 can enter/exit _threadedregion
85+
if threadid() != 1 || ccall(:jl_in_threaded_region, Cint, ()) != 0
86+
# only use threads when called from thread 1, outside @threads
8787
Base.invokelatest(threadsfor_fun, true)
8888
else
8989
threading_run(threadsfor_fun)
@@ -93,31 +93,45 @@ function _threadsfor(iter,lbody)
9393
end
9494

9595
"""
96-
Threads.@threads
96+
Threads.@threads [schedule] for ... end
9797
98-
A macro to parallelize a for-loop to run with multiple threads. This spawns [`nthreads()`](@ref)
99-
number of threads, splits the iteration space amongst them, and iterates in parallel.
100-
A barrier is placed at the end of the loop which waits for all the threads to finish
101-
execution, and the loop returns.
98+
A macro to parallelize a `for` loop to run with multiple threads. Splits the iteration
99+
space among multiple tasks and runs those tasks on threads according to a scheduling
100+
policy.
101+
A barrier is placed at the end of the loop which waits for all tasks to finish
102+
execution.
103+
104+
The `schedule` argument can be used to request a particular scheduling policy.
105+
The only currently supported value is `static`, which creates one task per thread
106+
and divides the iterations equally among them. If called from inside another
107+
`@threads` loop, or from a thread other than 1, then all iterations run on the
108+
current thread.
109+
110+
The default schedule (used when no `schedule` argument is present) is subject to change.
111+
112+
!!! compat "Julia 1.5"
113+
The `schedule` argument is available as of Julia 1.5.
102114
"""
103115
macro threads(args...)
104116
na = length(args)
105-
if na != 1
117+
if na == 2
118+
sched, ex = args
119+
elseif na == 1
120+
sched = :static
121+
ex = args[1]
122+
else
106123
throw(ArgumentError("wrong number of arguments in @threads"))
107124
end
108-
ex = args[1]
109-
if !isa(ex, Expr)
110-
throw(ArgumentError("need an expression argument to @threads"))
125+
if !(isa(ex, Expr) && ex.head === :for)
126+
throw(ArgumentError("@threads requires a `for` loop expression"))
111127
end
112-
if ex.head === :for
113-
if ex.args[1] isa Expr && ex.args[1].head === :(=)
114-
return _threadsfor(ex.args[1], ex.args[2])
115-
else
116-
throw(ArgumentError("nested outer loops are not currently supported by @threads"))
117-
end
118-
else
119-
throw(ArgumentError("unrecognized argument to @threads"))
128+
if sched != :static
129+
throw(ArgumentError("unsupported schedule argument in @threads"))
130+
end
131+
if !(ex.args[1] isa Expr && ex.args[1].head === :(=))
132+
throw(ArgumentError("nested outer loops are not currently supported by @threads"))
120133
end
134+
return _threadsfor(ex.args[1], ex.args[2])
121135
end
122136

123137
"""

src/threading.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,11 @@ void jl_start_threads(void)
481481

482482
unsigned volatile _threadedregion; // HACK: keep track of whether it is safe to do IO
483483

484+
JL_DLLEXPORT int jl_in_threaded_region(void)
485+
{
486+
return _threadedregion != 0;
487+
}
488+
484489
JL_DLLEXPORT void jl_enter_threaded_region(void)
485490
{
486491
_threadedregion += 1;

test/threads_exec.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,16 @@ let a = zeros(nthreads())
729729
@test a == [1:nthreads();]
730730
end
731731

732+
# static schedule
733+
function _atthreads_static_schedule()
734+
ids = zeros(Int, nthreads())
735+
Threads.@threads static for i = 1:nthreads()
736+
ids[i] = Threads.threadid()
737+
end
738+
return ids
739+
end
740+
@test _atthreads_static_schedule() == [1:nthreads();]
741+
732742
try
733743
@macroexpand @threads(for i = 1:10, j = 1:10; end)
734744
catch ex

0 commit comments

Comments
 (0)