1+ module HyperDualNumbersExt
2+
3+ using HyperDualNumbers: Hyper
4+ using Octavian: ArrayInterface,
5+ @turbo , @tturbo ,
6+ One, Zero,
7+ indices, static
8+ import Octavian: real_rep, _matmul!, _matmul_serial!
9+
10+ real_rep (a:: AbstractArray{DualT} ) where {T,DualT<: Hyper{T} } =
11+ reinterpret (reshape, T, a)
12+ _view1 (B:: AbstractMatrix ) = @view (B[1 , :])
13+ _view1 (B:: AbstractArray{<:Any,3} ) = @view (B[1 , :, :])
14+
15+ for AbstractVectorOrMatrix in (:AbstractVector , :AbstractMatrix )
16+ # multiplication of dual vector/matrix by standard matrix from the left
17+ @eval function _matmul! (
18+ _C:: $ (AbstractVectorOrMatrix){DualT},
19+ A:: AbstractMatrix ,
20+ _B:: $ (AbstractVectorOrMatrix){DualT},
21+ α,
22+ β = Zero (),
23+ nthread:: Nothing = nothing ,
24+ MKN = nothing ,
25+ contig_axis = nothing
26+ ) where {T, DualT<: Hyper{T} }
27+ B = real_rep (_B)
28+ C = real_rep (_C)
29+
30+ @tturbo for n ∈ indices ((C, B), 3 ),
31+ m ∈ indices ((C, A), (2 , 1 )),
32+ l in indices ((C, B), 1 )
33+
34+ Cₗₘₙ = zero (eltype (C))
35+ for k ∈ indices ((A, B), 2 )
36+ Cₗₘₙ += A[m, k] * B[l, k, n]
37+ end
38+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
39+ end
40+
41+ _C
42+ end
43+
44+ # multiplication of dual matrix by standard vector/matrix from the right
45+ @eval @inline function _matmul! (
46+ _C:: $ (AbstractVectorOrMatrix){DualT},
47+ _A:: AbstractMatrix{DualT} ,
48+ B:: $ (AbstractVectorOrMatrix),
49+ α = One (),
50+ β = Zero (),
51+ nthread:: Nothing = nothing ,
52+ MKN = nothing
53+ ) where {T,DualT<: Hyper{T} }
54+ if Bool (ArrayInterface. is_dense (_C)) &&
55+ Bool (ArrayInterface. is_column_major (_C)) &&
56+ Bool (ArrayInterface. is_dense (_A)) &&
57+ Bool (ArrayInterface. is_column_major (_A))
58+ # we can avoid the reshape and call the standard method
59+ A = reinterpret (T, _A)
60+ C = reinterpret (T, _C)
61+ _matmul! (C, A, B, α, β, nthread, nothing )
62+ else
63+ # we cannot use the standard method directly
64+ A = real_rep (_A)
65+ C = real_rep (_C)
66+
67+ @tturbo for n ∈ indices ((C, B), (3 , 2 )),
68+ m ∈ indices ((C, A), 2 ),
69+ l in indices ((C, A), 1 )
70+
71+ Cₗₘₙ = zero (eltype (C))
72+ for k ∈ indices ((A, B), (3 , 1 ))
73+ Cₗₘₙ += A[l, m, k] * B[k, n]
74+ end
75+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
76+ end
77+ end
78+
79+ _C
80+ end
81+
82+ @eval @inline function _matmul! (
83+ _C:: $ (AbstractVectorOrMatrix){DualT},
84+ _A:: AbstractMatrix{DualT} ,
85+ _B:: $ (AbstractVectorOrMatrix){DualT},
86+ α = One (),
87+ β = Zero (),
88+ nthread:: Nothing = nothing ,
89+ MKN = nothing ,
90+ contig = nothing
91+ ) where {T,DualT<: Hyper{T} }
92+ A = real_rep (_A)
93+ C = real_rep (_C)
94+ B = real_rep (_B)
95+ if Bool (ArrayInterface. is_dense (_C)) &&
96+ Bool (ArrayInterface. is_column_major (_C)) &&
97+ Bool (ArrayInterface. is_dense (_A)) &&
98+ Bool (ArrayInterface. is_column_major (_A))
99+ # we can avoid the reshape and call the standard method
100+ Ar = reinterpret (T, _A)
101+ Cr = reinterpret (T, _C)
102+ _matmul! (Cr, Ar, _view1 (B), α, β, nthread, nothing )
103+ else
104+ # we cannot use the standard method directly
105+ @tturbo for n ∈ indices ((C, B), 3 ),
106+ m ∈ indices ((C, A), 2 ),
107+ l in indices ((C, A), 1 )
108+
109+ Cₗₘₙ = zero (eltype (C))
110+ for k ∈ indices ((A, B), (3 , 2 ))
111+ Cₗₘₙ += A[l, m, k] * B[1 , k, n]
112+ end
113+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
114+ end
115+ end
116+ @tturbo for n ∈ indices ((B, C), 3 ), m ∈ indices ((A, C), 2 ), p ∈ 1 : 3
117+ Cₚₘₙ = zero (eltype (C))
118+ for k ∈ indices ((A, B), (3 , 2 ))
119+ Cₚₘₙ += A[1 , m, k] * B[p+ 1 , k, n]
120+ end
121+ C[p+ 1 , m, n] = C[p+ 1 , m, n] + α * Cₚₘₙ
122+ end
123+
124+ @tturbo for n ∈ indices ((B, C), 3 ), m ∈ indices ((A, C), 2 )
125+ Cₘₙ = zero (eltype (C))
126+ for k ∈ indices ((A, B), (3 , 2 ))
127+ Cₘₙ += A[2 , m, k] * B[3 , k, n] + A[3 , m, k] * B[2 , k, n]
128+ end
129+ C[4 , m, n] = C[4 , m, n] + α * Cₘₙ
130+ end
131+ _C
132+ end
133+
134+ # multiplication of dual vector/matrix by standard matrix from the left
135+ @eval function _matmul_serial! (
136+ _C:: $ (AbstractVectorOrMatrix){DualT},
137+ A:: AbstractMatrix ,
138+ _B:: $ (AbstractVectorOrMatrix){DualT},
139+ α,
140+ β,
141+ MKN
142+ ) where {T, DualT<: Hyper{T} }
143+ B = real_rep (_B)
144+ C = real_rep (_C)
145+
146+ @turbo for n ∈ indices ((C, B), 3 ),
147+ m ∈ indices ((C, A), (2 , 1 )),
148+ l in indices ((C, B), 1 )
149+
150+ Cₗₘₙ = zero (eltype (C))
151+ for k ∈ indices ((A, B), 2 )
152+ Cₗₘₙ += A[m, k] * B[l, k, n]
153+ end
154+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
155+ end
156+
157+ _C
158+ end
159+
160+ # multiplication of dual matrix by standard vector/matrix from the right
161+ @eval @inline function _matmul_serial! (
162+ _C:: $ (AbstractVectorOrMatrix){DualT},
163+ _A:: AbstractMatrix{DualT} ,
164+ B:: $ (AbstractVectorOrMatrix),
165+ α,
166+ β,
167+ MKN
168+ ) where {T,DualT<: Hyper{T} }
169+ if Bool (ArrayInterface. is_dense (_C)) &&
170+ Bool (ArrayInterface. is_column_major (_C)) &&
171+ Bool (ArrayInterface. is_dense (_A)) &&
172+ Bool (ArrayInterface. is_column_major (_A))
173+ # we can avoid the reshape and call the standard method
174+ A = reinterpret (T, _A)
175+ C = reinterpret (T, _C)
176+ _matmul_serial! (C, A, B, α, β, nothing )
177+ else
178+ # we cannot use the standard method directly
179+ A = real_rep (_A)
180+ C = real_rep (_C)
181+
182+ @turbo for n ∈ indices ((C, B), (3 , 2 )),
183+ m ∈ indices ((C, A), 2 ),
184+ l in indices ((C, A), 1 )
185+
186+ Cₗₘₙ = zero (eltype (C))
187+ for k ∈ indices ((A, B), (3 , 1 ))
188+ Cₗₘₙ += A[l, m, k] * B[k, n]
189+ end
190+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
191+ end
192+ end
193+
194+ _C
195+ end
196+
197+ @eval @inline function _matmul_serial! (
198+ _C:: $ (AbstractVectorOrMatrix){DualT},
199+ _A:: AbstractMatrix{DualT} ,
200+ _B:: $ (AbstractVectorOrMatrix){DualT},
201+ α,
202+ β,
203+ MKN
204+ ) where {T, DualT<: Hyper{T} }
205+ A = real_rep (_A)
206+ C = real_rep (_C)
207+ B = real_rep (_B)
208+ if Bool (ArrayInterface. is_dense (_C)) &&
209+ Bool (ArrayInterface. is_column_major (_C)) &&
210+ Bool (ArrayInterface. is_dense (_A)) &&
211+ Bool (ArrayInterface. is_column_major (_A))
212+ # we can avoid the reshape and call the standard method
213+ Ar = reinterpret (T, _A)
214+ Cr = reinterpret (T, _C)
215+ _matmul_serial! (Cr, Ar, _view1 (B), α, β, nothing )
216+ else
217+ # we cannot use the standard method directly
218+ @turbo for n ∈ indices ((C, B), 3 ),
219+ m ∈ indices ((C, A), 2 ),
220+ l in indices ((C, A), 1 )
221+
222+ Cₗₘₙ = zero (eltype (C))
223+ for k ∈ indices ((A, B), (3 , 2 ))
224+ Cₗₘₙ += A[l, m, k] * B[1 , k, n]
225+ end
226+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
227+ end
228+ end
229+
230+ @turbo for n ∈ indices ((B, C), 3 ), m ∈ indices ((A, C), 2 ), p ∈ 1 : 3
231+ Cₚₘₙ = zero (eltype (C))
232+ for k ∈ indices ((A, B), (3 , 2 ))
233+ Cₚₘₙ += A[1 , m, k] * B[p+ 1 , k, n]
234+ end
235+ C[p+ 1 , m, n] = C[p+ 1 , m, n] + α * Cₚₘₙ
236+ end
237+
238+ @tturbo for n ∈ indices ((B, C), 3 ), m ∈ indices ((A, C), 2 )
239+ Cₘₙ = zero (eltype (C))
240+ for k ∈ indices ((A, B), (3 , 2 ))
241+ Cₘₙ += A[2 , m, k] * B[3 , k, n] + A[3 , m, k] * B[2 , k, n]
242+ end
243+ C[4 , m, n] = C[4 , m, n] + α * Cₘₙ
244+ end
245+ _C
246+ end
247+ end # for
248+
249+ end # module
0 commit comments