Skip to content

Commit 32a6a99

Browse files
authored
Merge pull request #216 from Lokathor/sse-sqrt
Improve sqrt/sqrtf if stable intrinsics allow
2 parents 2cc2589 + b0f666e commit 32a6a99

File tree

2 files changed

+212
-175
lines changed

2 files changed

+212
-175
lines changed

src/math/sqrt.rs

+127-108
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,6 @@
7777
*/
7878

7979
use core::f64;
80-
use core::num::Wrapping;
81-
82-
const TINY: f64 = 1.0e-300;
8380

8481
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
8582
pub fn sqrt(x: f64) -> f64 {
@@ -95,128 +92,150 @@ pub fn sqrt(x: f64) -> f64 {
9592
}
9693
}
9794
}
98-
let mut z: f64;
99-
let sign: Wrapping<u32> = Wrapping(0x80000000);
100-
let mut ix0: i32;
101-
let mut s0: i32;
102-
let mut q: i32;
103-
let mut m: i32;
104-
let mut t: i32;
105-
let mut i: i32;
106-
let mut r: Wrapping<u32>;
107-
let mut t1: Wrapping<u32>;
108-
let mut s1: Wrapping<u32>;
109-
let mut ix1: Wrapping<u32>;
110-
let mut q1: Wrapping<u32>;
95+
#[cfg(target_feature = "sse2")]
96+
{
97+
// Note: This path is unlikely since LLVM will usually have already
98+
// optimized sqrt calls into hardware instructions if sse2 is available,
99+
// but if someone does end up here they'll apprected the speed increase.
100+
#[cfg(target_arch = "x86")]
101+
use core::arch::x86::*;
102+
#[cfg(target_arch = "x86_64")]
103+
use core::arch::x86_64::*;
104+
unsafe {
105+
let m = _mm_set_sd(x);
106+
let m_sqrt = _mm_sqrt_pd(m);
107+
_mm_cvtsd_f64(m_sqrt)
108+
}
109+
}
110+
#[cfg(not(target_feature = "sse2"))]
111+
{
112+
use core::num::Wrapping;
111113

112-
ix0 = (x.to_bits() >> 32) as i32;
113-
ix1 = Wrapping(x.to_bits() as u32);
114+
const TINY: f64 = 1.0e-300;
114115

115-
/* take care of Inf and NaN */
116-
if (ix0 & 0x7ff00000) == 0x7ff00000 {
117-
return x * x + x; /* sqrt(NaN)=NaN, sqrt(+inf)=+inf, sqrt(-inf)=sNaN */
118-
}
119-
/* take care of zero */
120-
if ix0 <= 0 {
121-
if ((ix0 & !(sign.0 as i32)) | ix1.0 as i32) == 0 {
122-
return x; /* sqrt(+-0) = +-0 */
116+
let mut z: f64;
117+
let sign: Wrapping<u32> = Wrapping(0x80000000);
118+
let mut ix0: i32;
119+
let mut s0: i32;
120+
let mut q: i32;
121+
let mut m: i32;
122+
let mut t: i32;
123+
let mut i: i32;
124+
let mut r: Wrapping<u32>;
125+
let mut t1: Wrapping<u32>;
126+
let mut s1: Wrapping<u32>;
127+
let mut ix1: Wrapping<u32>;
128+
let mut q1: Wrapping<u32>;
129+
130+
ix0 = (x.to_bits() >> 32) as i32;
131+
ix1 = Wrapping(x.to_bits() as u32);
132+
133+
/* take care of Inf and NaN */
134+
if (ix0 & 0x7ff00000) == 0x7ff00000 {
135+
return x * x + x; /* sqrt(NaN)=NaN, sqrt(+inf)=+inf, sqrt(-inf)=sNaN */
123136
}
124-
if ix0 < 0 {
125-
return (x - x) / (x - x); /* sqrt(-ve) = sNaN */
137+
/* take care of zero */
138+
if ix0 <= 0 {
139+
if ((ix0 & !(sign.0 as i32)) | ix1.0 as i32) == 0 {
140+
return x; /* sqrt(+-0) = +-0 */
141+
}
142+
if ix0 < 0 {
143+
return (x - x) / (x - x); /* sqrt(-ve) = sNaN */
144+
}
126145
}
127-
}
128-
/* normalize x */
129-
m = ix0 >> 20;
130-
if m == 0 {
131-
/* subnormal x */
132-
while ix0 == 0 {
133-
m -= 21;
134-
ix0 |= (ix1 >> 11).0 as i32;
135-
ix1 <<= 21;
146+
/* normalize x */
147+
m = ix0 >> 20;
148+
if m == 0 {
149+
/* subnormal x */
150+
while ix0 == 0 {
151+
m -= 21;
152+
ix0 |= (ix1 >> 11).0 as i32;
153+
ix1 <<= 21;
154+
}
155+
i = 0;
156+
while (ix0 & 0x00100000) == 0 {
157+
i += 1;
158+
ix0 <<= 1;
159+
}
160+
m -= i - 1;
161+
ix0 |= (ix1 >> (32 - i) as usize).0 as i32;
162+
ix1 = ix1 << i as usize;
136163
}
137-
i = 0;
138-
while (ix0 & 0x00100000) == 0 {
139-
i += 1;
140-
ix0 <<= 1;
164+
m -= 1023; /* unbias exponent */
165+
ix0 = (ix0 & 0x000fffff) | 0x00100000;
166+
if (m & 1) == 1 {
167+
/* odd m, double x to make it even */
168+
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
169+
ix1 += ix1;
141170
}
142-
m -= i - 1;
143-
ix0 |= (ix1 >> (32 - i) as usize).0 as i32;
144-
ix1 = ix1 << i as usize;
145-
}
146-
m -= 1023; /* unbias exponent */
147-
ix0 = (ix0 & 0x000fffff) | 0x00100000;
148-
if (m & 1) == 1 {
149-
/* odd m, double x to make it even */
150-
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
151-
ix1 += ix1;
152-
}
153-
m >>= 1; /* m = [m/2] */
154-
155-
/* generate sqrt(x) bit by bit */
156-
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
157-
ix1 += ix1;
158-
q = 0; /* [q,q1] = sqrt(x) */
159-
q1 = Wrapping(0);
160-
s0 = 0;
161-
s1 = Wrapping(0);
162-
r = Wrapping(0x00200000); /* r = moving bit from right to left */
171+
m >>= 1; /* m = [m/2] */
163172

164-
while r != Wrapping(0) {
165-
t = s0 + r.0 as i32;
166-
if t <= ix0 {
167-
s0 = t + r.0 as i32;
168-
ix0 -= t;
169-
q += r.0 as i32;
170-
}
173+
/* generate sqrt(x) bit by bit */
171174
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
172175
ix1 += ix1;
173-
r >>= 1;
174-
}
176+
q = 0; /* [q,q1] = sqrt(x) */
177+
q1 = Wrapping(0);
178+
s0 = 0;
179+
s1 = Wrapping(0);
180+
r = Wrapping(0x00200000); /* r = moving bit from right to left */
175181

176-
r = sign;
177-
while r != Wrapping(0) {
178-
t1 = s1 + r;
179-
t = s0;
180-
if t < ix0 || (t == ix0 && t1 <= ix1) {
181-
s1 = t1 + r;
182-
if (t1 & sign) == sign && (s1 & sign) == Wrapping(0) {
183-
s0 += 1;
182+
while r != Wrapping(0) {
183+
t = s0 + r.0 as i32;
184+
if t <= ix0 {
185+
s0 = t + r.0 as i32;
186+
ix0 -= t;
187+
q += r.0 as i32;
184188
}
185-
ix0 -= t;
186-
if ix1 < t1 {
187-
ix0 -= 1;
189+
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
190+
ix1 += ix1;
191+
r >>= 1;
192+
}
193+
194+
r = sign;
195+
while r != Wrapping(0) {
196+
t1 = s1 + r;
197+
t = s0;
198+
if t < ix0 || (t == ix0 && t1 <= ix1) {
199+
s1 = t1 + r;
200+
if (t1 & sign) == sign && (s1 & sign) == Wrapping(0) {
201+
s0 += 1;
202+
}
203+
ix0 -= t;
204+
if ix1 < t1 {
205+
ix0 -= 1;
206+
}
207+
ix1 -= t1;
208+
q1 += r;
188209
}
189-
ix1 -= t1;
190-
q1 += r;
210+
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
211+
ix1 += ix1;
212+
r >>= 1;
191213
}
192-
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
193-
ix1 += ix1;
194-
r >>= 1;
195-
}
196214

197-
/* use floating add to find out rounding direction */
198-
if (ix0 as u32 | ix1.0) != 0 {
199-
z = 1.0 - TINY; /* raise inexact flag */
200-
if z >= 1.0 {
201-
z = 1.0 + TINY;
202-
if q1.0 == 0xffffffff {
203-
q1 = Wrapping(0);
204-
q += 1;
205-
} else if z > 1.0 {
206-
if q1.0 == 0xfffffffe {
215+
/* use floating add to find out rounding direction */
216+
if (ix0 as u32 | ix1.0) != 0 {
217+
z = 1.0 - TINY; /* raise inexact flag */
218+
if z >= 1.0 {
219+
z = 1.0 + TINY;
220+
if q1.0 == 0xffffffff {
221+
q1 = Wrapping(0);
207222
q += 1;
223+
} else if z > 1.0 {
224+
if q1.0 == 0xfffffffe {
225+
q += 1;
226+
}
227+
q1 += Wrapping(2);
228+
} else {
229+
q1 += q1 & Wrapping(1);
208230
}
209-
q1 += Wrapping(2);
210-
} else {
211-
q1 += q1 & Wrapping(1);
212231
}
213232
}
233+
ix0 = (q >> 1) + 0x3fe00000;
234+
ix1 = q1 >> 1;
235+
if (q & 1) == 1 {
236+
ix1 |= sign;
237+
}
238+
ix0 += m << 20;
239+
f64::from_bits((ix0 as u64) << 32 | ix1.0 as u64)
214240
}
215-
ix0 = (q >> 1) + 0x3fe00000;
216-
ix1 = q1 >> 1;
217-
if (q & 1) == 1 {
218-
ix1 |= sign;
219-
}
220-
ix0 += m << 20;
221-
f64::from_bits((ix0 as u64) << 32 | ix1.0 as u64)
222241
}

0 commit comments

Comments
 (0)