Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add midpoint function for all integers and floating numbers #92048

Merged
merged 3 commits into from
May 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions library/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
#![feature(const_maybe_uninit_assume_init)]
#![feature(const_maybe_uninit_uninit_array)]
#![feature(const_nonnull_new)]
#![feature(const_num_midpoint)]
#![feature(const_option)]
#![feature(const_option_ext)]
#![feature(const_pin)]
Expand Down
36 changes: 36 additions & 0 deletions library/core/src/num/f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,42 @@ impl f32 {
}
}

/// Calculates the middle point of `self` and `rhs`.
///
/// This returns NaN when *either* argument is NaN or if a combination of
/// +inf and -inf is provided as arguments.
///
/// # Examples
///
/// ```
/// #![feature(num_midpoint)]
/// assert_eq!(1f32.midpoint(4.0), 2.5);
/// assert_eq!((-5.5f32).midpoint(8.0), 1.25);
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
pub fn midpoint(self, other: f32) -> f32 {
const LO: f32 = f32::MIN_POSITIVE * 2.;
const HI: f32 = f32::MAX / 2.;

let (a, b) = (self, other);
let abs_a = a.abs_private();
let abs_b = b.abs_private();

if abs_a <= HI && abs_b <= HI {
// Overflow is impossible
(a + b) / 2.
} else if abs_a < LO {
// Not safe to halve a
a + (b / 2.)
} else if abs_b < LO {
// Not safe to halve b
(a / 2.) + b
} else {
// Not safe to halve a and b
(a / 2.) + (b / 2.)
}
}

/// Rounds toward zero and converts to any primitive integer type,
/// assuming that the value is finite and fits in that type.
///
Expand Down
36 changes: 36 additions & 0 deletions library/core/src/num/f64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,42 @@ impl f64 {
}
}

/// Calculates the middle point of `self` and `rhs`.
///
/// This returns NaN when *either* argument is NaN or if a combination of
/// +inf and -inf is provided as arguments.
///
/// # Examples
///
/// ```
/// #![feature(num_midpoint)]
/// assert_eq!(1f64.midpoint(4.0), 2.5);
/// assert_eq!((-5.5f64).midpoint(8.0), 1.25);
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
pub fn midpoint(self, other: f64) -> f64 {
const LO: f64 = f64::MIN_POSITIVE * 2.;
const HI: f64 = f64::MAX / 2.;

let (a, b) = (self, other);
let abs_a = a.abs_private();
let abs_b = b.abs_private();

if abs_a <= HI && abs_b <= HI {
// Overflow is impossible
(a + b) / 2.
} else if abs_a < LO {
// Not safe to halve a
a + (b / 2.)
} else if abs_b < LO {
// Not safe to halve b
(a / 2.) + b
} else {
// Not safe to halve a and b
(a / 2.) + (b / 2.)
}
}

/// Rounds toward zero and converts to any primitive integer type,
/// assuming that the value is finite and fits in that type.
///
Expand Down
38 changes: 38 additions & 0 deletions library/core/src/num/int_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2332,6 +2332,44 @@ macro_rules! int_impl {
}
}

/// Calculates the middle point of `self` and `rhs`.
///
/// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
/// sufficiently-large signed integral type. This implies that the result is
/// always rounded towards negative infinity and that no overflow will ever occur.
///
/// # Examples
///
/// ```
/// #![feature(num_midpoint)]
#[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
#[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(-1), -1);")]
#[doc = concat!("assert_eq!((-1", stringify!($SelfT), ").midpoint(0), -1);")]
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
#[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
#[rustc_allow_const_fn_unstable(const_num_midpoint)]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn midpoint(self, rhs: Self) -> Self {
const U: $UnsignedT = <$SelfT>::MIN.unsigned_abs();

// Map an $SelfT to an $UnsignedT
// ex: i8 [-128; 127] to [0; 255]
const fn map(a: $SelfT) -> $UnsignedT {
(a as $UnsignedT) ^ U
}

// Map an $UnsignedT to an $SelfT
// ex: u8 [0; 255] to [-128; 127]
const fn demap(a: $UnsignedT) -> $SelfT {
(a ^ U) as $SelfT
}

demap(<$UnsignedT>::midpoint(map(self), map(rhs)))
}

/// Returns the logarithm of the number with respect to an arbitrary base,
/// rounded down.
///
Expand Down
59 changes: 59 additions & 0 deletions library/core/src/num/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,57 @@ depending on the target pointer size.
};
}

macro_rules! midpoint_impl {
($SelfT:ty, unsigned) => {
/// Calculates the middle point of `self` and `rhs`.
///
/// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
/// sufficiently-large signed integral type. This implies that the result is
/// always rounded towards negative infinity and that no overflow will ever occur.
///
/// # Examples
///
/// ```
/// #![feature(num_midpoint)]
#[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
#[doc = concat!("assert_eq!(1", stringify!($SelfT), ".midpoint(4), 2);")]
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
#[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn midpoint(self, rhs: $SelfT) -> $SelfT {
// Use the well known branchless algorthim from Hacker's Delight to compute
// `(a + b) / 2` without overflowing: `((a ^ b) >> 1) + (a & b)`.
((self ^ rhs) >> 1) + (self & rhs)
}
};
($SelfT:ty, $WideT:ty, unsigned) => {
/// Calculates the middle point of `self` and `rhs`.
///
/// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
/// sufficiently-large signed integral type. This implies that the result is
/// always rounded towards negative infinity and that no overflow will ever occur.
///
/// # Examples
///
/// ```
/// #![feature(num_midpoint)]
#[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
#[doc = concat!("assert_eq!(1", stringify!($SelfT), ".midpoint(4), 2);")]
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
#[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn midpoint(self, rhs: $SelfT) -> $SelfT {
((self as $WideT + rhs as $WideT) / 2) as $SelfT
}
};
}

macro_rules! widening_impl {
($SelfT:ty, $WideT:ty, $BITS:literal, unsigned) => {
/// Calculates the complete product `self * rhs` without the possibility to overflow.
Expand Down Expand Up @@ -455,6 +506,7 @@ impl u8 {
bound_condition = "",
}
widening_impl! { u8, u16, 8, unsigned }
midpoint_impl! { u8, u16, unsigned }

/// Checks if the value is within the ASCII range.
///
Expand Down Expand Up @@ -1057,6 +1109,7 @@ impl u16 {
bound_condition = "",
}
widening_impl! { u16, u32, 16, unsigned }
midpoint_impl! { u16, u32, unsigned }

/// Checks if the value is a Unicode surrogate code point, which are disallowed values for [`char`].
///
Expand Down Expand Up @@ -1105,6 +1158,7 @@ impl u32 {
bound_condition = "",
}
widening_impl! { u32, u64, 32, unsigned }
midpoint_impl! { u32, u64, unsigned }
}

impl u64 {
Expand All @@ -1128,6 +1182,7 @@ impl u64 {
bound_condition = "",
}
widening_impl! { u64, u128, 64, unsigned }
midpoint_impl! { u64, u128, unsigned }
}

impl u128 {
Expand All @@ -1152,6 +1207,7 @@ impl u128 {
from_xe_bytes_doc = "",
bound_condition = "",
}
midpoint_impl! { u128, unsigned }
}

#[cfg(target_pointer_width = "16")]
Expand All @@ -1176,6 +1232,7 @@ impl usize {
bound_condition = " on 16-bit targets",
}
widening_impl! { usize, u32, 16, unsigned }
midpoint_impl! { usize, u32, unsigned }
}

#[cfg(target_pointer_width = "32")]
Expand All @@ -1200,6 +1257,7 @@ impl usize {
bound_condition = " on 32-bit targets",
}
widening_impl! { usize, u64, 32, unsigned }
midpoint_impl! { usize, u64, unsigned }
}

#[cfg(target_pointer_width = "64")]
Expand All @@ -1224,6 +1282,7 @@ impl usize {
bound_condition = " on 64-bit targets",
}
widening_impl! { usize, u128, 64, unsigned }
midpoint_impl! { usize, u128, unsigned }
}

impl usize {
Expand Down
37 changes: 37 additions & 0 deletions library/core/src/num/nonzero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,43 @@ macro_rules! nonzero_unsigned_operations {
pub const fn ilog10(self) -> u32 {
super::int_log10::$Int(self.0)
}

/// Calculates the middle point of `self` and `rhs`.
///
/// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
/// sufficiently-large signed integral type. This implies that the result is
/// always rounded towards negative infinity and that no overflow will ever occur.
///
/// # Examples
///
/// ```
/// #![feature(num_midpoint)]
#[doc = concat!("# use std::num::", stringify!($Ty), ";")]
///
/// # fn main() { test().unwrap(); }
/// # fn test() -> Option<()> {
#[doc = concat!("let one = ", stringify!($Ty), "::new(1)?;")]
#[doc = concat!("let two = ", stringify!($Ty), "::new(2)?;")]
#[doc = concat!("let four = ", stringify!($Ty), "::new(4)?;")]
///
/// assert_eq!(one.midpoint(four), two);
/// assert_eq!(four.midpoint(one), two);
/// # Some(())
/// # }
/// ```
#[unstable(feature = "num_midpoint", issue = "110840")]
#[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
#[rustc_allow_const_fn_unstable(const_num_midpoint)]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn midpoint(self, rhs: Self) -> Self {
// SAFETY: The only way to get `0` with midpoint is to have two opposite or
// near opposite numbers: (-5, 5), (0, 1), (0, 0) which is impossible because
// of the unsignedness of this number and also because $Ty is guaranteed to
// never being 0.
unsafe { $Ty::new_unchecked(self.get().midpoint(rhs.get())) }
}
}
)+
}
Expand Down
1 change: 1 addition & 0 deletions library/core/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#![feature(maybe_uninit_uninit_array_transpose)]
#![feature(min_specialization)]
#![feature(numfmt)]
#![feature(num_midpoint)]
#![feature(step_trait)]
#![feature(str_internals)]
#![feature(std_internals)]
Expand Down
26 changes: 26 additions & 0 deletions library/core/tests/num/int_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,32 @@ macro_rules! int_module {
assert_eq!((0 as $T).borrowing_sub($T::MIN, false), ($T::MIN, true));
assert_eq!((0 as $T).borrowing_sub($T::MIN, true), ($T::MAX, false));
}

#[test]
fn test_midpoint() {
assert_eq!(<$T>::midpoint(1, 3), 2);
assert_eq!(<$T>::midpoint(3, 1), 2);

assert_eq!(<$T>::midpoint(0, 0), 0);
assert_eq!(<$T>::midpoint(0, 2), 1);
assert_eq!(<$T>::midpoint(2, 0), 1);
assert_eq!(<$T>::midpoint(2, 2), 2);

assert_eq!(<$T>::midpoint(1, 4), 2);
assert_eq!(<$T>::midpoint(4, 1), 2);
assert_eq!(<$T>::midpoint(3, 4), 3);
assert_eq!(<$T>::midpoint(4, 3), 3);

assert_eq!(<$T>::midpoint(<$T>::MIN, <$T>::MAX), -1);
assert_eq!(<$T>::midpoint(<$T>::MAX, <$T>::MIN), -1);
assert_eq!(<$T>::midpoint(<$T>::MIN, <$T>::MIN), <$T>::MIN);
assert_eq!(<$T>::midpoint(<$T>::MAX, <$T>::MAX), <$T>::MAX);

assert_eq!(<$T>::midpoint(<$T>::MIN, 6), <$T>::MIN / 2 + 3);
assert_eq!(<$T>::midpoint(6, <$T>::MIN), <$T>::MIN / 2 + 3);
assert_eq!(<$T>::midpoint(<$T>::MAX, 6), <$T>::MAX / 2 + 3);
assert_eq!(<$T>::midpoint(6, <$T>::MAX), <$T>::MAX / 2 + 3);
}
}
};
}
Loading