Skip to content

Commit a37ae26

Browse files
Add HyperDualNumbersExt (#179)
* Add HyperDualNumbersExt * add test and dependency * minor fix on test * tests fix with proper testsets * fix hyperduals ext * bump version * some format fix
1 parent 00d50b3 commit a37ae26

File tree

5 files changed

+357
-9
lines changed

5 files changed

+357
-9
lines changed

Project.toml

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "Octavian"
22
uuid = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
33
authors = ["Chris Elrod", "Dilum Aluthge", "Mason Protter", "contributors"]
4-
version = "0.3.22"
4+
version = "0.3.23"
55

66
[deps]
77
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
88
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
9+
HyperDualNumbers = "50ceba7f-c3ee-5a84-a6e8-3ad40456ec97"
910
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
1011
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1112
ManualMemory = "d125e4d3-2237-4719-b19c-fa641b8a4667"
@@ -16,9 +17,16 @@ StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
1617
ThreadingUtilities = "8290d209-cae3-49c0-8002-c8c24d57dab5"
1718
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1819

20+
[weakdeps]
21+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
22+
23+
[extensions]
24+
ForwardDiffExt = "ForwardDiff"
25+
1926
[compat]
2027
CPUSummary = "0.1.26, 0.2.1"
2128
ForwardDiff = "0.10"
29+
HyperDualNumbers = "4"
2230
IfElse = "0.1"
2331
LoopVectorization = "0.12.86"
2432
ManualMemory = "0.1.1"
@@ -30,13 +38,11 @@ ThreadingUtilities = "0.5"
3038
VectorizationBase = "0.21.15"
3139
julia = "1.6"
3240

33-
[extensions]
34-
ForwardDiffExt = "ForwardDiff"
35-
3641
[extras]
3742
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3843
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3944
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
45+
HyperDualNumbers = "50ceba7f-c3ee-5a84-a6e8-3ad40456ec97"
4046
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
4147
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4248
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
@@ -45,7 +51,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4551
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
4652

4753
[targets]
48-
test = ["Aqua", "BenchmarkTools", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "LoopVectorization", "Random", "VectorizationBase", "Test"]
49-
50-
[weakdeps]
51-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
54+
test = ["Aqua", "BenchmarkTools", "ForwardDiff", "HyperDualNumbers", "InteractiveUtils", "LinearAlgebra", "LoopVectorization", "Random", "VectorizationBase", "Test"]

ext/HyperDualNumbersExt.jl

+249
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
module HyperDualNumbersExt
2+
3+
using HyperDualNumbers: Hyper
4+
using Octavian: ArrayInterface,
5+
@turbo, @tturbo,
6+
One, Zero,
7+
indices, static
8+
import Octavian: real_rep, _matmul!, _matmul_serial!
9+
10+
real_rep(a::AbstractArray{DualT}) where {T,DualT<:Hyper{T}} =
11+
reinterpret(reshape, T, a)
12+
_view1(B::AbstractMatrix) = @view(B[1, :])
13+
_view1(B::AbstractArray{<:Any,3}) = @view(B[1, :, :])
14+
15+
for AbstractVectorOrMatrix in (:AbstractVector, :AbstractMatrix)
16+
# multiplication of dual vector/matrix by standard matrix from the left
17+
@eval function _matmul!(
18+
_C::$(AbstractVectorOrMatrix){DualT},
19+
A::AbstractMatrix,
20+
_B::$(AbstractVectorOrMatrix){DualT},
21+
α,
22+
β = Zero(),
23+
nthread::Nothing = nothing,
24+
MKN = nothing,
25+
contig_axis = nothing
26+
) where {T, DualT<:Hyper{T}}
27+
B = real_rep(_B)
28+
C = real_rep(_C)
29+
30+
@tturbo for n indices((C, B), 3),
31+
m indices((C, A), (2, 1)),
32+
l in indices((C, B), 1)
33+
34+
Cₗₘₙ = zero(eltype(C))
35+
for k indices((A, B), 2)
36+
Cₗₘₙ += A[m, k] * B[l, k, n]
37+
end
38+
C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
39+
end
40+
41+
_C
42+
end
43+
44+
# multiplication of dual matrix by standard vector/matrix from the right
45+
@eval @inline function _matmul!(
46+
_C::$(AbstractVectorOrMatrix){DualT},
47+
_A::AbstractMatrix{DualT},
48+
B::$(AbstractVectorOrMatrix),
49+
α = One(),
50+
β = Zero(),
51+
nthread::Nothing = nothing,
52+
MKN = nothing
53+
) where {T,DualT<:Hyper{T}}
54+
if Bool(ArrayInterface.is_dense(_C)) &&
55+
Bool(ArrayInterface.is_column_major(_C)) &&
56+
Bool(ArrayInterface.is_dense(_A)) &&
57+
Bool(ArrayInterface.is_column_major(_A))
58+
# we can avoid the reshape and call the standard method
59+
A = reinterpret(T, _A)
60+
C = reinterpret(T, _C)
61+
_matmul!(C, A, B, α, β, nthread, nothing)
62+
else
63+
# we cannot use the standard method directly
64+
A = real_rep(_A)
65+
C = real_rep(_C)
66+
67+
@tturbo for n indices((C, B), (3, 2)),
68+
m indices((C, A), 2),
69+
l in indices((C, A), 1)
70+
71+
Cₗₘₙ = zero(eltype(C))
72+
for k indices((A, B), (3, 1))
73+
Cₗₘₙ += A[l, m, k] * B[k, n]
74+
end
75+
C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
76+
end
77+
end
78+
79+
_C
80+
end
81+
82+
@eval @inline function _matmul!(
83+
_C::$(AbstractVectorOrMatrix){DualT},
84+
_A::AbstractMatrix{DualT},
85+
_B::$(AbstractVectorOrMatrix){DualT},
86+
α = One(),
87+
β = Zero(),
88+
nthread::Nothing = nothing,
89+
MKN = nothing,
90+
contig = nothing
91+
) where {T,DualT<:Hyper{T}}
92+
A = real_rep(_A)
93+
C = real_rep(_C)
94+
B = real_rep(_B)
95+
if Bool(ArrayInterface.is_dense(_C)) &&
96+
Bool(ArrayInterface.is_column_major(_C)) &&
97+
Bool(ArrayInterface.is_dense(_A)) &&
98+
Bool(ArrayInterface.is_column_major(_A))
99+
# we can avoid the reshape and call the standard method
100+
Ar = reinterpret(T, _A)
101+
Cr = reinterpret(T, _C)
102+
_matmul!(Cr, Ar, _view1(B), α, β, nthread, nothing)
103+
else
104+
# we cannot use the standard method directly
105+
@tturbo for n indices((C, B), 3),
106+
m indices((C, A), 2),
107+
l in indices((C, A), 1)
108+
109+
Cₗₘₙ = zero(eltype(C))
110+
for k indices((A, B), (3, 2))
111+
Cₗₘₙ += A[l, m, k] * B[1, k, n]
112+
end
113+
C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
114+
end
115+
end
116+
@tturbo for n indices((B, C), 3), m indices((A, C), 2), p 1:3
117+
Cₚₘₙ = zero(eltype(C))
118+
for k indices((A, B), (3, 2))
119+
Cₚₘₙ += A[1, m, k] * B[p+1, k, n]
120+
end
121+
C[p+1, m, n] = C[p+1, m, n] + α * Cₚₘₙ
122+
end
123+
124+
@tturbo for n indices((B, C), 3), m indices((A, C), 2)
125+
Cₘₙ = zero(eltype(C))
126+
for k indices((A, B), (3, 2))
127+
Cₘₙ += A[2, m, k] * B[3, k, n] + A[3, m, k] * B[2, k, n]
128+
end
129+
C[4, m, n] = C[4, m, n] + α * Cₘₙ
130+
end
131+
_C
132+
end
133+
134+
# multiplication of dual vector/matrix by standard matrix from the left
135+
@eval function _matmul_serial!(
136+
_C::$(AbstractVectorOrMatrix){DualT},
137+
A::AbstractMatrix,
138+
_B::$(AbstractVectorOrMatrix){DualT},
139+
α,
140+
β,
141+
MKN
142+
) where {T, DualT<:Hyper{T}}
143+
B = real_rep(_B)
144+
C = real_rep(_C)
145+
146+
@turbo for n indices((C, B), 3),
147+
m indices((C, A), (2, 1)),
148+
l in indices((C, B), 1)
149+
150+
Cₗₘₙ = zero(eltype(C))
151+
for k indices((A, B), 2)
152+
Cₗₘₙ += A[m, k] * B[l, k, n]
153+
end
154+
C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
155+
end
156+
157+
_C
158+
end
159+
160+
# multiplication of dual matrix by standard vector/matrix from the right
161+
@eval @inline function _matmul_serial!(
162+
_C::$(AbstractVectorOrMatrix){DualT},
163+
_A::AbstractMatrix{DualT},
164+
B::$(AbstractVectorOrMatrix),
165+
α,
166+
β,
167+
MKN
168+
) where {T,DualT<:Hyper{T}}
169+
if Bool(ArrayInterface.is_dense(_C)) &&
170+
Bool(ArrayInterface.is_column_major(_C)) &&
171+
Bool(ArrayInterface.is_dense(_A)) &&
172+
Bool(ArrayInterface.is_column_major(_A))
173+
# we can avoid the reshape and call the standard method
174+
A = reinterpret(T, _A)
175+
C = reinterpret(T, _C)
176+
_matmul_serial!(C, A, B, α, β, nothing)
177+
else
178+
# we cannot use the standard method directly
179+
A = real_rep(_A)
180+
C = real_rep(_C)
181+
182+
@turbo for n indices((C, B), (3, 2)),
183+
m indices((C, A), 2),
184+
l in indices((C, A), 1)
185+
186+
Cₗₘₙ = zero(eltype(C))
187+
for k indices((A, B), (3, 1))
188+
Cₗₘₙ += A[l, m, k] * B[k, n]
189+
end
190+
C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
191+
end
192+
end
193+
194+
_C
195+
end
196+
197+
@eval @inline function _matmul_serial!(
198+
_C::$(AbstractVectorOrMatrix){DualT},
199+
_A::AbstractMatrix{DualT},
200+
_B::$(AbstractVectorOrMatrix){DualT},
201+
α,
202+
β,
203+
MKN
204+
) where {T, DualT<:Hyper{T}}
205+
A = real_rep(_A)
206+
C = real_rep(_C)
207+
B = real_rep(_B)
208+
if Bool(ArrayInterface.is_dense(_C)) &&
209+
Bool(ArrayInterface.is_column_major(_C)) &&
210+
Bool(ArrayInterface.is_dense(_A)) &&
211+
Bool(ArrayInterface.is_column_major(_A))
212+
# we can avoid the reshape and call the standard method
213+
Ar = reinterpret(T, _A)
214+
Cr = reinterpret(T, _C)
215+
_matmul_serial!(Cr, Ar, _view1(B), α, β, nothing)
216+
else
217+
# we cannot use the standard method directly
218+
@turbo for n indices((C, B), 3),
219+
m indices((C, A), 2),
220+
l in indices((C, A), 1)
221+
222+
Cₗₘₙ = zero(eltype(C))
223+
for k indices((A, B), (3, 2))
224+
Cₗₘₙ += A[l, m, k] * B[1, k, n]
225+
end
226+
C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
227+
end
228+
end
229+
230+
@turbo for n indices((B, C), 3), m indices((A, C), 2), p 1:3
231+
Cₚₘₙ = zero(eltype(C))
232+
for k indices((A, B), (3, 2))
233+
Cₚₘₙ += A[1, m, k] * B[p+1, k, n]
234+
end
235+
C[p+1, m, n] = C[p+1, m, n] + α * Cₚₘₙ
236+
end
237+
238+
@tturbo for n indices((B, C), 3), m indices((A, C), 2)
239+
Cₘₙ = zero(eltype(C))
240+
for k indices((A, B), (3, 2))
241+
Cₘₙ += A[2, m, k] * B[3, k, n] + A[3, m, k] * B[2, k, n]
242+
end
243+
C[4, m, n] = C[4, m, n] + α * Cₘₙ
244+
end
245+
_C
246+
end
247+
end # for
248+
249+
end # module

src/Octavian.jl

+3
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ if !isdefined(Base, :get_extension)
7373
include("../ext/ForwardDiffExt.jl")
7474
end
7575

76+
# TODO: confirm when we need this extension
77+
include("../ext/HyperDualNumbersExt.jl")
78+
7679
@static if VERSION >= v"1.8.0-beta1"
7780
@setup_workload begin
7881
# Putting some things in `setup` can reduce the size of the

0 commit comments

Comments
 (0)