diff --git a/Project.toml b/Project.toml index de04aeb..15e4b3e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "BFloat16s" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -authors = ["Keno Fischer "] +authors = ["Keno Fischer ", "Jeffrey Sarnoff "] version = "0.5.1" [deps] diff --git a/README.md b/README.md index 1177cce..7cf0f50 100644 --- a/README.md +++ b/README.md @@ -36,34 +36,27 @@ a method for `BFloat16` although it should. In this case, please raise an issue close the gap in support compared to other low-precision types like `Float16`. The usage of `BFloat16` should be as smooth as the following example, solving a linear equation system -```julia -julia> A = randn(BFloat16,3,3) -3×3 Matrix{BFloat16}: - 1.46875 -1.20312 -1.0 - 0.257812 -0.671875 -0.929688 - -0.410156 -1.75 -0.0162354 - -julia> b = randn(BFloat16,3) -3-element Vector{BFloat16}: - -0.26367188 - -0.14160156 - 0.77734375 - -julia> A\b -3-element Vector{BFloat16}: - -0.24902344 - -0.38671875 - 0.36328125 +This package exports the BFloat16 data type. This datatype should behave +just like any builtin floating point type (e.g. you can construct it from +other floating point types - e.g. `BFloat16(1.0)`). Many predicates, +conversion, structural and mathematical functions are supported: +``` + Int16, Int32, Int64, Float16, Float32, Float64, +, -, *, /, ^, ==, <, <=, >=, >, !=, inv, + isfinite, isnan, precision, iszero, eps, typemin, typemax, floatmin, floatmax, + sign_mask, exponent_mask, significand_mask, exponent_bits, significand_bits, exponent_bias, + signbit, exponent, significand, frexp, ldexp, exponent_one, exponent_half, + exp, exp2, exp10, expm1, log, log2, log10, log1p, + sin, cos, tan, csc, sec, cot, asin, acos, atan, acsc, asec, acot, + sinh, cosh, tanh, csch, sech, coth, asinh, acosh, atanh, acsch, asech, acoth, + round, trunc, floor, ceil, abs, abs2, sqrt, cbrt, clamp, hypot, bitstring ``` -## `LowPrecArray` for mixed-precision Float32/BFloat16 matrix multiplications +In addition, this package provides the `LowPrecArray` type. This array is +supposed to emulate the kind of matmul operation that TPUs do well +(BFloat16 multiply with Float32 accumulate). Broadcasts and scalar operations +are peformed in Float32 (as they would be on a TPU) while matrix multiplies +are performed in BFloat16 with Float32 accumulates, e.g. -In addition, this package provides the `LowPrecArray` type. -This array is supposed to emulate the kind -of matrix multiplications that TPUs do well (BFloat16 multiply with Float32 -accumulate). Broadcasts and scalar operations are peformed in Float32 (as -they would be on a TPU) while matrix multiplies are performed in BFloat16 with -Float32 accumulates, e.g. ```julia julia> A = LowPrecArray(rand(Float32, 5, 5)) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index 0b3e8e0..1b234a2 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -13,6 +13,7 @@ import Base: isfinite, isnan, precision, iszero, eps, asin, acos, atan, acsc, asec, acot, sinh, cosh, tanh, csch, sech, coth, asinh, acosh, atanh, acsch, asech, acoth, + clamp, hypot, bitstring, isinteger import Printf @@ -222,6 +223,13 @@ else end end +# accept Irrational +BFloat16s.BFloat16(x::Irrational) = BFloat16(Float32(x)) + +# Truncation to integer types +Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x)) +Base.trunc(::Type{T}, x::BFloat16) where {T<:Integer} = trunc(T, Float32(x)) + # BigFloat conversion BFloat16(x::BigFloat) = BFloat16(Float32(x)) Base.BigFloat(x::BFloat16) = BigFloat(Float32(x)) @@ -424,3 +432,10 @@ for F in (:abs, :abs2, :sqrt, :cbrt, Base.$F(x::BFloat16) = BFloat16($F(Float32(x))) end end + +Base.atan(y::BFloat16, x::BFloat16) = BFloat16(atan(Float32(y), Float32(x))) +Base.hypot(x::BFloat16, y::BFloat16) = BFloat16(hypot(Float32(x), Float32(y))) +Base.hypot(x::BFloat16, y::BFloat16, z::BFloat16) = BFloat16(hypot(Float32(x), Float32(y), Float32(z))) +Base.clamp(x::BFloat16, lo::BFloat16, hi::BFloat16) = BFloat16(clamp(Float32(x), Float32(lo), Float32(hi))) + +Base.bitstring(x::BFloat16) = bitstring(reinterpret(UInt16, x))