1
1
const IndexType = Union{Int,Base. AbstractCartesianIndex}
2
-
3
- struct Partial{Order,T<: IndexType }
4
- indices:: NTuple{Order,T}
2
+ struct Partial{Order,T<: Tuple{Vararg{IndexType,Order}} }
3
+ indices:: T
5
4
end
6
5
7
- # TODO : this is not ideal... how does NTuple{0,Int} <: Tuple{} work??
8
- Partial () = Partial {0,Int} (())
9
- function Partial (indices:: Integer... )
10
- return Partial {length(indices),Int} (indices)
11
- end
12
- function Partial (indices:: Base.AbstractCartesianIndex... )
13
- return Partial {length(indices),Base.AbstractCartesianIndex} (indices)
6
+ Partial (:: Tuple{} ) = Partial {0,Tuple{}} (())
7
+ function Partial (indices:: Tuple{Vararg{T}} ) where {T<: IndexType }
8
+ Ord = length (indices)
9
+ return Partial {Ord,NTuple{Ord,T}} (indices)
14
10
end
15
- partial (indices... ) = Partial (indices... )
11
+ partial (indices:: Tuple{Vararg{T}} ) where {T<: IndexType } = Partial (indices)
12
+ partial (indices:: IndexType... ) = Partial (indices)
16
13
17
14
# # show helpers
18
15
@@ -22,18 +19,18 @@ lower_digits(idx::Base.AbstractCartesianIndex) = join(map(lower_digits, Tuple(id
22
19
# ## Fallbacks
23
20
compact_representation (p:: Partial ) = compact_representation (MIME " text/plain" (), p)
24
21
compact_representation (:: MIME , p:: Partial ) = compact_representation (p)
25
- detailed_representation (p:: Partial ) = """ : Partial ($(join (p. indices," ," )) )"""
26
- detailed_representation (p:: Partial{0} ) = """ : Partial () a zero order derivative"""
22
+ detailed_representation (p:: Partial ) = """ : partial ($(join (p. indices," ," )) )"""
23
+ detailed_representation (p:: Partial{0,Tuple{}} ) = """ : partial () a zero order derivative"""
27
24
28
25
# ## text/plain
29
- compact_representation (:: MIME"text/plain" , :: Partial{0} ) = " id"
26
+ compact_representation (:: MIME"text/plain" , :: Partial{0,Tuple{} } ) = " id"
30
27
function compact_representation (:: MIME"text/plain" , p:: Partial )
31
28
lower_numbers = map (lower_digits, p. indices)
32
29
return join ([" ∂$(x) " for x in lower_numbers])
33
30
end
34
31
35
32
# ## text/html
36
- compact_representation (:: MIME"text/html" , :: Partial{0} ) = """ <span class="text-muted" title="a zero order derivative">id</span>"""
33
+ compact_representation (:: MIME"text/html" , :: Partial{0,Tuple{} } ) = """ <span class="text-muted" title="a zero order derivative">id</span>"""
37
34
function compact_representation (:: MIME"text/html" , p:: Partial )
38
35
return join (map (n -> " ∂<sub>$(n) </sub>" , Tuple (p. indices)), " " )
39
36
end
54
51
55
52
const DiffPt{T} = Tuple{T,Partial}
56
53
57
- gradient (dim:: Integer ) = mappedarray (partial, Base. OneTo (dim))
58
- hessian (dim:: Integer ) = mappedarray (partial, productArray (Base. OneTo (dim), Base. OneTo (dim)))
59
- fullderivative (order:: Integer ,dim:: Integer ) = mappedarray (partial, productArray (ntuple (_-> Base. OneTo (dim), order)... ))
54
+ function fullderivative (:: Val{order} , input_indices:: AbstractVector{Int} ) where {order}
55
+ return mappedarray (partial, productArray (ntuple (_ -> input_indices, Val {order} ())... ))
56
+ end
57
+ fullderivative (:: Val{order} , dim:: Integer ) where {order} = fullderivative (Val {order} (), Base. OneTo (dim))
58
+ function fullderivative (:: Val{order} , input_indices:: AbstractArray{T,N} ) where {order,N,T<: Base.AbstractCartesianIndex{N} }
59
+ return mappedarray (partial, productArray (ntuple (_ -> input_indices, Val {order} ())... ))
60
+ end
61
+
62
+ gradient (input_indices:: AbstractArray ) = fullderivative (Val (1 ), input_indices)
63
+ gradient (dim:: Integer ) = fullderivative (Val (1 ), dim)
64
+
65
+ hessian (input_indices:: AbstractArray ) = fullderivative (Val (2 ), input_indices)
66
+ hessian (dim:: Integer ) = fullderivative (Val (2 ), dim)
60
67
61
68
# idea: lazy mappings can be undone (extract original range -> towards a specialization speedup of broadcasting over multiple derivatives using backwardsdiff)
62
- const MappedPartialVec{T} = ReadonlyMappedArray{Partial{1 ,Int},1 ,T,typeof (partial)}
69
+ const MappedPartialVec{T} = ReadonlyMappedArray{Partial{1 ,Tuple{ Int} },1 ,T,typeof (partial)}
63
70
function extract_range (p_map:: MappedPartialVec{T} ) where {T<: AbstractUnitRange{Int} }
64
71
return p_map. data:: T
65
72
end
0 commit comments