Skip to content

Commit cea4ddf

Browse files
Fixes for #67 (#68)
* Fix identity * Bump patch version * Tightens testing + fixes bugs * Cleans up test Co-Authored-By: Lyndon White <[email protected]> * Documents behaviour. Co-Authored-By: Lyndon White <[email protected]> * Checks to_vec for loud failure * Uses oftype for conversion * Documents to_vec better Co-authored-by: Lyndon White <[email protected]>
1 parent faa5dd0 commit cea4ddf

File tree

6 files changed

+54
-25
lines changed

6 files changed

+54
-25
lines changed

Project.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FiniteDifferences"
22
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
3-
version = "0.9.3"
3+
version = "0.9.4"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -12,6 +12,7 @@ julia = "1"
1212
[extras]
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
15+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1516

1617
[targets]
17-
test = ["Random", "Test"]
18+
test = ["Random", "StaticArrays", "Test"]

src/grad.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function jacobian(fdm, f, x::Vector{<:Number}; len=nothing)
1616
return fdm(zero(eltype(x))) do ε
1717
xn = x[n]
1818
x[n] = xn + ε
19-
ret = first(to_vec(f(x)))
19+
ret = copy(first(to_vec(f(x)))) # copy required incase `f(x)` returns something that aliases `x`
2020
x[n] = xn # Can't do `x[n] -= ϵ` as floating-point math is not associative
2121
return ret
2222
end

src/to_vec.jl

+25-19
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,38 @@ transformation.
66
"""
77
function to_vec(x::Number)
88
function Number_from_vec(x_vec)
9-
return first(x_vec)
9+
return oftype(x, first(x_vec))
1010
end
1111
return [x], Number_from_vec
1212
end
1313

14-
# Vectors
14+
# Base case -- if x is already a Vector{<:Number} there's no conversion necessary.
1515
to_vec(x::Vector{<:Number}) = (x, identity)
16-
function to_vec(x::Vector)
16+
17+
function to_vec(x::AbstractVector)
1718
x_vecs_and_backs = map(to_vec, x)
1819
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
1920
function Vector_from_vec(x_vec)
2021
sz = cumsum(map(length, x_vecs))
21-
return [backs[n](x_vec[sz[n] - length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x)]
22+
x_Vec = [backs[n](x_vec[sz[n] - length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x)]
23+
return oftype(x, x_Vec)
2224
end
2325
return vcat(x_vecs...), Vector_from_vec
2426
end
2527

26-
# Arrays
27-
function to_vec(x::Array{<:Number})
28-
function Array_from_vec(x_vec)
29-
return reshape(x_vec, size(x))
30-
end
31-
return vec(x), Array_from_vec
32-
end
28+
function to_vec(x::AbstractArray)
29+
30+
x_vec, from_vec = to_vec(vec(x))
3331

34-
function to_vec(x::Array)
35-
x_vec, back = to_vec(reshape(x, :))
3632
function Array_from_vec(x_vec)
37-
return reshape(back(x_vec), size(x))
33+
return oftype(x, reshape(from_vec(x_vec), size(x)))
3834
end
35+
3936
return x_vec, Array_from_vec
4037
end
4138

42-
# AbstractArrays
39+
# Some specific subtypes of AbstractArray.
40+
4341
function to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular}
4442
x_vec, back = to_vec(Matrix(x))
4543
function AbstractTriangular_from_vec(x_vec)
@@ -63,17 +61,25 @@ function to_vec(X::Diagonal)
6361
end
6462

6563
function to_vec(X::Transpose)
64+
65+
x_vec, x_from_vec = to_vec(X.parent)
66+
6667
function Transpose_from_vec(x_vec)
67-
return Transpose(permutedims(reshape(x_vec, size(X))))
68+
return Transpose(x_from_vec(x_vec))
6869
end
69-
return vec(Matrix(X)), Transpose_from_vec
70+
71+
return x_vec, Transpose_from_vec
7072
end
7173

7274
function to_vec(X::Adjoint)
75+
76+
x_vec, x_from_vec = to_vec(X.parent)
77+
7378
function Adjoint_from_vec(x_vec)
74-
return Adjoint(conj!(permutedims(reshape(x_vec, size(X)))))
79+
return Adjoint(x_from_vec(x_vec))
7580
end
76-
return vec(Matrix(X)), Adjoint_from_vec
81+
82+
return x_vec, Adjoint_from_vec
7783
end
7884

7985
# Non-array data structures

test/grad.jl

+3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ using FiniteDifferences: grad, jacobian, _jvp, jvp, j′vp, _j′vp, to_vec
5252
@test Ac == A
5353
check_jac_and_jvp_and_j′vp(fdm, x->sin.(A * x), ȳ, x, ẋ, cos.(A * x) .* A)
5454
@test Ac == A
55+
56+
# Prevent regression against https://github.com/JuliaDiff/FiniteDifferences.jl/issues/67
57+
@test first(jacobian(fdm, identity, x)) one(Matrix{T}(undef, length(x), length(x)))
5558
end
5659

5760
@testset "multi vars jacobian/grad" begin

test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using FiniteDifferences, Test, Random, Printf, LinearAlgebra
1+
using FiniteDifferences, Test, Random, Printf, LinearAlgebra, StaticArrays
22

33
@testset "FiniteDifferences" begin
44
include("methods.jl")

test/to_vec.jl

+21-2
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,26 @@ end
1111
Base.:(==)(x::DummyType, y::DummyType) = x.X == y.X
1212
Base.length(x::DummyType) = size(x.X, 1)
1313

14-
function test_to_vec(x)
14+
# A dummy FillVector. This is a type for which the fallback implementation of
15+
# `to_vec` should fail loudly.
16+
struct FillVector <: AbstractVector{Float64}
17+
x::Float64
18+
len::Int
19+
end
20+
21+
Base.size(x::FillVector) = (x.len,)
22+
Base.getindex(x::FillVector, n::Int) = x.x
23+
24+
function test_to_vec(x::T) where {T}
1525
x_vec, back = to_vec(x)
1626
@test x_vec isa Vector
1727
@test x == back(x_vec)
28+
@test back(x_vec) isa T
1829
return nothing
1930
end
2031

2132
@testset "to_vec" begin
22-
@testset "$T" for T in (Float64, ComplexF64)
33+
@testset "$T" for T in (Float32, ComplexF32, Float64, ComplexF64)
2334
if T == Float64
2435
test_to_vec(1.0)
2536
test_to_vec(1)
@@ -38,6 +49,8 @@ end
3849
test_to_vec(Symmetric(randn(T, 11, 11)))
3950
test_to_vec(Diagonal(randn(T, 7)))
4051
test_to_vec(DummyType(randn(T, 2, 9)))
52+
test_to_vec(SVector{2, T}(1.0, 2.0))
53+
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0))
4154

4255
@testset "$Op" for Op in (Adjoint, Transpose)
4356
test_to_vec(Op(randn(T, 4, 4)))
@@ -62,4 +75,10 @@ end
6275
end
6376
end
6477
end
78+
79+
@testset "FillVector" begin
80+
x = FillVector(5.0, 10)
81+
x_vec, from_vec = to_vec(x)
82+
@test_throws MethodError from_vec(randn(10))
83+
end
6584
end

0 commit comments

Comments
 (0)