Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve sqrt/sqrtf if stable intrinsics allow #216

Merged
merged 5 commits into from
Aug 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 127 additions & 108 deletions src/math/sqrt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@
*/

use core::f64;
use core::num::Wrapping;

const TINY: f64 = 1.0e-300;

#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn sqrt(x: f64) -> f64 {
Expand All @@ -95,128 +92,150 @@ pub fn sqrt(x: f64) -> f64 {
}
}
}
let mut z: f64;
let sign: Wrapping<u32> = Wrapping(0x80000000);
let mut ix0: i32;
let mut s0: i32;
let mut q: i32;
let mut m: i32;
let mut t: i32;
let mut i: i32;
let mut r: Wrapping<u32>;
let mut t1: Wrapping<u32>;
let mut s1: Wrapping<u32>;
let mut ix1: Wrapping<u32>;
let mut q1: Wrapping<u32>;
#[cfg(target_feature = "sse2")]
{
// Note: This path is unlikely since LLVM will usually have already
// optimized sqrt calls into hardware instructions if sse2 is available,
// but if someone does end up here they'll apprected the speed increase.
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
unsafe {
let m = _mm_set_sd(x);
let m_sqrt = _mm_sqrt_pd(m);
_mm_cvtsd_f64(m_sqrt)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't using the same code here as for wasm32 produce the exact same effect for targets with SSE2? If not, why not ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I don't understand the question. sse2 is an x86 and x86_64 feature. It will never be available during a wasm32 build.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant that instead of using the core::arch intrinsics here, you can just write unsafe { ::core::intrinsics::sqrtf64(x) } just like wasm32 does, and that should generate this exact same code if SSE2 is enabled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, but that's not stable. this code is. that's why i did not also add a neon version, I'm only out to help out stable rust today.

you're right that we should probably also reconsider calling the core intrinsic in more cases if Nightly is in use.

}
#[cfg(not(target_feature = "sse2"))]
{
use core::num::Wrapping;

ix0 = (x.to_bits() >> 32) as i32;
ix1 = Wrapping(x.to_bits() as u32);
const TINY: f64 = 1.0e-300;

/* take care of Inf and NaN */
if (ix0 & 0x7ff00000) == 0x7ff00000 {
return x * x + x; /* sqrt(NaN)=NaN, sqrt(+inf)=+inf, sqrt(-inf)=sNaN */
}
/* take care of zero */
if ix0 <= 0 {
if ((ix0 & !(sign.0 as i32)) | ix1.0 as i32) == 0 {
return x; /* sqrt(+-0) = +-0 */
let mut z: f64;
let sign: Wrapping<u32> = Wrapping(0x80000000);
let mut ix0: i32;
let mut s0: i32;
let mut q: i32;
let mut m: i32;
let mut t: i32;
let mut i: i32;
let mut r: Wrapping<u32>;
let mut t1: Wrapping<u32>;
let mut s1: Wrapping<u32>;
let mut ix1: Wrapping<u32>;
let mut q1: Wrapping<u32>;

ix0 = (x.to_bits() >> 32) as i32;
ix1 = Wrapping(x.to_bits() as u32);

/* take care of Inf and NaN */
if (ix0 & 0x7ff00000) == 0x7ff00000 {
return x * x + x; /* sqrt(NaN)=NaN, sqrt(+inf)=+inf, sqrt(-inf)=sNaN */
}
if ix0 < 0 {
return (x - x) / (x - x); /* sqrt(-ve) = sNaN */
/* take care of zero */
if ix0 <= 0 {
if ((ix0 & !(sign.0 as i32)) | ix1.0 as i32) == 0 {
return x; /* sqrt(+-0) = +-0 */
}
if ix0 < 0 {
return (x - x) / (x - x); /* sqrt(-ve) = sNaN */
}
}
}
/* normalize x */
m = ix0 >> 20;
if m == 0 {
/* subnormal x */
while ix0 == 0 {
m -= 21;
ix0 |= (ix1 >> 11).0 as i32;
ix1 <<= 21;
/* normalize x */
m = ix0 >> 20;
if m == 0 {
/* subnormal x */
while ix0 == 0 {
m -= 21;
ix0 |= (ix1 >> 11).0 as i32;
ix1 <<= 21;
}
i = 0;
while (ix0 & 0x00100000) == 0 {
i += 1;
ix0 <<= 1;
}
m -= i - 1;
ix0 |= (ix1 >> (32 - i) as usize).0 as i32;
ix1 = ix1 << i as usize;
}
i = 0;
while (ix0 & 0x00100000) == 0 {
i += 1;
ix0 <<= 1;
m -= 1023; /* unbias exponent */
ix0 = (ix0 & 0x000fffff) | 0x00100000;
if (m & 1) == 1 {
/* odd m, double x to make it even */
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
}
m -= i - 1;
ix0 |= (ix1 >> (32 - i) as usize).0 as i32;
ix1 = ix1 << i as usize;
}
m -= 1023; /* unbias exponent */
ix0 = (ix0 & 0x000fffff) | 0x00100000;
if (m & 1) == 1 {
/* odd m, double x to make it even */
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
}
m >>= 1; /* m = [m/2] */

/* generate sqrt(x) bit by bit */
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
q = 0; /* [q,q1] = sqrt(x) */
q1 = Wrapping(0);
s0 = 0;
s1 = Wrapping(0);
r = Wrapping(0x00200000); /* r = moving bit from right to left */
m >>= 1; /* m = [m/2] */

while r != Wrapping(0) {
t = s0 + r.0 as i32;
if t <= ix0 {
s0 = t + r.0 as i32;
ix0 -= t;
q += r.0 as i32;
}
/* generate sqrt(x) bit by bit */
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
r >>= 1;
}
q = 0; /* [q,q1] = sqrt(x) */
q1 = Wrapping(0);
s0 = 0;
s1 = Wrapping(0);
r = Wrapping(0x00200000); /* r = moving bit from right to left */

r = sign;
while r != Wrapping(0) {
t1 = s1 + r;
t = s0;
if t < ix0 || (t == ix0 && t1 <= ix1) {
s1 = t1 + r;
if (t1 & sign) == sign && (s1 & sign) == Wrapping(0) {
s0 += 1;
while r != Wrapping(0) {
t = s0 + r.0 as i32;
if t <= ix0 {
s0 = t + r.0 as i32;
ix0 -= t;
q += r.0 as i32;
}
ix0 -= t;
if ix1 < t1 {
ix0 -= 1;
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
r >>= 1;
}

r = sign;
while r != Wrapping(0) {
t1 = s1 + r;
t = s0;
if t < ix0 || (t == ix0 && t1 <= ix1) {
s1 = t1 + r;
if (t1 & sign) == sign && (s1 & sign) == Wrapping(0) {
s0 += 1;
}
ix0 -= t;
if ix1 < t1 {
ix0 -= 1;
}
ix1 -= t1;
q1 += r;
}
ix1 -= t1;
q1 += r;
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
r >>= 1;
}
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
r >>= 1;
}

/* use floating add to find out rounding direction */
if (ix0 as u32 | ix1.0) != 0 {
z = 1.0 - TINY; /* raise inexact flag */
if z >= 1.0 {
z = 1.0 + TINY;
if q1.0 == 0xffffffff {
q1 = Wrapping(0);
q += 1;
} else if z > 1.0 {
if q1.0 == 0xfffffffe {
/* use floating add to find out rounding direction */
if (ix0 as u32 | ix1.0) != 0 {
z = 1.0 - TINY; /* raise inexact flag */
if z >= 1.0 {
z = 1.0 + TINY;
if q1.0 == 0xffffffff {
q1 = Wrapping(0);
q += 1;
} else if z > 1.0 {
if q1.0 == 0xfffffffe {
q += 1;
}
q1 += Wrapping(2);
} else {
q1 += q1 & Wrapping(1);
}
q1 += Wrapping(2);
} else {
q1 += q1 & Wrapping(1);
}
}
ix0 = (q >> 1) + 0x3fe00000;
ix1 = q1 >> 1;
if (q & 1) == 1 {
ix1 |= sign;
}
ix0 += m << 20;
f64::from_bits((ix0 as u64) << 32 | ix1.0 as u64)
}
ix0 = (q >> 1) + 0x3fe00000;
ix1 = q1 >> 1;
if (q & 1) == 1 {
ix1 |= sign;
}
ix0 += m << 20;
f64::from_bits((ix0 as u64) << 32 | ix1.0 as u64)
}
Loading