@@ -465,7 +465,7 @@ function rrule(
465465 y = first (last (hobbits))
466466 project = ProjectTo (x)
467467 function foldl_pullback_tuple (dy)
468- trio = accumulate (_reverse1 (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
468+ trio = accumulate (reverse (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
469469 ds, da, db = back (dc)
470470 # Don't need to store every `da`, need one for the next iteration + the last.
471471 end
@@ -501,78 +501,43 @@ end
501501
502502# The implementation was originally for both tuples and arrays, although using accumulate
503503# to carry intermediate results along creates arrays of tuples which could be avoided.
504- # Using a loop can be a few times faster, this should be replaced.
505- # Note also that it does not return a gradient for `init`.
504+ # Using a loop can be a few times faster, this should be replaced:
505+ # https://github.com/FluxML/Zygote.jl/issues/644#issuecomment-628762305
506+
507+ # Note also that it does not return a gradient for `init`, now marked `@not_implemented`.
506508
507509function rrule (
508- config:: RuleConfig{>:HasReverseMode} , :: typeof (Base. mapfoldl_impl), :: typeof (identity), op:: G , init, x:: Union{AbstractArray, Tuple} ;
510+ config:: RuleConfig{>:HasReverseMode} , :: typeof (Base. mapfoldl_impl), :: typeof (identity), op:: G , init, x:: Union{AbstractArray, Tuple} ;
509511 ) where {G}
510- list, start = if init === _INIT
511- _drop1 (x), first (x)
512+ start, list = if init === Base . _InitialValue ()
513+ Iterators . peel (x)
512514 else
513515 # Case with init keyword is simpler to understand first!
514- _reshape1 (x, :), init # (vec is for Julia 1.0, accumulate is fussy)
516+ init, x
515517 end
516- hobbits = accumulate (list; init= (start, nothing )) do (a,_), b
517- # Here `a` is what we would normally cary forward, and `_` ignores
518- # the previous iteration's pullback function (needed later),
519- # while `b` is the fresh input from `list` as usual.
520- c, back = rrule_via_ad (config, op, a, b) # LHS is just documentation here!
521- # We don't really need to store every `c`, last one is `foldl` output.
522- # (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
518+ hobbits = accumulate (list; init= (start, nothing )) do (a, _), b
519+ c, back = rrule_via_ad (config, op, a, b)
523520 end
524521 y = first (last (hobbits))
525522 axe = axes (x)
526523 project = ProjectTo (x)
527524 function unfoldl (dy)
528- trio = accumulate (_reverse1 (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
525+ trio = accumulate (Iterators . reverse (hobbits); init= (0 , dy, 0 )) do (_, dc, _), (_, back)
529526 ds, da, db = back (dc)
530- # Don't need to store every `da`, need one for the next iteration + maybe last
531527 end
532528 dop = sum (first, trio)
533- dx = map (last, _reverse1 (trio))
534- if init === _INIT
535- # `hobbits` is one short
529+ dx = map (last, Iterators. reverse (trio))
530+ if init === Base. _InitialValue () # `hobbits` is one short
536531 dx = _vcat1 (trio[end ][2 ], dx)
537532 end
538533 d_init = @not_implemented " gradient for foldl does not at present include init, sorry"
539- return (NoTangent (), NoTangent (), dop, d_init, project (_reshape1 (dx, axe)))
534+ return (NoTangent (), NoTangent (), dop, d_init, project (reshape (dx, axe)))
540535 end
541536 return y, unfoldl
542537end
543538
544-
545- # ####
546- # #### Iterator-or-Tuple functions
547- # ####
548-
549- # This zoo of underscore functions helps `foldl` & `accumulate` handle both tuples and arrays,
550- # and also provides some alternatives for versions of Julia where iterators weren't supported.
551- # Inspired by `Base._reverse`, used in defn of `foldr`.
552-
553- # To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps
554- # be replaced by _peel1 like Iterators.peel
555-
556- _reverse1 (x) = Iterators. reverse (x)
557- _drop1 (x) = Iterators. drop (x, 1 )
558- _zip2 (x, y) = zip (x, y) # for `accumulate`, below
559-
560- _reverse1 (x:: Tuple ) = reverse (x)
561- _drop1 (x:: Tuple ) = Base. tail (x)
562- _zip2 (x:: Tuple{Vararg{Any,N}} , y:: Tuple{Vararg{Any,N}} ) where N = ntuple (i -> (x[i],y[i]), N)
563-
564- const _INIT = Base. _InitialValue ()
565-
566539_vcat1 (x, ys:: AbstractVector ) = vcat (x, ys)
567540_vcat1 (x:: AbstractArray , ys:: AbstractVector ) = vcat ([x], ys)
568- _vcat1 (x, ys:: Tuple ) = (x, ys... )
569-
570- _reshape1 (x:: AbstractArray , axe) = reshape (x, axe)
571- _reshape1 (x:: Tuple , axe) = x
572-
573- _no_tuple_tangent (dx:: Tangent ) = ChainRulesCore. backing (dx)
574- _no_tuple_tangent (dx) = dx
575-
576541
577542# ####
578543# #### `accumulate`
@@ -584,13 +549,18 @@ _no_tuple_tangent(dx) = dx
584549# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)`
585550
586551function rrule (
587- config:: RuleConfig{>:HasReverseMode} , :: typeof (Base. _accumulate!), op:: G , y, x:: AbstractVector , dims:: Nothing , init,
552+ config:: RuleConfig{>:HasReverseMode} ,
553+ :: typeof (Base. _accumulate!),
554+ op:: G , y:: AbstractVector ,
555+ x:: AbstractVector ,
556+ dims:: Nothing ,
557+ init,
588558 ) where {G}
589559
590- list, start = if init === nothing
591- _drop1 (x), first (x)
560+ start, list = if init === nothing
561+ Iterators . peel (x)
592562 else
593- x, something (init)
563+ something (init), x
594564 end
595565 hobbits = accumulate (list; init = (start, nothing )) do (a, _), b
596566 c, back = rrule_via_ad (config, op, a, b)
@@ -607,28 +577,24 @@ function rrule(
607577 axe = axes (x)
608578 project = ProjectTo (x)
609579 function decumulate (dy)
610- dy_plain = _no_tuple_tangent (unthunk (dy))
611- rev_list = if init === nothing
612- # Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
613- # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
614- _zip2 (_reverse1 (hobbits), _reverse1 (dy_plain))
615- else
616- _zip2 (_reverse1 (hobbits), _reverse1 (dy_plain))
617- end
580+ dy_plain = unthunk (dy)
581+ rev_list = zip (Iterators. reverse (hobbits), Iterators. reverse (dy_plain))
582+ # Here we rely on `zip` to stop early when init === nothing. Begin explicit with Iterators.reverse(Iterators.drop(..., 1))
583+ # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
618584 trio = accumulate (rev_list; init= (0 , ZeroTangent (), 0 )) do (_, dc, _), ((_, back), dz)
619585 ds, da, db = back (dc + dz)
620586 # Don't need to store every 'da', but need for next iteration, and the last one.
621587 end
622588 dop = sum (first, trio)
623- dx = map (last, _reverse1 (trio))
589+ dx = map (last, Iterators . reverse (trio))
624590 if init == nothing
625591 # `hobbits` is one short, and the first one is weird
626592 dx = _vcat1 (trio[end ][2 ] + dy_plain[1 ], dx)
627593 end
628594 dy = @not_implemented " no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only"
629595 d_init_not = @not_implemented " gradient for accumulate does not at present include init, sorry"
630596 d_init = init === nothing ? NoTangent () : Tangent {typeof(init)} (; value = d_init_not)
631- return (NoTangent (), dop, dy, project (_reshape1 (dx, axe)), NoTangent (), d_init)
597+ return (NoTangent (), dop, dy, project (reshape (dx, axe)), NoTangent (), d_init)
632598 end
633- return _reshape1 (y, axe), decumulate
599+ return reshape (y, axe), decumulate
634600end
0 commit comments