Skip to content

Commit c43c82b

Browse files
authored
add a user define option for threshold in bregman (#3)
* add a user define option for threshold in bregman * have a custom thresholding function for first iter * update options for lambda function * sorry about slightly changed api * dispatch, make kwargs into options * doesnt change API * pre-process at bregmanparams * clean up tests and documentation * TD is at the end for funobj bregman * do defaults * fixed all * don't need to be in same type
1 parent 3f2822e commit c43c82b

File tree

6 files changed

+109
-101
lines changed

6 files changed

+109
-101
lines changed

Project.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "SlimOptim"
22
uuid = "e4c7bc62-5b23-4522-a1b9-71c2be45f1df"
33
authors = ["Mathias Louboutin <[email protected]>"]
4-
version = "0.1.7"
4+
version = "0.1.8"
55

66
[deps]
77
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
10+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1011
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1112

1213
[compat]
13-
julia = "1"
1414
LineSearches = "7.1.1"
15+
julia = "1"

examples/denoising.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ n = 256
1111
k = 4
1212
# Sparse in wavelet domain
1313
W = joDWT(n, n; DDT=Float32, RDT=Float32)
14-
# Or with curvelet ifi nstalled
14+
# Or with curvelet if installed
1515
# W = joCurvelet2D(128, 128; DDT=Float32, RDT=Float32)
1616
A = vcat([joRomberg(n, n; DDT=Float32, RDT=Float32) for i=1:k]...)
1717

@@ -20,11 +20,11 @@ imgn= img .+ .01f0*randn(Float32, size(img))
2020
b = A*vec(imgn)
2121

2222
# setup bregamn
23-
opt = bregman_options(maxIter=200, verbose=2, quantile=.5, alpha=1, antichatter=true)
24-
opt2 = bregman_options(maxIter=200, verbose=2, quantile=.5, alpha=1, antichatter=true, spg=true)
23+
opt = bregman_options(maxIter=200, verbose=2, quantile=.5, alpha=1, antichatter=true, TD=W)
24+
opt2 = bregman_options(maxIter=200, verbose=2, quantile=.5, alpha=1, antichatter=true, spg=true, TD=W)
2525

26-
sol = bregman(A, W, zeros(Float32, n*n), b, opt)
27-
sol2 = bregman(A, W, zeros(Float32, n*n), b, opt2)
26+
sol = bregman(A, zeros(Float32, n*n), b, opt)
27+
sol2 = bregman(A, zeros(Float32, n*n), b, opt2)
2828

2929
figure()
3030
subplot(121)

examples/lsrtm_marmousi.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,5 @@ function breg_obj(x)
5252
return .5f0*norm(r)^2, g[1:end]
5353
end
5454

55-
opt = bregman_options(maxIter=5, verbose=2, quantile=.9, alpha=1, antichatter=true)#, spg=true)
56-
sol = bregman(breg_obj, 0f0.*vec(m0), opt, C)
55+
opt = bregman_options(maxIter=5, verbose=2, quantile=.9, alpha=1, antichatter=true, TD=C)#, spg=true)
56+
sol = bregman(breg_obj, 0f0.*vec(m0), opt)

src/SlimOptim.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
module SlimOptim
55

6-
using Printf, LinearAlgebra, LineSearches
6+
using Printf, LinearAlgebra, LineSearches, Statistics
77

88
import LineSearches: BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe
99
export BackTracking, HagerZhang, Static, MoreThuente, StrongWolfe

src/bregman.jl

+63-60
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ mutable struct BregmanParams
77
maxIter
88
store_trace
99
antichatter
10-
quantile
1110
alpha
1211
spg
12+
TD
13+
λfunc
1314
end
1415

1516
"""
16-
bregman_options(;verbose=1, optTol=1e-6, progTol=1e-8, maxIter=20
17-
store_trace=false, linesearch=false, alpha=.25, spg=false)
17+
bregman_options(;verbose=1, optTol=1e-6, progTol=1e-8, maxIter=20,
18+
store_trace=false, quantile=.5, alpha=.25, spg=false)
1819
1920
Options structure for the bregman iteration algorithm
2021
@@ -25,69 +26,87 @@ Options structure for the bregman iteration algorithm
2526
- `maxIter`: maximum number of iterations (default: 20)
2627
- `store_trace`: Whether to store the trace/history of x (default: false)
2728
- `antichatter`: Whether to use anti-chatter step length correction
28-
- `quantile`: Thresholding level as quantile value, (default=.95 i.e thresholds 95% of the vector)
2929
- `alpha`: Strong convexity modulus. (step length is ``α \\frac{||r||_2^2}{||g||_2^2}``)
30+
- `spg`: whether to use spg, default is false
31+
- `TD`: sparsifying transform (e.g. curvelet), default is identity (LinearAlgebra.I)
32+
- `λfunc`: a function to calculate threshold value, default is nothing
33+
- `λ`: a pre-set threshold, will only be used if `λfunc` is not defined, default is nothing
34+
- `quantile`: a percentage to calculate the threshold by quantile of the dual variable in 1st iteration, will only be used if neither `λfunc` nor `λ` are defined, default is .95 i.e thresholds 95% of the vector
3035
3136
"""
32-
bregman_options(;verbose=1, progTol=1e-8, maxIter=20, store_trace=false, antichatter=true, quantile=.95, alpha=.5, spg=false) =
33-
BregmanParams(verbose, progTol, maxIter, store_trace, antichatter, quantile, alpha, spg)
37+
function bregman_options(;verbose=1, progTol=1e-8, maxIter=20, store_trace=false, antichatter=true, alpha=.5, spg=false, TD=LinearAlgebra.I, quantile=.95, λ=nothing, λfunc=nothing)
38+
if isnothing(λfunc)
39+
if ~isnothing(λ)
40+
λfunc = z->λ
41+
else
42+
λfunc = z->Statistics.quantile(abs.(z), quantile)
43+
end
44+
end
45+
return BregmanParams(verbose, progTol, maxIter, store_trace, antichatter, alpha, spg, TD, λfunc)
46+
end
3447

3548
"""
36-
bregman(A, TD, x, b, options)
49+
bregman(A, x, b, options)
3750
3851
Linearized bregman iteration for the system
3952
4053
``\\frac{1}{2} ||TD \\ x||_2^2 + λ ||TD \\ x||_1 \\ \\ \\ s.t Ax = b``
4154
4255
For example, for sparsity promoting denoising (i.e LSRTM)
4356
44-
# Arguments
57+
# Required arguments
4558
46-
- `TD`: curvelet transform
47-
- `A`: Forward operator (J or preconditioned J for LSRTM)
48-
- `b`: observed data
59+
- `A`: Forward operator (e.g. J or preconditioned J for LSRTM)
4960
- `x`: Initial guess
61+
- `b`: observed data
62+
63+
# Non-required arguments
64+
65+
- `options`: bregman options, default is bregman_options(); options.TD provides the sparsifying transform (e.g. curvelet)
5066
"""
51-
function bregman(A, TD, x::Array{T}, b, options) where {T}
67+
function bregman(A, x::AbstractVector{T1}, b::AbstractVector{T2}, options::BregmanParams=bregman_options()) where {T1<:Number, T2<:Number}
5268
# residual function wrapper
5369
function obj(x)
5470
d = A*x
5571
fun = .5*norm(d - b)^2
5672
grad = A'*(d - b)
5773
return fun, grad
5874
end
59-
60-
return bregman(obj, x, options, TD)
75+
return bregman(obj, x, options)
76+
end
77+
78+
function bregman(A, TD, x::AbstractVector{T1}, b::AbstractVector{T2}, options::BregmanParams=bregman_options()) where {T1<:Number, T2<:Number}
79+
@warn "deprecation warning: please put TD in options (BregmanParams) for version > 0.1.7; now overwritting TD in BregmanParams"
80+
options.TD = TD
81+
return bregman(A, x, b, options)
6182
end
6283

6384
"""
64-
bregman(fun, TD, x, b, options)
85+
bregman(funobj, x, options)
6586
6687
Linearized bregman iteration for the system
6788
6889
``\\frac{1}{2} ||TD \\ x||_2^2 + λ ||TD \\ x||_1 \\ \\ \\ s.t Ax = b``
6990
70-
For example, for sparsity promoting denoising (i.e LSRTM)
91+
# Required arguments
7192
72-
# Arguments
73-
74-
- `TD`: curvelet transform
75-
- `fun`: residual function, return the tuple (``f = \\frac{1}{2}||Ax - b||_2``, ``g = A^T(Ax - b)``)
76-
- `b`: observed data
93+
- `funobj`: a function that calculates the objective value (`0.5 * norm(Ax-b)^2`) and the gradient (`A'(Ax-b)`)
7794
- `x`: Initial guess
7895
96+
# Non-required arguments
97+
98+
- `options`: bregman options, default is bregman_options(); options.TD provides the sparsifying transform (e.g. curvelet)
7999
"""
80-
function bregman(funobj::Function, x::AbstractArray{T}, options::BregmanParams, TD=nothing) where {T}
100+
function bregman(funobj::Function, x::AbstractVector{T}, options::BregmanParams=bregman_options()) where {T}
81101
# Output Parameter Settings
82102
if options.verbose > 0
83103
@printf("Running linearized bregman...\n");
84104
@printf("Progress tolerance: %.2e\n",options.progTol)
85105
@printf("Maximum number of iterations: %d\n",options.maxIter)
86106
@printf("Anti-chatter correction: %d\n",options.antichatter)
87107
end
88-
isnothing(TD) && (TD = LinearAlgebra.I)
89-
# Intitalize variables
90-
z = TD*x
108+
# Initialize variables
109+
z = options.TD*x
91110
d = similar(z)
92111
options.spg && (gold = similar(x); xold=similar(x))
93112
if options.antichatter
@@ -96,8 +115,6 @@ function bregman(funobj::Function, x::AbstractArray{T}, options::BregmanParams,
96115

97116
# Result structure
98117
sol = breglog(x, z)
99-
# Initialize λ
100-
λ = abs(T(0))
101118

102119
# Output Log
103120
if options.verbose > 0
@@ -108,60 +125,45 @@ function bregman(funobj::Function, x::AbstractArray{T}, options::BregmanParams,
108125
for i=1:options.maxIter
109126
f, g = funobj(x)
110127
# Preconditionned ipdate direction
111-
d .= -TD*g
128+
d .= -options.TD*g
112129
# Step length
113130
t = (options.spg && i> 1) ? T(dot(x-xold, x-xold)/dot(x-xold, g-gold)) : T(options.alpha*f/norm(d)^2)
114131
t = abs(t)
115132
mul!(d, d, t)
116133

117134
# Anti-chatter
118135
if options.antichatter
136+
@assert isreal(z) "we currently do not support anti-chatter for complex numbers"
119137
@. tk = tk - sign(d)
120-
# Chatter correction
121-
inds_z = findall(abs.(z) .> λ)
122-
@views d[inds_z] .*= abs.(tk[inds_z])/i
138+
# Chatter correction after 1st iteration
139+
if i > 1
140+
inds_z = findall(abs.(z) .> sol.λ)
141+
@views d[inds_z] .*= abs.(tk[inds_z])/i
142+
end
123143
end
124144
# Update z variable
125145
@. z = z + d
126146
# Get λ at first iteration
127-
i == 1 &&= abs(T(quantile(abs.(z), options.quantile))))
147+
(i == 1) && (sol.λ = abs.(T.(options.λfunc(z))))
128148
# Save curent state
129149
options.spg && (gold .= g; xold .= x)
130150
# Update x
131-
x = TD'*soft_thresholding(z, λ)
151+
x = options.TD'*soft_thresholding(z, sol.λ)
132152

133-
obj_fun = λ * norm(z, 1) + .5 * norm(z, 2)^2
134-
if options.verbose > 0
135-
@printf("%10d %15.5e %15.5e %15.5e %15.5e \n",i, t, obj_fun, f, λ)
136-
end
153+
obj_fun = norm(sol.λ .* z, 1) + .5 * norm(z, 2)^2
154+
(options.verbose > 0) && (@printf("%10d %15.5e %15.5e %15.5e %15.5e \n",i, t, obj_fun, f, maximum(sol.λ)))
137155
norm(x - sol.x) < options.progTol && (@printf("Step size below progTol\n"); break;)
138156
update!(sol; iter=i, ϕ=obj_fun, residual=f, x=x, z=z, g=g, store_trace=options.store_trace)
139157
end
140158
return sol
141159
end
142160

143-
# Utility functions
144-
"""
145-
Simplified Quantile from Statistics.jl since we only need simplified version of it.
146-
"""
147-
function quantile(u::AbstractVector, p::Real)
148-
0 <= p <= 1 || throw(ArgumentError("input probability out of [0,1] range"))
149-
n = length(u)
150-
v = sort(u; alg=Base.QuickSort)
151-
152-
m = 1 - p
153-
aleph = n*p + oftype(p, m)
154-
j = clamp(trunc(Int, aleph), 1, n-1)
155-
γ = clamp(aleph - j, 0, 1)
156-
157-
n == 1 ? a = v[1] : a = v[j]
158-
n == 1 ? b = v[1] : b = v[j+1]
159-
160-
(isfinite(a) && isfinite(b)) ? q = a + γ*(b-a) : q = (1-γ)*a + γ*b
161-
return q
161+
function bregman(funobj::Function, x::AbstractVector{T}, options::BregmanParams, TD) where {T}
162+
@warn "deprecation warning: please put TD in options (BregmanParams) for version > 0.1.7; now overwritting TD in BregmanParams"
163+
options.TD = TD
164+
return bregman(funobj, x, options)
162165
end
163166

164-
165167
"""
166168
Bregman result structure
167169
"""
@@ -170,6 +172,7 @@ mutable struct BregmanIterations
170172
z
171173
g
172174
ϕ
175+
λ
173176
residual
174177
ϕ_trace
175178
r_trace
@@ -189,6 +192,6 @@ function update!(r::BregmanIterations; x=nothing, z=nothing, ϕ=nothing, residua
189192
(~isnothing(residual) && length(r.r_trace) == iter-1) && (push!(r.r_trace, residual))
190193
end
191194

192-
function breglog(init_x, init_z; f0=0, obj0=0)
193-
return BregmanIterations(1*init_x, 1*init_z, 0*init_z, f0, obj0, Vector{}(), Vector{}(), Vector{}(), Vector{}())
194-
end
195+
function breglog(init_x, init_z; lambda0=0, f0=0, obj0=0)
196+
return BregmanIterations(1*init_x, 1*init_z, 0*init_z, f0, lambda0, obj0, Vector{}(), Vector{}(), Vector{}(), Vector{}())
197+
end

test/test_bregman.jl

+35-31
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,38 @@ using LinearAlgebra
55

66
N1 = 100
77
N2 = div(N1, 2) + 5
8-
A = randn(N1, N2)
9-
10-
x0 = 10 .* randn(N2)
11-
x0[abs.(x0) .< 1f-6] .= 1.0
12-
inds = rand(1:N2, div(N2, 4))
13-
ninds = [i for i=1:N2 if i inds]
14-
x0[inds] .= 0
15-
b = A*x0
16-
17-
function obj(x)
18-
fun = .5*norm(A*x - b)^2
19-
grad = A'*(A*x - b)
20-
return fun, grad
21-
end
22-
23-
opt = bregman_options(maxIter=200, progTol=0, verbose=2)
24-
sol = bregman(obj, 1 .+ randn(N2), opt)
25-
26-
@show sol.x[inds]
27-
@show x0[inds]
28-
@show sol.x[ninds]
29-
@show x0[ninds]
30-
31-
part_n = i -> norm(sol.x[i] - x0[i])/(norm(x0[i]) + norm(sol.x[i]) + eps(Float64))
32-
part_nz = i -> norm(sol.x[i], 1)/N2
33-
@show part_nz(inds)
34-
@show part_n(ninds)
35-
36-
@test part_nz(inds) < 1f-1
37-
@test part_n(ninds) < 1f-1
38-
@test sol.residual/sol.r_trace[1] < 1f-1
8+
9+
@testset "Bregman test for type $(T)" for T = [Float32, ComplexF32]
10+
11+
A = randn(T, N1, N2)
12+
x0 = 10 .* randn(T, N2)
13+
x0[abs.(x0) .< 1f-6] .= 1.0
14+
inds = rand(1:N2, div(N2, 4))
15+
ninds = [i for i=1:N2 if i inds]
16+
x0[inds] .= 0
17+
b = A*x0
18+
19+
function obj(x)
20+
fun = .5*norm(A*x - b)^2
21+
grad = A'*(A*x - b)
22+
return fun, grad
23+
end
24+
25+
opt = bregman_options(maxIter=200, progTol=0, verbose=2, antichatter=T==Float32)
26+
sol = bregman(obj, 1 .+ randn(T, N2), opt)
27+
28+
@show sol.x[inds]
29+
@show x0[inds]
30+
@show sol.x[ninds]
31+
@show x0[ninds]
32+
33+
part_n = i -> norm(sol.x[i] - x0[i])/(norm(x0[i]) + norm(sol.x[i]) + eps(Float32))
34+
part_nz = i -> norm(sol.x[i], 1)/N2
35+
@show part_nz(inds)
36+
@show part_n(ninds)
37+
38+
@test part_nz(inds) < 1f-1
39+
@test part_n(ninds) < 1f-1
40+
@test sol.residual/sol.r_trace[1] < 1f-1
41+
42+
end

0 commit comments

Comments
 (0)