Skip to content
This repository was archived by the owner on Apr 28, 2025. It is now read-only.

Commit 9ee4137

Browse files
committed
WIP f16 fma
1 parent 0f6b1bb commit 9ee4137

File tree

15 files changed

+193
-21
lines changed

15 files changed

+193
-21
lines changed

crates/libm-macros/src/shared.rs

+7
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ const ALL_OPERATIONS_NESTED: &[(FloatTy, Signature, Option<Signature>, &[&str])]
9292
None,
9393
&["copysignf128"],
9494
),
95+
(
96+
// `(f16, f16, f16) -> f16`
97+
FloatTy::F16,
98+
Signature { args: &[Ty::F16, Ty::F16, Ty::F16], returns: &[Ty::F16] },
99+
None,
100+
&["fmaf16"],
101+
),
95102
(
96103
// `(f32, f32, f32) -> f32`
97104
FloatTy::F32,

crates/libm-test/src/f8_impl.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ pub struct f8(u8);
2020
impl Float for f8 {
2121
type Int = u8;
2222
type SignedInt = i8;
23-
type ExpInt = i8;
2423

2524
const ZERO: Self = Self(0b0_0000_000);
2625
const NEG_ZERO: Self = Self(0b1_0000_000);
@@ -62,8 +61,8 @@ impl Float for f8 {
6261
self.0 & Self::SIGN_MASK != 0
6362
}
6463

65-
fn exp(self) -> Self::ExpInt {
66-
unimplemented!()
64+
fn exp(self) -> i32 {
65+
((self.to_bits() & Self::EXP_MASK) >> Self::SIG_BITS) as i32
6766
}
6867

6968
fn from_bits(a: Self::Int) -> Self {

crates/libm-test/src/mpfloat.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ libm_macros::for_each_function! {
147147
expm1 | expm1f => exp_m1,
148148
fabs | fabsf => abs,
149149
fdim | fdimf => positive_diff,
150-
fma | fmaf => mul_add,
150+
fma | fmaf | fmaf16 => mul_add,
151151
fmax | fmaxf => max,
152152
fmin | fminf => min,
153153
lgamma | lgammaf => ln_gamma,

crates/libm-test/src/precision.rs

+5
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,12 @@ fn bessel_prec_dropoff<F: Float>(
485485
None
486486
}
487487

488+
#[cfg(f16_enabled)]
489+
impl MaybeOverride<(f16, f16, f16)> for SpecialCase {}
488490
impl MaybeOverride<(f32, f32, f32)> for SpecialCase {}
489491
impl MaybeOverride<(f64, f64, f64)> for SpecialCase {}
492+
#[cfg(f128_enabled)]
493+
impl MaybeOverride<(f128, f128, f128)> for SpecialCase {}
494+
490495
impl MaybeOverride<(f32, i32)> for SpecialCase {}
491496
impl MaybeOverride<(f64, i32)> for SpecialCase {}

crates/libm-test/tests/multiprecision.rs

+1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ libm_macros::for_each_function! {
122122
fdimf,
123123
fma,
124124
fmaf,
125+
fmaf16,
125126
fmax,
126127
fmaxf,
127128
fmin,

etc/function-definitions.json

+6
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,12 @@
328328
],
329329
"type": "f32"
330330
},
331+
"fmaf16": {
332+
"sources": [
333+
"src/math/fmaf16.rs"
334+
],
335+
"type": "f16"
336+
},
331337
"fmax": {
332338
"sources": [
333339
"src/libm_helper.rs",

etc/function-list.txt

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ floor
4747
floorf
4848
fma
4949
fmaf
50+
fmaf16
5051
fmax
5152
fmaxf
5253
fmin

src/math/fmaf.rs

+4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ use super::fenv::{
4747
/// according to the rounding mode characterized by the value of FLT_ROUNDS.
4848
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
4949
pub fn fmaf(x: f32, y: f32, mut z: f32) -> f32 {
50+
if true {
51+
return super::generic::fma_big::<f32, f64>(x, y, z);
52+
}
53+
5054
let xy: f64;
5155
let mut result: f64;
5256
let mut ui: u64;

src/math/fmaf16.rs

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
2+
pub fn fmaf16(x: f16, y: f16, z: f16) -> f16 {
3+
super::generic::fma_big::<f16, f32>(x, y, z)
4+
}

src/math/generic/fma.rs

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
use super::super::fenv::{
2+
FE_INEXACT, FE_TONEAREST, FE_UNDERFLOW, feclearexcept, fegetround, feraiseexcept, fetestexcept,
3+
};
4+
use super::super::{CastFrom, CastInto, DFloat, Float, HFloat, IntTy, MinInt};
5+
6+
/// FMA implementation when there is a larger float type available.
7+
pub fn fma_big<F, B>(x: F, y: F, z: F) -> F
8+
where
9+
F: Float + HFloat<D = B>,
10+
B: Float + DFloat<H = F>,
11+
// F: Float + CastInto<B>,
12+
// B: Float + CastInto<F> + CastFrom<F>,
13+
B::Int: CastInto<i32>,
14+
i32: CastFrom<i32>,
15+
{
16+
let one = IntTy::<B>::ONE;
17+
18+
let xy: B;
19+
let mut result: B;
20+
let mut ui: B::Int;
21+
let e: i32;
22+
23+
xy = x.widen() * y.widen();
24+
result = xy + z.widen();
25+
ui = result.to_bits();
26+
e = i32::cast_from(ui >> F::SIG_BITS) & F::EXP_MAX as i32;
27+
let zb: B = z.widen();
28+
29+
let prec_diff = B::SIG_BITS - F::SIG_BITS;
30+
let excess_prec = ui & ((one << prec_diff) - one);
31+
let x = one << (prec_diff - 1);
32+
33+
// Common case: the larger precision is fine
34+
if excess_prec != x
35+
|| e == i32::cast_from(F::EXP_MAX)
36+
|| (result - xy == zb && result - zb == xy)
37+
|| fegetround() != FE_TONEAREST
38+
{
39+
// TODO: feclearexcept
40+
41+
return result.narrow();
42+
}
43+
44+
let neg = ui & B::SIGN_MASK > IntTy::<B>::ZERO;
45+
let err = if neg == (zb > xy) { xy - result + zb } else { zb - result + xy };
46+
if neg == (err < B::ZERO) {
47+
ui += one;
48+
} else {
49+
ui -= one;
50+
}
51+
52+
B::from_bits(ui).narrow()
53+
}

src/math/generic/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
mod copysign;
22
mod fabs;
3+
mod fma;
34

45
pub use copysign::copysign;
56
pub use fabs::fabs;
7+
pub use fma::fma_big;

src/math/mod.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ use self::rem_pio2::rem_pio2;
121121
use self::rem_pio2_large::rem_pio2_large;
122122
use self::rem_pio2f::rem_pio2f;
123123
#[allow(unused_imports)]
124-
use self::support::{CastFrom, CastInto, DInt, Float, HInt, Int, MinInt};
124+
use self::support::{CastFrom, CastInto, DFloat, DInt, Float, HFloat, HInt, Int, IntTy, MinInt};
125125

126126
// Public modules
127127
mod acos;
@@ -343,9 +343,11 @@ cfg_if! {
343343
if #[cfg(f16_enabled)] {
344344
mod copysignf16;
345345
mod fabsf16;
346+
mod fmaf16;
346347

347348
pub use self::copysignf16::copysignf16;
348349
pub use self::fabsf16::fabsf16;
350+
pub use self::fmaf16::fmaf16;
349351
}
350352
}
351353

src/math/support/float_traits.rs

+72-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use core::{fmt, mem, ops};
1+
use core::ops::{self, Neg};
2+
use core::{fmt, mem};
23

34
use super::int_traits::{Int, MinInt};
45

@@ -23,10 +24,9 @@ pub trait Float:
2324
type Int: Int<OtherSign = Self::SignedInt, Unsigned = Self::Int>;
2425

2526
/// A int of the same width as the float
26-
type SignedInt: Int + MinInt<OtherSign = Self::Int, Unsigned = Self::Int>;
27-
28-
/// An int capable of containing the exponent bits plus a sign bit. This is signed.
29-
type ExpInt: Int;
27+
type SignedInt: Int
28+
+ MinInt<OtherSign = Self::Int, Unsigned = Self::Int>
29+
+ Neg<Output = Self::SignedInt>;
3030

3131
const ZERO: Self;
3232
const NEG_ZERO: Self;
@@ -98,7 +98,7 @@ pub trait Float:
9898
}
9999

100100
/// Returns the exponent, not adjusting for bias.
101-
fn exp(self) -> Self::ExpInt;
101+
fn exp(self) -> i32;
102102

103103
/// Returns the significand with no implicit bit (or the "fractional" part)
104104
fn frac(self) -> Self::Int {
@@ -138,23 +138,20 @@ pub trait Float:
138138
}
139139

140140
/// Access the associated `Int` type from a float (helper to avoid ambiguous associated types).
141-
#[allow(dead_code)]
142141
pub type IntTy<F> = <F as Float>::Int;
143142

144143
macro_rules! float_impl {
145144
(
146145
$ty:ident,
147146
$ity:ident,
148147
$sity:ident,
149-
$expty:ident,
150148
$bits:expr,
151149
$significand_bits:expr,
152150
$from_bits:path
153151
) => {
154152
impl Float for $ty {
155153
type Int = $ity;
156154
type SignedInt = $sity;
157-
type ExpInt = $expty;
158155

159156
const ZERO: Self = 0.0;
160157
const NEG_ZERO: Self = -0.0;
@@ -191,8 +188,8 @@ macro_rules! float_impl {
191188
fn is_sign_negative(self) -> bool {
192189
self.is_sign_negative()
193190
}
194-
fn exp(self) -> Self::ExpInt {
195-
((self.to_bits() & Self::EXP_MASK) >> Self::SIG_BITS) as Self::ExpInt
191+
fn exp(self) -> i32 {
192+
((self.to_bits() & Self::EXP_MASK) >> Self::SIG_BITS) as i32
196193
}
197194
fn from_bits(a: Self::Int) -> Self {
198195
Self::from_bits(a)
@@ -226,11 +223,11 @@ macro_rules! float_impl {
226223
}
227224

228225
#[cfg(f16_enabled)]
229-
float_impl!(f16, u16, i16, i8, 16, 10, f16::from_bits);
230-
float_impl!(f32, u32, i32, i16, 32, 23, f32_from_bits);
231-
float_impl!(f64, u64, i64, i16, 64, 52, f64_from_bits);
226+
float_impl!(f16, u16, i16, 16, 10, f16::from_bits);
227+
float_impl!(f32, u32, i32, 32, 23, f32_from_bits);
228+
float_impl!(f64, u64, i64, 64, 52, f64_from_bits);
232229
#[cfg(f128_enabled)]
233-
float_impl!(f128, u128, i128, i16, 128, 112, f128::from_bits);
230+
float_impl!(f128, u128, i128, 128, 112, f128::from_bits);
234231

235232
/* FIXME(msrv): vendor some things that are not const stable at our MSRV */
236233

@@ -245,3 +242,63 @@ pub const fn f64_from_bits(bits: u64) -> f64 {
245242
// SAFETY: POD cast with no preconditions
246243
unsafe { mem::transmute::<u64, f64>(bits) }
247244
}
245+
246+
/// Trait for floats twice the bit width of another integer.
247+
#[allow(unused)]
248+
pub trait DFloat: Float {
249+
/// Float that is half the bit width of the floatthis trait is implemented for.
250+
type H: HFloat<D = Self>;
251+
252+
/// Narrow the float type.
253+
fn narrow(self) -> Self::H;
254+
}
255+
256+
/// Trait for floats half the bit width of another float.
257+
#[allow(unused)]
258+
pub trait HFloat: Float {
259+
/// Float that is double the bit width of the float this trait is implemented for.
260+
type D: DFloat<H = Self>;
261+
262+
/// Widen the float type.
263+
fn widen(self) -> Self::D;
264+
}
265+
266+
macro_rules! impl_d_float {
267+
($($X:ident $D:ident),*) => {
268+
$(
269+
impl DFloat for $D {
270+
type H = $X;
271+
272+
fn narrow(self) -> Self::H {
273+
self as $X
274+
}
275+
}
276+
)*
277+
};
278+
}
279+
280+
macro_rules! impl_h_float {
281+
($($H:ident $X:ident),*) => {
282+
$(
283+
impl HFloat for $H {
284+
type D = $X;
285+
286+
fn widen(self) -> Self::D {
287+
self as $X
288+
}
289+
}
290+
)*
291+
};
292+
}
293+
294+
impl_d_float!(f32 f64);
295+
#[cfg(f16_enabled)]
296+
impl_d_float!(f16 f32);
297+
#[cfg(f128_enabled)]
298+
impl_d_float!(f64 f128);
299+
300+
impl_h_float!(f32 f64);
301+
#[cfg(f16_enabled)]
302+
impl_h_float!(f16 f32);
303+
#[cfg(f128_enabled)]
304+
impl_h_float!(f64 f128);

src/math/support/int_traits.rs

+31
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ pub trait Int:
8282
fn wrapping_shr(self, other: u32) -> Self;
8383
fn rotate_left(self, other: u32) -> Self;
8484
fn overflowing_add(self, other: Self) -> (Self, bool);
85+
fn overflowing_sub(self, other: Self) -> (Self, bool);
8586
fn leading_zeros(self) -> u32;
8687
fn ilog2(self) -> u32;
8788
}
@@ -140,6 +141,10 @@ macro_rules! int_impl_common {
140141
<Self>::overflowing_add(self, other)
141142
}
142143

144+
fn overflowing_sub(self, other: Self) -> (Self, bool) {
145+
<Self>::overflowing_sub(self, other)
146+
}
147+
143148
fn leading_zeros(self) -> u32 {
144149
<Self>::leading_zeros(self)
145150
}
@@ -382,3 +387,29 @@ cast_into!(u64);
382387
cast_into!(i64);
383388
cast_into!(u128);
384389
cast_into!(i128);
390+
391+
cast_into!(i64; f32);
392+
cast_into!(i64; f64);
393+
cast_into!(f32; f64);
394+
cast_into!(f64; f32);
395+
396+
cast_into!(bool; u16);
397+
cast_into!(bool; u32);
398+
cast_into!(bool; u64);
399+
cast_into!(bool; u128);
400+
401+
cfg_if! {
402+
if #[cfg(f16_enabled)] {
403+
cast_into!(f16; f32, f64);
404+
cast_into!(f32; f16);
405+
cast_into!(f64; f16);
406+
}
407+
}
408+
409+
cfg_if! {
410+
if #[cfg(f128_enabled)] {
411+
cast_into!(f128; f32, f64);
412+
cast_into!(f32; f128);
413+
cast_into!(f64; f128);
414+
}
415+
}

src/math/support/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ mod hex_float;
55
mod int_traits;
66

77
#[allow(unused_imports)]
8-
pub use float_traits::{Float, IntTy};
8+
pub use float_traits::{DFloat, Float, HFloat, IntTy};
99
pub(crate) use float_traits::{f32_from_bits, f64_from_bits};
1010
#[allow(unused_imports)]
1111
pub use hex_float::{hf32, hf64};

0 commit comments

Comments
 (0)