22
33using Base: FastMath
44
5-
65# # helpers
76
87within (lower, upper) = (val) -> lower <= val <= upper
@@ -103,17 +102,98 @@ end
103102
104103@device_override Base. log (x:: Float64 ) = ccall (" extern __nv_log" , llvmcall, Cdouble, (Cdouble,), x)
105104@device_override Base. log (x:: Float32 ) = ccall (" extern __nv_logf" , llvmcall, Cfloat, (Cfloat,), x)
105+ @device_override function Base. log (x:: Float16 )
106+ log_x = @asmcall (""" {.reg.b32 f, C;
107+ .reg.b16 r,h;
108+ mov.b16 h,\$ 1;
109+ cvt.f32.f16 f,h;
110+ lg2.approx.ftz.f32 f,f;
111+ mov.b32 C, 0x3f317218U;
112+ mul.f32 f,f,C;
113+ cvt.rn.f16.f32 r,f;
114+ .reg.b16 spc, ulp, p;
115+ mov.b16 spc, 0X160DU;
116+ mov.b16 ulp, 0x9C00U;
117+ set.eq.f16.f16 p, h, spc;
118+ fma.rn.f16 r,p,ulp,r;
119+ mov.b16 spc, 0X3BFEU;
120+ mov.b16 ulp, 0x8010U;
121+ set.eq.f16.f16 p, h, spc;
122+ fma.rn.f16 r,p,ulp,r;
123+ mov.b16 spc, 0X3C0BU;
124+ mov.b16 ulp, 0x8080U;
125+ set.eq.f16.f16 p, h, spc;
126+ fma.rn.f16 r,p,ulp,r;
127+ mov.b16 spc, 0X6051U;
128+ mov.b16 ulp, 0x1C00U;
129+ set.eq.f16.f16 p, h, spc;
130+ fma.rn.f16 r,p,ulp,r;
131+ mov.b16 \$ 0,r;
132+ }""" , " =h,h" , Float16, Tuple{Float16}, x)
133+ return log_x
134+ end
135+
106136@device_override FastMath. log_fast (x:: Float32 ) = ccall (" extern __nv_fast_logf" , llvmcall, Cfloat, (Cfloat,), x)
107137
108138@device_override Base. log10 (x:: Float64 ) = ccall (" extern __nv_log10" , llvmcall, Cdouble, (Cdouble,), x)
109139@device_override Base. log10 (x:: Float32 ) = ccall (" extern __nv_log10f" , llvmcall, Cfloat, (Cfloat,), x)
140+ @device_override function Base. log10 (x:: Float16 )
141+ log_x = @asmcall (""" {.reg.b16 h, r;
142+ .reg.b32 f, C;
143+ mov.b16 h, \$ 1;
144+ cvt.f32.f16 f, h;
145+ lg2.approx.ftz.f32 f, f;
146+ mov.b32 C, 0x3E9A209BU;
147+ mul.f32 f,f,C;
148+ cvt.rn.f16.f32 r, f;
149+ .reg.b16 spc, ulp, p;
150+ mov.b16 spc, 0x338FU;
151+ mov.b16 ulp, 0x1000U;
152+ set.eq.f16.f16 p, h, spc;
153+ fma.rn.f16 r,p,ulp,r;
154+ mov.b16 spc, 0x33F8U;
155+ mov.b16 ulp, 0x9000U;
156+ set.eq.f16.f16 p, h, spc;
157+ fma.rn.f16 r,p,ulp,r;
158+ mov.b16 spc, 0x57E1U;
159+ mov.b16 ulp, 0x9800U;
160+ set.eq.f16.f16 p, h, spc;
161+ fma.rn.f16 r,p,ulp,r;
162+ mov.b16 spc, 0x719DU;
163+ mov.b16 ulp, 0x9C00U;
164+ set.eq.f16.f16 p, h, spc;
165+ fma.rn.f16 r,p,ulp,r;
166+ mov.b16 \$ 0, r;
167+ }""" , " =h,h" , Float16, Tuple{Float16}, x)
168+ return log_x
169+ end
110170@device_override FastMath. log10_fast (x:: Float32 ) = ccall (" extern __nv_fast_log10f" , llvmcall, Cfloat, (Cfloat,), x)
111171
112172@device_override Base. log1p (x:: Float64 ) = ccall (" extern __nv_log1p" , llvmcall, Cdouble, (Cdouble,), x)
113173@device_override Base. log1p (x:: Float32 ) = ccall (" extern __nv_log1pf" , llvmcall, Cfloat, (Cfloat,), x)
114174
115175@device_override Base. log2 (x:: Float64 ) = ccall (" extern __nv_log2" , llvmcall, Cdouble, (Cdouble,), x)
116176@device_override Base. log2 (x:: Float32 ) = ccall (" extern __nv_log2f" , llvmcall, Cfloat, (Cfloat,), x)
177+ @device_override function Base. log2 (x:: Float16 )
178+ log_x = @asmcall (""" {.reg.b16 h, r;
179+ .reg.b32 f;
180+ mov.b16 h, \$ 1;
181+ cvt.f32.f16 f, h;
182+ lg2.approx.ftz.f32 f, f;
183+ cvt.rn.f16.f32 r, f;
184+ .reg.b16 spc, ulp, p;
185+ mov.b16 spc, 0xA2E2U;
186+ mov.b16 ulp, 0x8080U;
187+ set.eq.f16.f16 p, r, spc;
188+ fma.rn.f16 r,p,ulp,r;
189+ mov.b16 spc, 0xBF46U;
190+ mov.b16 ulp, 0x9400U;
191+ set.eq.f16.f16 p, r, spc;
192+ fma.rn.f16 r,p,ulp,r;
193+ mov.b16 \$ 0, r;
194+ }""" , " =h,h" , Float16, Tuple{Float16}, x)
195+ return log_x
196+ end
117197@device_override FastMath. log2_fast (x:: Float32 ) = ccall (" extern __nv_fast_log2f" , llvmcall, Cfloat, (Cfloat,), x)
118198
119199@device_function logb (x:: Float64 ) = ccall (" extern __nv_logb" , llvmcall, Cdouble, (Cdouble,), x)
@@ -127,16 +207,95 @@ end
127207
128208@device_override Base. exp (x:: Float64 ) = ccall (" extern __nv_exp" , llvmcall, Cdouble, (Cdouble,), x)
129209@device_override Base. exp (x:: Float32 ) = ccall (" extern __nv_expf" , llvmcall, Cfloat, (Cfloat,), x)
210+ @device_override function Base. exp (x:: Float16 )
211+ exp_x = @asmcall (""" {
212+ .reg.b32 f, C, nZ;
213+ .reg.b16 h,r;
214+ mov.b16 h,\$ 1;
215+ cvt.f32.f16 f,h;
216+ mov.b32 C, 0x3fb8aa3bU;
217+ mov.b32 nZ, 0x80000000U;
218+ fma.rn.f32 f,f,C,nZ;
219+ ex2.approx.ftz.f32 f,f;
220+ cvt.rn.f16.f32 r,f;
221+ .reg.b16 spc, ulp, p;
222+ mov.b16 spc,0X1F79U;
223+ mov.b16 ulp,0x9400U;
224+ set.eq.f16.f16 p, h, spc;
225+ fma.rn.f16 r,p,ulp,r;
226+ mov.b16 spc,0X25CFU;
227+ mov.b16 ulp,0x9400U;
228+ set.eq.f16.f16 p, h, spc;
229+ fma.rn.f16 r,p,ulp,r;
230+ mov.b16 spc,0XC13BU;
231+ mov.b16 ulp,0x0400U;
232+ set.eq.f16.f16 p, h, spc;
233+ fma.rn.f16 r,p,ulp,r;
234+ mov.b16 spc,0XC1EFU;
235+ mov.b16 ulp,0x0200U;
236+ set.eq.f16.f16 p, h, spc;
237+ fma.rn.f16 r,p,ulp,r;
238+ mov.b16 \$ 0,r;
239+ }""" , " =h,h" , Float16, Tuple{Float16}, x)
240+ return exp_x
241+ end
130242@device_override FastMath. exp_fast (x:: Float32 ) = ccall (" extern __nv_fast_expf" , llvmcall, Cfloat, (Cfloat,), x)
131243
132244@device_override Base. exp2 (x:: Float64 ) = ccall (" extern __nv_exp2" , llvmcall, Cdouble, (Cdouble,), x)
133245@device_override Base. exp2 (x:: Float32 ) = ccall (" extern __nv_exp2f" , llvmcall, Cfloat, (Cfloat,), x)
246+ @device_override function Base. exp2 (x:: Float16 )
247+ exp_x = @asmcall (""" {.reg.b32 f, ULP;
248+ .reg.b16 r;
249+ mov.b16 r,\$ 1;
250+ cvt.f32.f16 f,r;
251+ ex2.approx.ftz.f32 f,f;
252+ mov.b32 ULP, 0x33800000U;
253+ fma.rn.f32 f,f,ULP,f;
254+ cvt.rn.f16.f32 r,f;
255+ mov.b16 \$ 0,r;
256+ }""" , " =h,h" , Float16, Tuple{Float16}, x)
257+ return exp_x
258+ end
134259@device_override FastMath. exp2_fast (x:: Union{Float32, Float64} ) = exp2 (x)
135- # TODO : enable once PTX > 7.0 is supported
136- # @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
137260
138261@device_override Base. exp10 (x:: Float64 ) = ccall (" extern __nv_exp10" , llvmcall, Cdouble, (Cdouble,), x)
139262@device_override Base. exp10 (x:: Float32 ) = ccall (" extern __nv_exp10f" , llvmcall, Cfloat, (Cfloat,), x)
263+ @device_override function Base. exp10 (x:: Float16 )
264+
265+ exp_x = @asmcall (""" {.reg.b16 h,r;
266+ .reg.b32 f, C, nZ;
267+ mov.b16 h, \$ 1;
268+ cvt.f32.f16 f, h;
269+ mov.b32 C, 0x40549A78U;
270+ mov.b32 nZ, 0x80000000U;
271+ fma.rn.f32 f,f,C,nZ;
272+ ex2.approx.ftz.f32 f, f;
273+ cvt.rn.f16.f32 r, f;
274+ .reg.b16 spc, ulp, p;
275+ mov.b16 spc,0x34DEU;
276+ mov.b16 ulp,0x9800U;
277+ set.eq.f16.f16 p, h, spc;
278+ fma.rn.f16 r,p,ulp,r;
279+ mov.b16 spc,0x9766U;
280+ mov.b16 ulp,0x9000U;
281+ set.eq.f16.f16 p, h, spc;
282+ fma.rn.f16 r,p,ulp,r;
283+ mov.b16 spc,0x9972U;
284+ mov.b16 ulp,0x1000U;
285+ set.eq.f16.f16 p, h, spc;
286+ fma.rn.f16 r,p,ulp,r;
287+ mov.b16 spc,0xA5C4U;
288+ mov.b16 ulp,0x1000U;
289+ set.eq.f16.f16 p, h, spc;
290+ fma.rn.f16 r,p,ulp,r;
291+ mov.b16 spc,0xBF0AU;
292+ mov.b16 ulp,0x8100U;
293+ set.eq.f16.f16 p, h, spc;
294+ fma.rn.f16 r,p,ulp,r;
295+ mov.b16 \$ 0, r;
296+ }""" , " =h,h" , Float16, Tuple{Float16}, x)
297+ return exp_x
298+ end
140299@device_override FastMath. exp10_fast (x:: Float32 ) = ccall (" extern __nv_fast_exp10f" , llvmcall, Cfloat, (Cfloat,), x)
141300
142301@device_override Base. expm1 (x:: Float64 ) = ccall (" extern __nv_expm1" , llvmcall, Cdouble, (Cdouble,), x)
204363
205364@device_override Base. isnan (x:: Float64 ) = (ccall (" extern __nv_isnand" , llvmcall, Int32, (Cdouble,), x)) != 0
206365@device_override Base. isnan (x:: Float32 ) = (ccall (" extern __nv_isnanf" , llvmcall, Int32, (Cfloat,), x)) != 0
366+ @device_override function Base. isnan (x:: Float16 )
367+ if compute_capability () >= sv " 8.0"
368+ return (ccall (" extern __nv_hisnan" , llvmcall, Int32, (Float16,), x)) != 0
369+ else
370+ return isnan (Float32 (x))
371+ end
372+ end
207373
208374@device_function nearbyint (x:: Float64 ) = ccall (" extern __nv_nearbyint" , llvmcall, Cdouble, (Cdouble,), x)
209375@device_function nearbyint (x:: Float32 ) = ccall (" extern __nv_nearbyintf" , llvmcall, Cfloat, (Cfloat,), x)
@@ -223,14 +389,20 @@ end
223389@device_override Base. abs (x:: Int32 ) = ccall (" extern __nv_abs" , llvmcall, Int32, (Int32,), x)
224390@device_override Base. abs (f:: Float64 ) = ccall (" extern __nv_fabs" , llvmcall, Cdouble, (Cdouble,), f)
225391@device_override Base. abs (f:: Float32 ) = ccall (" extern __nv_fabsf" , llvmcall, Cfloat, (Cfloat,), f)
226- # TODO : enable once PTX > 7.0 is supported
227- # @device_override Base.abs(x::Float16) = @asmcall("abs.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
392+ @device_override Base. abs (f:: Float16 ) = Float16 (abs (Float32 (f)))
228393@device_override Base. abs (x:: Int64 ) = ccall (" extern __nv_llabs" , llvmcall, Int64, (Int64,), x)
229394
230395# # roots and powers
231396
232397@device_override Base. sqrt (x:: Float64 ) = ccall (" extern __nv_sqrt" , llvmcall, Cdouble, (Cdouble,), x)
233398@device_override Base. sqrt (x:: Float32 ) = ccall (" extern __nv_sqrtf" , llvmcall, Cfloat, (Cfloat,), x)
399+ @device_override function Base. sqrt (x:: Float16 )
400+ if compute_capability () >= sv " 8.0"
401+ ccall (" extern __nv_hsqrt" , llvmcall, Float16, (Float16,), x)
402+ else
403+ return Float16 (sqrt (Float32 (x)))
404+ end
405+ end
234406@device_override FastMath. sqrt_fast (x:: Union{Float32, Float64} ) = sqrt (x)
235407
236408@device_function rsqrt (x:: Float64 ) = ccall (" extern __nv_rsqrt" , llvmcall, Cdouble, (Cdouble,), x)
0 commit comments