Skip to content

Add a preference option for DFTK threading #972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 31, 2024
8 changes: 4 additions & 4 deletions docs/src/tricks/parallelization.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,13 @@ BLAS.set_num_threads(N)
where `N` is the number of threads you desire.
To **check the number of BLAS threads** currently used, you can use `BLAS.get_num_threads()`.

### Julia threads
On top of BLAS threading DFTK uses Julia threads (`Thread.@threads`)
### DFTK threads
On top of BLAS threading DFTK uses Julia threads
in a couple of places to parallelize over ``k``-points (density computation)
or bands (Hamiltonian application).
The number of threads used for these aspects is controlled by the
The number of threads used for these aspects is controlled by default by the
flag `-t` passed to Julia or the *environment variable* `JULIA_NUM_THREADS`.
To **check the number of Julia threads** use `Threads.nthreads()`.
It can also be set through `setup_threading(; n_DFTK)`.

### FFT threads
Since FFT threading is only used in DFTK inside the regions already parallelized
Expand Down
77 changes: 56 additions & 21 deletions src/common/threading.jl
Original file line number Diff line number Diff line change
@@ -1,42 +1,77 @@
import FFTW
using LinearAlgebra

function setup_threading(; n_fft=1, n_blas=Threads.nthreads())
n_julia = Threads.nthreads()
function setup_threading(; n_fft=1, n_blas=Threads.nthreads(), n_DFTK=nothing)
if n_DFTK != nothing
set_DFTK_threads!(n_DFTK)
end
n_DFTK = @load_preference("DFTK_threads", Threads.nthreads())
FFTW.set_num_threads(n_fft)
BLAS.set_num_threads(n_blas)
mpi_master() && @info "Threading setup:" n_fft n_blas n_julia
mpi_master() && @info "Threading setup: $n_julia Julia threads, $n_DFTK DFTK threads, $n_fft FFT threads, $n_blas BLAS threads"

end

"""
Convenience function to disable all threading in DFTK and assert that Julia threading
is off as well.
"""
function disable_threading()
n_julia = Threads.nthreads()
n_julia > 1 && mpi_master() && error(
"Julia currently uses $n_julia threads. Ensure that the environment variable " *
"JULIA_NUM_THREADS is unset and julia does not get the `-t` flag passed."
)
@assert n_julia == 1 # To exit in non-master MPI nodes
setup_threading(;n_fft=1, n_blas=1)
setup_threading(;n_fft=1, n_blas=1, n_DFTK=1)
end

# Parallelization loop breaking range into chunks.
function parallel_loop_over_range(fun, storages::AbstractVector, range)
chunk_length = cld(length(range), Threads.nthreads())
function set_DFTK_threads!(n)
if @load_preference("DFTK_threads", nothing) != n
@info "DFTK_threads preference changed. Restart julia to see the effect."
end
@set_preferences!("DFTK_threads" => n)
end
function set_DFTK_threads!()
@delete_preferences!("DFTK_threads")
end

iszero(chunk_length) && return storages
"""
Parallelize a loop, calling `fun(i)` for side effects for all i in `range`.
If allocate_local_storage is not nothing, `fun` is called as `fun(i, st)` where
`st` is a thread-local temporary storage allocated by `allocate_local_storage()`.
"""
function parallel_loop_over_range(fun, range; allocate_local_storage=nothing)
nthreads = @load_preference("DFTK_threads", Threads.nthreads())
if !isnothing(allocate_local_storage)
storages = [allocate_local_storage() for _ = 1:nthreads]
else
storages = nothing
end
parallel_loop_over_range(fun, range, storages)
end
# private interface to be called
function parallel_loop_over_range(fun, range, storages)
nthreads = length(storages)
chunk_length = cld(length(range), nthreads)

@sync for (ichunk, chunk) in enumerate(Iterators.partition(range, chunk_length))
Threads.@spawn for idc in chunk # spawn a task per chunk
fun(storages[ichunk], idc)
# this tensorized if is ugly, but this is potentially
# performance critical and factoring it is more trouble
# than it's worth
if nthreads == 1
for i in range
if isnothing(storages)
fun(i)
else
fun(i, storages[1])
else
end
elseif length(range) == 0
# do nothing
else
@sync for (ichunk, chunk) in enumerate(Iterators.partition(range, chunk_length))
Threads.@spawn for i in chunk # spawn a task per chunk
if isnothing(storages)
fun(i)
else
fun(i, storages[ichunk])
else
end
end

return storages
end
function parallel_loop_over_range(fun, allocate_local_storage::Function, range)
storages = [allocate_local_storage() for _ = 1:Threads.nthreads()]
parallel_loop_over_range(fun, storages, range)
end
4 changes: 2 additions & 2 deletions src/densities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ using an optional `occupation_threshold`. By default all occupation numbers are
# We split the total iteration range (ik, n) in chunks, and parallelize over them.
range = [(ik, n) for ik = 1:length(basis.kpoints) for n = mask_occ[ik]]

storages = parallel_loop_over_range(allocate_local_storage, range) do storage, kn
storages = parallel_loop_over_range(range; allocate_local_storage) do kn, storage
(ik, n) = kn
kpt = basis.kpoints[ik]

Expand Down Expand Up @@ -76,7 +76,7 @@ end
# |ψ_{n,k}|² is 2 ψ_{n,k} * δψ_{n,k+q}.
# Hence, we first get the δψ_{[k+q]} as δψ_{k+q}…
δψ_plus_k = transfer_blochwave_equivalent_to_actual(basis, δψ, q)
storages = parallel_loop_over_range(allocate_local_storage, range) do storage, kn
storages = parallel_loop_over_range(range; allocate_local_storage) do kn, storage
(ik, n) = kn

kpt = basis.kpoints[ik]
Expand Down
2 changes: 1 addition & 1 deletion src/eigen/lobpcg_hyper_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ end
# Find X that is orthogonal, and B-orthogonal to Y, up to a tolerance tol.
@timing "ortho! X vs Y" function ortho!(X::AbstractArray{T}, Y, BY; tol=2eps(real(T))) where {T}
# normalize to try to cheaply improve conditioning
Threads.@threads for i=1:size(X,2)
parallel_loop_over_range(1:size(X, 2)) do i
n = norm(@views X[:,i])
@views X[:,i] ./= n
end
Expand Down
4 changes: 2 additions & 2 deletions src/eigen/preconditioners.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end
if P.mean_kin === nothing
ldiv!(Y, Diagonal(P.kin .+ P.default_shift), R)
else
Threads.@threads for n = 1:size(Y, 2)
parallel_loop_over_range(1:size(Y, 2)) do n
Y[:, n] .= P.mean_kin[n] ./ (P.mean_kin[n] .+ P.kin) .* R[:, n]
end
end
Expand All @@ -64,7 +64,7 @@ ldiv!(P::PreconditionerTPA, R) = ldiv!(R, P, R)
if P.mean_kin === nothing
mul!(Y, Diagonal(P.kin .+ P.default_shift), R)
else
Threads.@threads for n = 1:size(Y, 2)
parallel_loop_over_range(1:size(Y, 2)) do n
Y[:, n] .= (P.mean_kin[n] .+ P.kin) ./ P.mean_kin[n] .* R[:, n]
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/terms/Hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ Base.:*(H::Hamiltonian, ψ) = mul!(deepcopy(ψ), H, ψ)
ψ_real = similar(ψ, complex(T), H.basis.fft_size...),
Hψ_real = similar(Hψ, complex(T), H.basis.fft_size...))
end
parallel_loop_over_range(allocate_local_storage, 1:size(ψ, 2)) do storage, iband
parallel_loop_over_range(1:size(ψ, 2); allocate_local_storage) do iband, storage
to = TimerOutput() # Thread-local timer output

# Take ψi, IFFT it to ψ_real, apply each term to Hψ_fourier and Hψ_real, and add it
Expand Down Expand Up @@ -140,7 +140,7 @@ end
# Notice that we use unnormalized plans for extra speed
potential = H.local_op.potential / prod(H.basis.fft_size)

parallel_loop_over_range(H.scratch, 1:n_bands) do storage, iband
parallel_loop_over_range(1:n_bands, H.scratch) do iband, storage
to = TimerOutput() # Thread-local timer output
ψ_real = storage.ψ_reals

Expand Down
Loading