Skip to content

Commit 2ff3077

Browse files
committed
re-enable but don't require macro expansion to make @cse(@binarize(expr)) work
1 parent 372195e commit 2ff3077

File tree

2 files changed

+39
-29
lines changed

2 files changed

+39
-29
lines changed

src/CommonSubexpressions.jl

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ disqualify!(cache::Cache, s::Symbol) = push!(cache.disqualified_symbols, s)
2525
disqualify!(cache::Cache, expr::Expr) = foreach(arg -> disqualify!(cache, arg), expr.args)
2626

2727
# fallback for non-Expr arguments
28-
combine_subexprs!(setup, x, warn_enabled::Bool) = x
28+
combine_subexprs!(setup, x; warn=true, mod=nothing) = x
2929

3030
const standard_expression_forms = Set{Symbol}(
3131
(:call,
@@ -50,37 +50,42 @@ const assignment_expression_forms = Set{Symbol}(
5050
:(*=),
5151
:(/=)))
5252

53-
function combine_subexprs!(cache::Cache, expr::Expr, warn_enabled::Bool)
53+
function combine_subexprs!(cache::Cache, expr::Expr;
54+
warn::Bool=true, mod::Union{Module, Nothing}=nothing)
5455
if expr.head == :macrocall
55-
# We don't recursively expand other macros, but we can perform CSE on
56-
# the expression inside the macro call.
57-
for i in 2:length(expr.args)
58-
expr.args[i] = combine_subexprs!(expr.args[i], warn_enabled)
56+
if (mod === nothing)
57+
error("""
58+
`cse` cannot expand macro calls unless you explicitly pass in
59+
a `Module` in which to perform that expansion. You can pass
60+
`mod=@__MODULE__` to expand in the current module, or you can use
61+
the `@cse` macro which handles this automatically.""")
5962
end
63+
return combine_subexprs!(cache, macroexpand(mod, expr);
64+
warn=warn, mod=mod)
6065
elseif expr.head == :function
6166
# We can't continue CSE through a function definition, but we can
6267
# start over inside the body of the function:
6368
for i in 2:length(expr.args)
64-
expr.args[i] = combine_subexprs!(expr.args[i], warn_enabled)
69+
expr.args[i] = combine_subexprs!(expr.args[i]; warn=warn, mod=mod)
6570
end
6671
elseif expr.head == :line
6772
# nothing
6873
elseif expr.head in assignment_expression_forms
6974
disqualify!(cache, expr.args[1])
7075
for i in 2:length(expr.args)
71-
expr.args[i] = combine_subexprs!(cache, expr.args[i], warn_enabled)
76+
expr.args[i] = combine_subexprs!(cache, expr.args[i]; warn=warn, mod=mod)
7277
end
7378
elseif expr.head == :generator
7479
for i in vcat(2:length(expr.args), 1)
75-
expr.args[i] = combine_subexprs!(cache, expr.args[i], warn_enabled)
80+
expr.args[i] = combine_subexprs!(cache, expr.args[i]; warn=warn, mod=mod)
7681
end
7782
elseif expr.head in standard_expression_forms
7883
for (i, child) in enumerate(expr.args)
79-
expr.args[i] = combine_subexprs!(cache, child, warn_enabled)
84+
expr.args[i] = combine_subexprs!(cache, child; warn=warn, mod=mod)
8085
end
8186
if expr.head == :call
8287
for (i, child) in enumerate(expr.args)
83-
expr.args[i] = combine_subexprs!(cache, child, warn_enabled)
88+
expr.args[i] = combine_subexprs!(cache, child, warn=warn, mod=mod)
8489
end
8590
if all(!isa(arg, Expr) && !(arg in cache.disqualified_symbols) for arg in drop(expr.args, 1))
8691
combined_args = Symbol(expr.args...)
@@ -94,21 +99,23 @@ function combine_subexprs!(cache::Cache, expr::Expr, warn_enabled::Bool)
9499
end
95100
end
96101
else
97-
warn_enabled && @warn("CommonSubexpressions can't yet handle expressions of this form: $(expr.head)")
102+
warn && @warn("CommonSubexpressions can't yet handle expressions of this form: $(expr.head)")
98103
end
99104
return expr
100105
end
101106

102-
combine_subexprs!(x, warn_enabled::Bool = true) = x
107+
combine_subexprs!(x; warn=true, mod=nothing) = x
103108

104-
function combine_subexprs!(expr::Expr, warn_enabled::Bool)
109+
function combine_subexprs!(expr::Expr; warn=true, mod=nothing)
105110
cache = Cache()
106-
expr = combine_subexprs!(cache, expr, warn_enabled)
111+
expr = combine_subexprs!(cache, expr; warn=warn, mod=mod)
107112
Expr(:block, cache.setup..., expr)
108113
end
109114

110115
function parse_cse_args(args)
111-
params = Dict(:warn => true, :binarize => false)
116+
# Overly complicated way to look for `warn=true` or `warn=false`,
117+
# but should be easier to expand for other arguments later.
118+
params = Dict(:warn => true)
112119
for (i, arg) in enumerate(args)
113120
if @capture(arg, key_Symbol = val_Bool)
114121
if key in keys(params)
@@ -117,42 +124,45 @@ function parse_cse_args(args)
117124
error("Unrecognized key: $key")
118125
end
119126
elseif i == 1 && arg isa Bool
120-
Base.depwarn("The `warn_enabled` positional argument is deprecated. Please use `warn=true` or `warn=false` instead", :cse_kwargs)
127+
Base.depwarn("The `warn_enabled` positional argument is deprecated. Please use `warn=true` or `warn=false` instead", :cse_macro_kwargs)
121128
else
122-
error("Unrecognized argument: $arg. Expected `warn=<bool>` or `binarize=<bool>`")
129+
error("Unrecognized argument: $arg. Expected `warn=true` or `warn=false`")
123130

124131
end
125132
end
126133
params
127134
end
128135

129136
"""
130-
@cse(expr; warn=true, binarize=false)
137+
@cse(expr; warn=true)
131138
132139
Perform naive common subexpression elimination under the assumption
133140
that all functions called withing the body of the macro are pure,
134141
meaning that they have no side effects. See [Readme.md](https://github.com/rdeits/CommonSubexpressions.jl/blob/master/Readme.md)
135142
for more details.
136143
137-
This macro will not recursively expand macro calls within the resulting expression.
144+
This macro will recursively expand macro calls within the expression before
145+
performing subexpression elimination. A useful macro to combine with this is
146+
`@binarize`, which will turn n-ary function calls into nested binary calls and
147+
can therefore provide more opportunities for subexpression elimination. Usage:
148+
149+
@cse(@binarize(<your code here>))
138150
139151
If the macro encounters an expression which it does not know how to handle,
140152
it will return that expression unmodified. If `warn=true`, then it
141153
will also log a warning in that event.
142-
143-
If `binarize=true` is given, then all n-ary expressions will be recursively
144-
converted into nested binary expressions. See `@binarize` for more information.
145154
"""
146155
macro cse(expr, args...)
147156
params = parse_cse_args(args)
148-
if params[:binarize]
149-
expr = binarize(expr)
150-
end
151-
result = combine_subexprs!(expr, params[:warn])
157+
result = combine_subexprs!(expr, warn=params[:warn], mod=__module__)
152158
esc(result)
153159
end
154160

155-
cse(expr, warn_enabled::Bool = true) = combine_subexprs!(copy(expr), warn_enabled)
161+
Base.@deprecate cse(expr, warn_enabled::Bool) cse(expr, warn=warn_enabled)
162+
163+
function cse(expr; warn::Bool=true, mod::Union{Module, Nothing}=nothing)
164+
combine_subexprs!(copy(expr); warn=warn, mod=mod)
165+
end
156166

157167
function _binarize(expr::Expr)
158168
if @capture(expr, f_(a_, b_, c_, args__))

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ module NestedMacroTest
226226
@test special_plus_calls[] == 2
227227

228228
special_plus_calls[] = 0
229-
@test(@cse(@special_math((1 + 2 + 3) + (1 + 2 + 4) + (1 + 2 + 5)), binarize=true) == 21)
229+
@test(@cse(@binarize(@special_math((1 + 2 + 3) + (1 + 2 + 4) + (1 + 2 + 5)))) == 21)
230230
# Test that the duplicate calls to `1 + 2` were eliminated
231231
@test special_plus_calls[] == 6
232232
end

0 commit comments

Comments
 (0)