1
+ module HyperDualNumbersExt
2
+
3
+ using HyperDualNumbers: Hyper
4
+ using Octavian: ArrayInterface,
5
+ @turbo , @tturbo ,
6
+ One, Zero,
7
+ indices, static
8
+ import Octavian: real_rep, _matmul!, _matmul_serial!
9
+
10
+ real_rep (a:: AbstractArray{DualT} ) where {T,DualT<: Hyper{T} } =
11
+ reinterpret (reshape, T, a)
12
+ _view1 (B:: AbstractMatrix ) = @view (B[1 , :])
13
+ _view1 (B:: AbstractArray{<:Any,3} ) = @view (B[1 , :, :])
14
+
15
+ for AbstractVectorOrMatrix in (:AbstractVector , :AbstractMatrix )
16
+ # multiplication of dual vector/matrix by standard matrix from the left
17
+ @eval function _matmul! (
18
+ _C:: $ (AbstractVectorOrMatrix){DualT},
19
+ A:: AbstractMatrix ,
20
+ _B:: $ (AbstractVectorOrMatrix){DualT},
21
+ α,
22
+ β = Zero (),
23
+ nthread:: Nothing = nothing ,
24
+ MKN = nothing ,
25
+ contig_axis = nothing
26
+ ) where {T, DualT<: Hyper{T} }
27
+ B = real_rep (_B)
28
+ C = real_rep (_C)
29
+
30
+ @tturbo for n ∈ indices ((C, B), 3 ),
31
+ m ∈ indices ((C, A), (2 , 1 )),
32
+ l in indices ((C, B), 1 )
33
+
34
+ Cₗₘₙ = zero (eltype (C))
35
+ for k ∈ indices ((A, B), 2 )
36
+ Cₗₘₙ += A[m, k] * B[l, k, n]
37
+ end
38
+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
39
+ end
40
+
41
+ _C
42
+ end
43
+
44
+ # multiplication of dual matrix by standard vector/matrix from the right
45
+ @eval @inline function _matmul! (
46
+ _C:: $ (AbstractVectorOrMatrix){DualT},
47
+ _A:: AbstractMatrix{DualT} ,
48
+ B:: $ (AbstractVectorOrMatrix),
49
+ α = One (),
50
+ β = Zero (),
51
+ nthread:: Nothing = nothing ,
52
+ MKN = nothing
53
+ ) where {T,DualT<: Hyper{T} }
54
+ if Bool (ArrayInterface. is_dense (_C)) &&
55
+ Bool (ArrayInterface. is_column_major (_C)) &&
56
+ Bool (ArrayInterface. is_dense (_A)) &&
57
+ Bool (ArrayInterface. is_column_major (_A))
58
+ # we can avoid the reshape and call the standard method
59
+ A = reinterpret (T, _A)
60
+ C = reinterpret (T, _C)
61
+ _matmul! (C, A, B, α, β, nthread, nothing )
62
+ else
63
+ # we cannot use the standard method directly
64
+ A = real_rep (_A)
65
+ C = real_rep (_C)
66
+
67
+ @tturbo for n ∈ indices ((C, B), (3 , 2 )),
68
+ m ∈ indices ((C, A), 2 ),
69
+ l in indices ((C, A), 1 )
70
+
71
+ Cₗₘₙ = zero (eltype (C))
72
+ for k ∈ indices ((A, B), (3 , 1 ))
73
+ Cₗₘₙ += A[l, m, k] * B[k, n]
74
+ end
75
+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
76
+ end
77
+ end
78
+
79
+ _C
80
+ end
81
+
82
+ @eval @inline function _matmul! (
83
+ _C:: $ (AbstractVectorOrMatrix){DualT},
84
+ _A:: AbstractMatrix{DualT} ,
85
+ _B:: $ (AbstractVectorOrMatrix){DualT},
86
+ α = One (),
87
+ β = Zero (),
88
+ nthread:: Nothing = nothing ,
89
+ MKN = nothing ,
90
+ contig = nothing
91
+ ) where {T,DualT<: Hyper{T} }
92
+ A = real_rep (_A)
93
+ C = real_rep (_C)
94
+ B = real_rep (_B)
95
+ if Bool (ArrayInterface. is_dense (_C)) &&
96
+ Bool (ArrayInterface. is_column_major (_C)) &&
97
+ Bool (ArrayInterface. is_dense (_A)) &&
98
+ Bool (ArrayInterface. is_column_major (_A))
99
+ # we can avoid the reshape and call the standard method
100
+ Ar = reinterpret (T, _A)
101
+ Cr = reinterpret (T, _C)
102
+ _matmul! (Cr, Ar, _view1 (B), α, β, nthread, nothing )
103
+ else
104
+ # we cannot use the standard method directly
105
+ @tturbo for n ∈ indices ((C, B), 3 ),
106
+ m ∈ indices ((C, A), 2 ),
107
+ l in indices ((C, A), 1 )
108
+
109
+ Cₗₘₙ = zero (eltype (C))
110
+ for k ∈ indices ((A, B), (3 , 2 ))
111
+ Cₗₘₙ += A[l, m, k] * B[1 , k, n]
112
+ end
113
+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
114
+ end
115
+ end
116
+ @tturbo for n ∈ indices ((B, C), 3 ), m ∈ indices ((A, C), 2 ), p ∈ 1 : 3
117
+ Cₚₘₙ = zero (eltype (C))
118
+ for k ∈ indices ((A, B), (3 , 2 ))
119
+ Cₚₘₙ += A[1 , m, k] * B[p+ 1 , k, n]
120
+ end
121
+ C[p+ 1 , m, n] = C[p+ 1 , m, n] + α * Cₚₘₙ
122
+ end
123
+
124
+ @tturbo for n ∈ indices ((B, C), 3 ), m ∈ indices ((A, C), 2 )
125
+ Cₘₙ = zero (eltype (C))
126
+ for k ∈ indices ((A, B), (3 , 2 ))
127
+ Cₘₙ += A[2 , m, k] * B[3 , k, n] + A[3 , m, k] * B[2 , k, n]
128
+ end
129
+ C[4 , m, n] = C[4 , m, n] + α * Cₘₙ
130
+ end
131
+ _C
132
+ end
133
+
134
+ # multiplication of dual vector/matrix by standard matrix from the left
135
+ @eval function _matmul_serial! (
136
+ _C:: $ (AbstractVectorOrMatrix){DualT},
137
+ A:: AbstractMatrix ,
138
+ _B:: $ (AbstractVectorOrMatrix){DualT},
139
+ α,
140
+ β,
141
+ MKN
142
+ ) where {T, DualT<: Hyper{T} }
143
+ B = real_rep (_B)
144
+ C = real_rep (_C)
145
+
146
+ @turbo for n ∈ indices ((C, B), 3 ),
147
+ m ∈ indices ((C, A), (2 , 1 )),
148
+ l in indices ((C, B), 1 )
149
+
150
+ Cₗₘₙ = zero (eltype (C))
151
+ for k ∈ indices ((A, B), 2 )
152
+ Cₗₘₙ += A[m, k] * B[l, k, n]
153
+ end
154
+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
155
+ end
156
+
157
+ _C
158
+ end
159
+
160
+ # multiplication of dual matrix by standard vector/matrix from the right
161
+ @eval @inline function _matmul_serial! (
162
+ _C:: $ (AbstractVectorOrMatrix){DualT},
163
+ _A:: AbstractMatrix{DualT} ,
164
+ B:: $ (AbstractVectorOrMatrix),
165
+ α,
166
+ β,
167
+ MKN
168
+ ) where {T,DualT<: Hyper{T} }
169
+ if Bool (ArrayInterface. is_dense (_C)) &&
170
+ Bool (ArrayInterface. is_column_major (_C)) &&
171
+ Bool (ArrayInterface. is_dense (_A)) &&
172
+ Bool (ArrayInterface. is_column_major (_A))
173
+ # we can avoid the reshape and call the standard method
174
+ A = reinterpret (T, _A)
175
+ C = reinterpret (T, _C)
176
+ _matmul_serial! (C, A, B, α, β, nothing )
177
+ else
178
+ # we cannot use the standard method directly
179
+ A = real_rep (_A)
180
+ C = real_rep (_C)
181
+
182
+ @turbo for n ∈ indices ((C, B), (3 , 2 )),
183
+ m ∈ indices ((C, A), 2 ),
184
+ l in indices ((C, A), 1 )
185
+
186
+ Cₗₘₙ = zero (eltype (C))
187
+ for k ∈ indices ((A, B), (3 , 1 ))
188
+ Cₗₘₙ += A[l, m, k] * B[k, n]
189
+ end
190
+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
191
+ end
192
+ end
193
+
194
+ _C
195
+ end
196
+
197
+ @eval @inline function _matmul_serial! (
198
+ _C:: $ (AbstractVectorOrMatrix){DualT},
199
+ _A:: AbstractMatrix{DualT} ,
200
+ _B:: $ (AbstractVectorOrMatrix){DualT},
201
+ α,
202
+ β,
203
+ MKN
204
+ ) where {T, DualT<: Hyper{T} }
205
+ A = real_rep (_A)
206
+ C = real_rep (_C)
207
+ B = real_rep (_B)
208
+ if Bool (ArrayInterface. is_dense (_C)) &&
209
+ Bool (ArrayInterface. is_column_major (_C)) &&
210
+ Bool (ArrayInterface. is_dense (_A)) &&
211
+ Bool (ArrayInterface. is_column_major (_A))
212
+ # we can avoid the reshape and call the standard method
213
+ Ar = reinterpret (T, _A)
214
+ Cr = reinterpret (T, _C)
215
+ _matmul_serial! (Cr, Ar, _view1 (B), α, β, nothing )
216
+ else
217
+ # we cannot use the standard method directly
218
+ @turbo for n ∈ indices ((C, B), 3 ),
219
+ m ∈ indices ((C, A), 2 ),
220
+ l in indices ((C, A), 1 )
221
+
222
+ Cₗₘₙ = zero (eltype (C))
223
+ for k ∈ indices ((A, B), (3 , 2 ))
224
+ Cₗₘₙ += A[l, m, k] * B[1 , k, n]
225
+ end
226
+ C[l, m, n] = α * Cₗₘₙ + β * C[l, m, n]
227
+ end
228
+ end
229
+
230
+ @turbo for n ∈ indices ((B, C), 3 ), m ∈ indices ((A, C), 2 ), p ∈ 1 : 3
231
+ Cₚₘₙ = zero (eltype (C))
232
+ for k ∈ indices ((A, B), (3 , 2 ))
233
+ Cₚₘₙ += A[1 , m, k] * B[p+ 1 , k, n]
234
+ end
235
+ C[p+ 1 , m, n] = C[p+ 1 , m, n] + α * Cₚₘₙ
236
+ end
237
+
238
+ @tturbo for n ∈ indices ((B, C), 3 ), m ∈ indices ((A, C), 2 )
239
+ Cₘₙ = zero (eltype (C))
240
+ for k ∈ indices ((A, B), (3 , 2 ))
241
+ Cₘₙ += A[2 , m, k] * B[3 , k, n] + A[3 , m, k] * B[2 , k, n]
242
+ end
243
+ C[4 , m, n] = C[4 , m, n] + α * Cₘₙ
244
+ end
245
+ _C
246
+ end
247
+ end # for
248
+
249
+ end # module
0 commit comments