Skip to content

Commit 035e20a

Browse files
test: add test for broadcast autodiff
1 parent a8d416a commit 035e20a

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

test/adjoints.jl

+6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ function loss6(x)
3737
sum(abs2, Array(_prob.u0))
3838
end
3939

40+
function loss7(x)
41+
_x = VectorOfArray([x .* i for i in 1:5])
42+
return sum(abs2, x .- 1)
43+
end
44+
4045
x = float.(6:10)
4146
loss(x)
4247
@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x)
@@ -45,3 +50,4 @@ loss(x)
4550
@test Zygote.gradient(loss4, x)[1] == ForwardDiff.gradient(loss4, x)
4651
@test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x)
4752
@test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x)
53+
@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x)

0 commit comments

Comments
 (0)