@@ -327,22 +327,18 @@ some cases.
327
327
"""
328
328
function flatten (bc:: Broadcasted{Style} ) where {Style}
329
329
isflat (bc) && return bc
330
- # concatenate the nested arguments into {a, b, c, d}
330
+ # 1. concatenate the nested arguments into {a, b, c, d}
331
331
args = cat_nested (bc)
332
- # build a function `makeargs` that takes a "flat" argument list and
332
+ # 2. build a `Tuple` of functions `makeargs`
333
+ # The element of `makeargs` takes a "flat" argument list and
333
334
# and creates the appropriate input arguments for `f`, e.g.,
334
- # makeargs = (w, x, y, z) -> (w, g(x, y), z)
335
- #
336
- # `makeargs` is built recursively and looks a bit like this:
337
- # makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...)
338
- # = (w, g(x, y), makeargs2(z)...)
339
- # = (w, g(x, y), z)
340
- let makeargs = make_makeargs (()-> (), bc. args), f = bc. f
341
- newf = @inline function (args:: Vararg{Any,N} ) where N
342
- f (makeargs (args... )... )
343
- end
344
- return Broadcasted {Style} (newf, args, bc. axes)
345
- end
335
+ # makeargs[1] = ((w, x, y, z)) -> w
336
+ # makeargs[2] = ((w, x, y, z)) -> g(x, y)
337
+ # makeargs[2] = ((w, x, y, z)) -> z
338
+ makeargs, _ = make_makeargs (bc. args, 1 )
339
+ f = callable (bc. f)
340
+ @inline newf (args... ) = f (prepare_args (makeargs, args)... )
341
+ return Broadcasted {Style} (newf, args, bc. axes)
346
342
end
347
343
348
344
const NestedTuple = Tuple{<: Broadcasted ,Vararg{Any}}
@@ -351,78 +347,48 @@ _isflat(args::NestedTuple) = false
351
347
_isflat (args:: Tuple ) = _isflat (tail (args))
352
348
_isflat (args:: Tuple{} ) = true
353
349
354
- cat_nested (t:: Broadcasted , rest... ) = (cat_nested (t. args... )... , cat_nested (rest... )... )
355
- cat_nested (t:: Any , rest... ) = (t, cat_nested (rest... )... )
356
- cat_nested () = ()
350
+ cat_nested (bc:: Broadcasted ) = cat_nested (bc. args)
351
+ cat_nested (:: Tuple{} ) = ()
352
+ cat_nested (t:: Tuple{Any} ) = cat_nested1 (t[1 ])
353
+ cat_nested (t:: Tuple ) = (cat_nested1 (t[1 ])... , cat_nested (tail (t))... )
354
+ cat_nested1 (a) = (a,)
355
+ cat_nested1 (bc:: Broadcasted ) = cat_nested (bc. args)
356
+
357
+ struct Pick{N} <: Function end
358
+ (:: Pick{N} )(@nospecialize (args:: Tuple )) where {N} = args[N]
359
+
360
+ struct Callable{F} <: Function end
361
+ @inline (:: Callable{F} )(args... ) where {F} = F (args... )
362
+ callable (:: Type{F} ) where {F} = Callable {F} ()
363
+ callable (f) = f
357
364
358
365
"""
359
- make_makeargs(makeargs_tail::Function, t::Tuple ) -> Function
366
+ make_makeargs(t::Tuple, n::Int ) -> Tuple{Vararg{ Function}}
360
367
361
368
Each element of `t` is one (consecutive) node in a broadcast tree.
362
- Ignoring `makeargs_tail` for the moment, the job of `make_makeargs` is
363
- to return a function that takes in flattened argument list and returns a
364
- tuple (each entry corresponding to an entry in `t`, having evaluated
365
- the corresponding element in the broadcast tree). As an additional
366
- complication, the passed in tuple may be longer than the number of leaves
367
- in the subtree described by `t`. The `makeargs_tail` function should
368
- be called on such additional arguments (but not the arguments consumed
369
- by `t`).
369
+ `n` denotes the position of the first argument needed by `t[1]`
370
+ in the flattened list. The returned `Tuple` are functions which
371
+ take in the (whole) flattened list and generate the inputs of the
372
+ corresponding broadcasted function.
370
373
"""
371
- @inline make_makeargs (makeargs_tail, t:: Tuple{} ) = makeargs_tail
372
- @inline function make_makeargs (makeargs_tail, t:: Tuple )
373
- makeargs = make_makeargs (makeargs_tail, tail (t))
374
- (head, tail... )-> (head, makeargs (tail... )... )
375
- end
376
- function make_makeargs (makeargs_tail, t:: Tuple{<:Broadcasted, Vararg{Any}} )
377
- bc = t[1 ]
378
- # c.f. the same expression in the function on leaf nodes above. Here
379
- # we recurse into siblings in the broadcast tree.
380
- let makeargs_tail = make_makeargs (makeargs_tail, tail (t)),
381
- # Here we recurse into children. It would be valid to pass in makeargs_tail
382
- # here, and not use it below. However, in that case, our recursion is no
383
- # longer purely structural because we're building up one argument (the closure)
384
- # while destructuing another.
385
- makeargs_head = make_makeargs ((args... )-> args, bc. args),
386
- f = bc. f
387
- # Create two functions, one that splits of the first length(bc.args)
388
- # elements from the tuple and one that yields the remaining arguments.
389
- # N.B. We can't call headargs on `args...` directly because
390
- # args is flattened (i.e. our children have not been evaluated
391
- # yet).
392
- headargs, tailargs = make_headargs (bc. args), make_tailargs (bc. args)
393
- return @inline function (args:: Vararg{Any,N} ) where N
394
- args1 = makeargs_head (args... )
395
- a, b = headargs (args1... ), makeargs_tail (tailargs (args1... )... )
396
- (f (a... ), b... )
397
- end
398
- end
374
+ @inline function make_makeargs (args:: Tuple , n:: Int )
375
+ head, n = make_makeargs1 (args[1 ], n)
376
+ rest, n = make_makeargs (tail (args), n)
377
+ (head, rest... ), n
399
378
end
379
+ make_makeargs (:: Tuple{} , n:: Int ) = (), n
400
380
401
- @inline function make_headargs (t:: Tuple )
402
- let headargs = make_headargs (tail (t))
403
- return @inline function (head, tail:: Vararg{Any,N} ) where N
404
- (head, headargs (tail... )... )
405
- end
406
- end
407
- end
408
- @inline function make_headargs (:: Tuple{} )
409
- return @inline function (tail:: Vararg{Any,N} ) where N
410
- ()
411
- end
381
+ @inline function make_makeargs1 (bc:: Broadcasted , n:: Int )
382
+ makeargs, n = make_makeargs (bc. args, n)
383
+ f = callable (bc. f)
384
+ @inline newf (args) = f (prepare_args (makeargs, args)... )
385
+ newf, n
412
386
end
387
+ @inline make_makeargs1 (_, n:: Int ) = Pick {n} (), n + 1
413
388
414
- @inline function make_tailargs (t:: Tuple )
415
- let tailargs = make_tailargs (tail (t))
416
- return @inline function (head, tail:: Vararg{Any,N} ) where N
417
- tailargs (tail... )
418
- end
419
- end
420
- end
421
- @inline function make_tailargs (:: Tuple{} )
422
- return @inline function (tail:: Vararg{Any,N} ) where N
423
- tail
424
- end
425
- end
389
+ @inline prepare_args (pf:: Tuple , @nospecialize (x:: Tuple )) = (pf[1 ](x), prepare_args (tail (pf), x)... )
390
+ @inline prepare_args (pf:: Tuple{Any} , @nospecialize (x:: Tuple )) = (pf[1 ](x),)
391
+ prepare_args (:: Tuple{} , :: Tuple ) = ()
426
392
427
393
# # Broadcasting utilities ##
428
394
0 commit comments