Skip to content

Commit

Permalink
work around SR-13945 segfault (tensorflow#1141)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcrasi authored Dec 9, 2020
1 parent eaa1c8e commit e8686f8
Showing 1 changed file with 53 additions and 19 deletions.
72 changes: 53 additions & 19 deletions Sources/TensorFlow/Layers/Recurrent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -623,26 +623,10 @@ where Cell.TimeStepOutput: Mergeable {
) -> Output {
let forwardOutputs = forward(
inputs, initialState: initialForwardLayerState)

// TODO: Replace with inputs.reversed() after it become differentiable.
var inputsReversed = Input()

for forwardIndex in 0 ..< withoutDerivative(at: inputs.count) {
let backwardIndex = withoutDerivative(at: inputs.count - 1 - forwardIndex)
inputsReversed.append(inputs[backwardIndex])
}

let backwardOutputs = backward(
inputsReversed, initialState: initialBackwardLayerState)

var outputs = Output()

for forwardIndex in 0 ..< withoutDerivative(at: inputs.count) {
let backwardIndex = withoutDerivative(at: inputs.count - 1 - forwardIndex)
outputs.append(mergeFunction(forwardOutputs[forwardIndex], backwardOutputs[backwardIndex]))
}

return outputs
inputs.differentiableReversed(), initialState: initialBackwardLayerState)
return forwardOutputs.differentiableMerging(
backwardOutputs.differentiableReversed(), mergeFunction: mergeFunction)
}

@differentiable
Expand Down Expand Up @@ -703,3 +687,53 @@ public typealias SimpleRNNCell = BasicRNNCell

@available(*, deprecated, renamed: "BasicRNN")
public typealias SimpleRNN = BasicRNN

// - MARK: Workaround helpers.

fileprivate extension Array where Element: Differentiable {
/// Returns a reversed copy of `self`.
///
/// This has a custom derivative, which works around the SR-13945 segfault that you would
/// encounter if you tried to implement this at the callsite using a for loop.
@differentiable
func differentiableReversed() -> Self {
.init(self.reversed())
}

@derivative(of: differentiableReversed)
func vjpDifferentiableReversed()
-> (value: Self, pullback: (TangentVector) -> TangentVector)
{
return (self.differentiableReversed(), { .init(.init($0.base.reversed())) })
}

/// Returns `zip(self, other).map { mergeFunction($0.0, $0.1) }`.
///
/// This has a custom derivative, which works around the SR-13945 segfault that you would
/// encounter if you tried to implement this at the callsite using a for loop.
@differentiable
func differentiableMerging(
_ other: Self, mergeFunction: @differentiable (Element, Element) -> Element
) -> Self {
zip(self, other).map { mergeFunction($0.0, $0.1) }
}

@derivative(of: differentiableMerging)
func vjpDifferentiableMerging(
_ other: Self, mergeFunction: @differentiable (Element, Element) -> Element
) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
let valuesWithPullbacks = zip(self, other).map {
valueWithPullback(at: $0.0, $0.1, in: mergeFunction)
}
let pullbacks = valuesWithPullbacks.map { $0.pullback }
return (
valuesWithPullbacks.map { $0.value },
{ vs in
let resultPairs = zip(vs.base, pullbacks).map { (v, pb) in
pb(v)
}
return (.init(resultPairs.map { $0.0 }), .init(resultPairs.map { $0.1 }))
}
)
}
}

0 comments on commit e8686f8

Please sign in to comment.