@@ -25,7 +25,7 @@ disqualify!(cache::Cache, s::Symbol) = push!(cache.disqualified_symbols, s)
25
25
disqualify! (cache:: Cache , expr:: Expr ) = foreach (arg -> disqualify! (cache, arg), expr. args)
26
26
27
27
# fallback for non-Expr arguments
28
- combine_subexprs! (setup, x, warn_enabled :: Bool ) = x
28
+ combine_subexprs! (setup, x; warn = true , mod = nothing ) = x
29
29
30
30
const standard_expression_forms = Set {Symbol} (
31
31
(:call ,
@@ -50,37 +50,42 @@ const assignment_expression_forms = Set{Symbol}(
50
50
:(*= ),
51
51
:(/= )))
52
52
53
- function combine_subexprs! (cache:: Cache , expr:: Expr , warn_enabled:: Bool )
53
+ function combine_subexprs! (cache:: Cache , expr:: Expr ;
54
+ warn:: Bool = true , mod:: Union{Module, Nothing} = nothing )
54
55
if expr. head == :macrocall
55
- # We don't recursively expand other macros, but we can perform CSE on
56
- # the expression inside the macro call.
57
- for i in 2 : length (expr. args)
58
- expr. args[i] = combine_subexprs! (expr. args[i], warn_enabled)
56
+ if (mod === nothing )
57
+ error ("""
58
+ `cse` cannot expand macro calls unless you explicitly pass in
59
+ a `Module` in which to perform that expansion. You can pass
60
+ `mod=@__MODULE__` to expand in the current module, or you can use
61
+ the `@cse` macro which handles this automatically.""" )
59
62
end
63
+ return combine_subexprs! (cache, macroexpand (mod, expr);
64
+ warn= warn, mod= mod)
60
65
elseif expr. head == :function
61
66
# We can't continue CSE through a function definition, but we can
62
67
# start over inside the body of the function:
63
68
for i in 2 : length (expr. args)
64
- expr. args[i] = combine_subexprs! (expr. args[i], warn_enabled )
69
+ expr. args[i] = combine_subexprs! (expr. args[i]; warn = warn, mod = mod )
65
70
end
66
71
elseif expr. head == :line
67
72
# nothing
68
73
elseif expr. head in assignment_expression_forms
69
74
disqualify! (cache, expr. args[1 ])
70
75
for i in 2 : length (expr. args)
71
- expr. args[i] = combine_subexprs! (cache, expr. args[i], warn_enabled )
76
+ expr. args[i] = combine_subexprs! (cache, expr. args[i]; warn = warn, mod = mod )
72
77
end
73
78
elseif expr. head == :generator
74
79
for i in vcat (2 : length (expr. args), 1 )
75
- expr. args[i] = combine_subexprs! (cache, expr. args[i], warn_enabled )
80
+ expr. args[i] = combine_subexprs! (cache, expr. args[i]; warn = warn, mod = mod )
76
81
end
77
82
elseif expr. head in standard_expression_forms
78
83
for (i, child) in enumerate (expr. args)
79
- expr. args[i] = combine_subexprs! (cache, child, warn_enabled )
84
+ expr. args[i] = combine_subexprs! (cache, child; warn = warn, mod = mod )
80
85
end
81
86
if expr. head == :call
82
87
for (i, child) in enumerate (expr. args)
83
- expr. args[i] = combine_subexprs! (cache, child, warn_enabled )
88
+ expr. args[i] = combine_subexprs! (cache, child, warn = warn, mod = mod )
84
89
end
85
90
if all (! isa (arg, Expr) && ! (arg in cache. disqualified_symbols) for arg in drop (expr. args, 1 ))
86
91
combined_args = Symbol (expr. args... )
@@ -94,21 +99,23 @@ function combine_subexprs!(cache::Cache, expr::Expr, warn_enabled::Bool)
94
99
end
95
100
end
96
101
else
97
- warn_enabled && @warn (" CommonSubexpressions can't yet handle expressions of this form: $(expr. head) " )
102
+ warn && @warn (" CommonSubexpressions can't yet handle expressions of this form: $(expr. head) " )
98
103
end
99
104
return expr
100
105
end
101
106
102
- combine_subexprs! (x, warn_enabled :: Bool = true ) = x
107
+ combine_subexprs! (x; warn = true , mod = nothing ) = x
103
108
104
- function combine_subexprs! (expr:: Expr , warn_enabled :: Bool )
109
+ function combine_subexprs! (expr:: Expr ; warn = true , mod = nothing )
105
110
cache = Cache ()
106
- expr = combine_subexprs! (cache, expr, warn_enabled )
111
+ expr = combine_subexprs! (cache, expr; warn = warn, mod = mod )
107
112
Expr (:block , cache. setup... , expr)
108
113
end
109
114
110
115
function parse_cse_args (args)
111
- params = Dict (:warn => true , :binarize => false )
116
+ # Overly complicated way to look for `warn=true` or `warn=false`,
117
+ # but should be easier to expand for other arguments later.
118
+ params = Dict (:warn => true )
112
119
for (i, arg) in enumerate (args)
113
120
if @capture (arg, key_Symbol = val_Bool)
114
121
if key in keys (params)
@@ -117,42 +124,45 @@ function parse_cse_args(args)
117
124
error (" Unrecognized key: $key " )
118
125
end
119
126
elseif i == 1 && arg isa Bool
120
- Base. depwarn (" The `warn_enabled` positional argument is deprecated. Please use `warn=true` or `warn=false` instead" , :cse_kwargs )
127
+ Base. depwarn (" The `warn_enabled` positional argument is deprecated. Please use `warn=true` or `warn=false` instead" , :cse_macro_kwargs )
121
128
else
122
- error (" Unrecognized argument: $arg . Expected `warn=<bool> ` or `binarize=<bool> `" )
129
+ error (" Unrecognized argument: $arg . Expected `warn=true ` or `warn=false `" )
123
130
124
131
end
125
132
end
126
133
params
127
134
end
128
135
129
136
"""
130
- @cse(expr; warn=true, binarize=false )
137
+ @cse(expr; warn=true)
131
138
132
139
Perform naive common subexpression elimination under the assumption
133
140
that all functions called withing the body of the macro are pure,
134
141
meaning that they have no side effects. See [Readme.md](https://github.com/rdeits/CommonSubexpressions.jl/blob/master/Readme.md)
135
142
for more details.
136
143
137
- This macro will not recursively expand macro calls within the resulting expression.
144
+ This macro will recursively expand macro calls within the expression before
145
+ performing subexpression elimination. A useful macro to combine with this is
146
+ `@binarize`, which will turn n-ary function calls into nested binary calls and
147
+ can therefore provide more opportunities for subexpression elimination. Usage:
148
+
149
+ @cse(@binarize(<your code here>))
138
150
139
151
If the macro encounters an expression which it does not know how to handle,
140
152
it will return that expression unmodified. If `warn=true`, then it
141
153
will also log a warning in that event.
142
-
143
- If `binarize=true` is given, then all n-ary expressions will be recursively
144
- converted into nested binary expressions. See `@binarize` for more information.
145
154
"""
146
155
macro cse (expr, args... )
147
156
params = parse_cse_args (args)
148
- if params[:binarize ]
149
- expr = binarize (expr)
150
- end
151
- result = combine_subexprs! (expr, params[:warn ])
157
+ result = combine_subexprs! (expr, warn= params[:warn ], mod= __module__)
152
158
esc (result)
153
159
end
154
160
155
- cse (expr, warn_enabled:: Bool = true ) = combine_subexprs! (copy (expr), warn_enabled)
161
+ Base. @deprecate cse (expr, warn_enabled:: Bool ) cse (expr, warn= warn_enabled)
162
+
163
+ function cse (expr; warn:: Bool = true , mod:: Union{Module, Nothing} = nothing )
164
+ combine_subexprs! (copy (expr); warn= warn, mod= mod)
165
+ end
156
166
157
167
function _binarize (expr:: Expr )
158
168
if @capture (expr, f_ (a_, b_, c_, args__))
0 commit comments