Skip to content

Commit

Permalink
Add ldexp to float modules. (diku-dk#2047)
Browse files Browse the repository at this point in the history
  • Loading branch information
athas authored and 0scar committed Nov 27, 2023
1 parent 9a2a87d commit f6bfbf8
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
that have been fused with `map`s that internally produce arrays.
Work by Anders Holst and Christian Påbøl Jacobsen.

* `f16.ldexp`, `f32.ldexp`, `f64.ldexp`, corresponding to the
functions in the C math library.

### Removed

### Changed
Expand Down
12 changes: 12 additions & 0 deletions prelude/math.fut
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,21 @@ module type integral = {
-- one of the operands is negative. May be more efficient.
val %%: t -> t -> t

-- | Bitwise and.
val &: t -> t -> t
-- | Bitwise or.
val |: t -> t -> t
-- | Bitwise xor.
val ^: t -> t -> t

-- | Bitwise negation.
val not: t -> t

-- | Left shift; inserting zeroes.
val <<: t -> t -> t
-- | Arithmetic right shift, using sign extension for the leftmost bits.
val >>: t -> t -> t
-- | Logical right shift, inserting zeroes for the leftmost bits.
val >>>: t -> t -> t

val num_bits: i32
Expand Down Expand Up @@ -248,6 +254,9 @@ module type float = {
-- | Produces the next representable number from `x` in the
-- direction of `y`.
val nextafter : (x: t) -> (y: t) -> t

-- | Multiplies floating-point value by 2 raised to an integer power.
val ldexp : t -> i32 -> t
}

-- | Boolean numbers. When converting from a number to `bool`, 0 is
Expand Down Expand Up @@ -961,6 +970,7 @@ module f64: (float with t = f64 with int_t = u64) = {
def round = intrinsics.round64

def nextafter x y = intrinsics.nextafter64 (x,y)
def ldexp x y = intrinsics.ldexp64 (x,y)

def to_bits (x: f64): u64 = u64m.i64 (intrinsics.to_bits64 x)
def from_bits (x: u64): f64 = intrinsics.from_bits64 (intrinsics.sign_i64 x)
Expand Down Expand Up @@ -1076,6 +1086,7 @@ module f32: (float with t = f32 with int_t = u32) = {
def round = intrinsics.round32

def nextafter x y = intrinsics.nextafter32 (x,y)
def ldexp x y = intrinsics.ldexp32 (x,y)

def to_bits (x: f32): u32 = u32m.i32 (intrinsics.to_bits32 x)
def from_bits (x: u32): f32 = intrinsics.from_bits32 (intrinsics.sign_i32 x)
Expand Down Expand Up @@ -1195,6 +1206,7 @@ module f16: (float with t = f16 with int_t = u16) = {
def round = intrinsics.round16

def nextafter x y = intrinsics.nextafter16 (x,y)
def ldexp x y = intrinsics.ldexp16 (x,y)

def to_bits (x: f16): u16 = u16m.i16 (intrinsics.to_bits16 x)
def from_bits (x: u16): f16 = intrinsics.from_bits16 (intrinsics.sign_i16 x)
Expand Down
20 changes: 20 additions & 0 deletions rts/c/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -1886,6 +1886,10 @@ SCALAR_FUN_ATTR float futrts_lerp32(float v0, float v1, float t) {
return mix(v0, v1, t);
}

SCALAR_FUN_ATTR float futrts_ldexp32(float x, int32_t y) {
return ldexp(x, y);
}

SCALAR_FUN_ATTR float futrts_mad32(float a, float b, float c) {
return mad(a, b, c);
}
Expand Down Expand Up @@ -2095,6 +2099,10 @@ SCALAR_FUN_ATTR float futrts_lerp32(float v0, float v1, float t) {
return v0 + (v1 - v0) * t;
}

SCALAR_FUN_ATTR float futrts_ldexp32(float x, int32_t y) {
return x * pow((double)2.0, (double)y);
}

SCALAR_FUN_ATTR float futrts_mad32(float a, float b, float c) {
return a * b + c;
}
Expand Down Expand Up @@ -2229,6 +2237,10 @@ SCALAR_FUN_ATTR float futrts_lerp32(float v0, float v1, float t) {
return v0 + (v1 - v0) * t;
}

SCALAR_FUN_ATTR float futrts_ldexp32(float x, int32_t y) {
return ldexpf(x, y);
}

SCALAR_FUN_ATTR float futrts_mad32(float a, float b, float c) {
return a * b + c;
}
Expand Down Expand Up @@ -2640,6 +2652,10 @@ SCALAR_FUN_ATTR double futrts_lerp64(double v0, double v1, double t) {
return v0 + (v1 - v0) * t;
}

SCALAR_FUN_ATTR float futrts_ldexp64(double x, int32_t y) {
return x * pow((double)2.0, (double)y);
}

SCALAR_FUN_ATTR double futrts_mad64(double a, double b, double c) {
return a * b + c;
}
Expand Down Expand Up @@ -2970,6 +2986,10 @@ SCALAR_FUN_ATTR double futrts_lerp64(double v0, double v1, double t) {
#endif
}

SCALAR_FUN_ATTR double futrts_ldexp64(double x, int32_t y) {
return ldexp(x, y);
}

SCALAR_FUN_ATTR double futrts_mad64(double a, double b, double c) {
#ifdef __OPENCL_VERSION__
return mad(a, b, c);
Expand Down
17 changes: 17 additions & 0 deletions rts/c/scalar_f16.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ SCALAR_FUN_ATTR f16 fmin16(f16 x, f16 y) {
SCALAR_FUN_ATTR f16 fpow16(f16 x, f16 y) {
return pow(x, y);
}

#else // Assuming CUDA.

SCALAR_FUN_ATTR f16 fabs16(f16 x) {
Expand Down Expand Up @@ -329,6 +330,10 @@ SCALAR_FUN_ATTR f16 futrts_lerp16(f16 v0, f16 v1, f16 t) {
return mix(v0, v1, t);
}

SCALAR_FUN_ATTR f16 futrts_ldexp16(f16 x, int32_t y) {
return ldexp(x, y);
}

SCALAR_FUN_ATTR f16 futrts_mad16(f16 a, f16 b, f16 c) {
return mad(a, b, c);
}
Expand Down Expand Up @@ -486,6 +491,10 @@ SCALAR_FUN_ATTR f16 futrts_lerp16(f16 v0, f16 v1, f16 t) {
return v0 + (v1 - v0) * t;
}

SCALAR_FUN_ATTR f16 futrts_ldexp16(f16 x, int32_t y) {
return futrts_ldexp32((float)x, y);
}

SCALAR_FUN_ATTR f16 futrts_mad16(f16 a, f16 b, f16 c) {
return a * b + c;
}
Expand Down Expand Up @@ -620,6 +629,10 @@ SCALAR_FUN_ATTR f16 futrts_lerp16(f16 v0, f16 v1, f16 t) {
return v0 + (v1 - v0) * t;
}

SCALAR_FUN_ATTR f16 futrts_ldexp16(f16 x, int32_t y) {
return futrts_ldexp32((float)x, y);
}

SCALAR_FUN_ATTR f16 futrts_mad16(f16 a, f16 b, f16 c) {
return a * b + c;
}
Expand Down Expand Up @@ -822,6 +835,10 @@ SCALAR_FUN_ATTR f16 futrts_lerp16(f16 v0, f16 v1, f16 t) {
return futrts_lerp32(v0, v1, t);
}

SCALAR_FUN_ATTR f16 futrts_ldexp16(f16 x, int32_t y) {
return futrts_ldexp32(x, y);
}

SCALAR_FUN_ATTR f16 futrts_mad16(f16 a, f16 b, f16 c) {
return futrts_mad32(a, b, c);
}
Expand Down
15 changes: 15 additions & 0 deletions src/Futhark/AD/Derivatives.hs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,21 @@ pdBuiltin "lerp64" [v0, v1, t] =
untyped $ fMax64 0 (fMin64 1 (isF64 t)),
untyped $ isF64 v1 - isF64 v0
]
pdBuiltin "ldexp16" [x, y] =
Just
[ untyped $ 2 ** isF16 x,
untyped $ log 2 * (2 ** isF16 y) * isF16 x
]
pdBuiltin "ldexp32" [x, y] =
Just
[ untyped $ 2 ** isF32 x,
untyped $ log 2 * (2 ** isF32 y) * isF32 x
]
pdBuiltin "ldexp64" [x, y] =
Just
[ untyped $ 2 ** isF64 x,
untyped $ log 2 * (2 ** isF64 y) * isF64 x
]
pdBuiltin "erf16" [z] =
Just [untyped $ (2 / sqrt pi) * exp (negate (isF16 z * isF16 z))]
pdBuiltin "erf32" [z] =
Expand Down
16 changes: 16 additions & 0 deletions src/Futhark/Util/CMath.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ module Futhark.Util.CMath
cbrtf,
hypot,
hypotf,
ldexp,
ldexpf,
)
where

import Foreign.C.Types (CInt (..))

foreign import ccall "nearbyint" c_nearbyint :: Double -> Double

foreign import ccall "nearbyintf" c_nearbyintf :: Float -> Float
Expand Down Expand Up @@ -146,3 +150,15 @@ cbrt = c_cbrt
-- | The system-level @cbrtf@ function.
cbrtf :: Float -> Float
cbrtf = c_cbrtf

foreign import ccall "ldexp" c_ldexp :: Double -> CInt -> Double

foreign import ccall "ldexpf" c_ldexpf :: Float -> CInt -> Float

-- | The system-level @ldexp@ function.
ldexp :: Double -> CInt -> Double
ldexp = c_ldexp

-- | The system-level @ldexpf@ function.
ldexpf :: Float -> CInt -> Float
ldexpf = c_ldexpf
28 changes: 28 additions & 0 deletions src/Language/Futhark/Primitive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,34 @@ primFuns =
f32_2 "nextafter32" nextafterf,
f64_2 "nextafter64" nextafter,
--
( "ldexp16",
( [FloatType Float16, IntType Int32],
FloatType Float16,
\case
[FloatValue (Float16Value x), IntValue (Int32Value y)] ->
Just $ FloatValue $ Float16Value $ x * (2 ** fromIntegral y)
_ -> Nothing
)
),
( "ldexp32",
( [FloatType Float32, IntType Int32],
FloatType Float32,
\case
[FloatValue (Float32Value x), IntValue (Int32Value y)] ->
Just $ FloatValue $ Float32Value $ x * (2 ** fromIntegral y)
_ -> Nothing
)
),
( "ldexp64",
( [FloatType Float64, IntType Int32],
FloatType Float64,
\case
[FloatValue (Float64Value x), IntValue (Int32Value y)] ->
Just $ FloatValue $ Float64Value $ x * (2 ** fromIntegral y)
_ -> Nothing
)
),
--
f16 "gamma16" $ convFloat . tgammaf . convFloat,
f32 "gamma32" tgammaf,
f64 "gamma64" tgamma,
Expand Down
18 changes: 18 additions & 0 deletions tests/primitive/ldexp.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
-- ==
-- entry: test_f16
-- input { [7f16, 7f16, -0f16, f16.inf, 1f16] [-4, 4, 10, -1, 1000] }
-- output { [0.437500f16, 112f16, -0f16, f16.inf, f16.inf] }

-- ==
-- entry: test_f32
-- input { [7f32, 7f32, -0f32, f32.inf, 1f32] [-4, 4, 10, -1, 1000] }
-- output { [0.437500f32, 112f32, -0f32, f32.inf, f32.inf] }

-- ==
-- entry: test_f64
-- input { [7f64, 7f64, -0f64, f64.inf, 1f64] [-4, 4, 10, -1, 10000] }
-- output { [0.437500f64, 112f64, -0f64, f64.inf, f64.inf] }

entry test_f16 = map2 f16.ldexp
entry test_f32 = map2 f32.ldexp
entry test_f64 = map2 f64.ldexp

0 comments on commit f6bfbf8

Please sign in to comment.