diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 245335774..e48416739 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -268,3 +268,16 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular}) end return y, logdet_pullback end + +##### +##### Tridiagonal +##### + +function rrule(::Type{Tridiagonal}, dl, d, du) + y = Tridiagonal(dl, d, du) + function Tridiagonal_pullback(ȳ) + ∂y = unthunk(ȳ) + return (NoTangent(), diag(∂y, -1), diag(∂y), diag(∂y, 1)) + end + return y, Tridiagonal_pullback +end diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 1b83cc394..c46460f46 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -161,4 +161,8 @@ end end end + + @testset "Tridiagonal" begin + test_rrule(Tridiagonal, [1.0, 4.0], [2.0, 3.0, 4.0], [5.0, 3.0]) + end end