diff --git a/jxl/src/headers/bit_depth.rs b/jxl/src/headers/bit_depth.rs index 898e721e..e83217da 100644 --- a/jxl/src/headers/bit_depth.rs +++ b/jxl/src/headers/bit_depth.rs @@ -65,6 +65,14 @@ impl BitDepth { exponent_bits_per_sample: 8, } } + #[cfg(test)] + pub fn f16() -> BitDepth { + BitDepth { + floating_point_sample: true, + bits_per_sample: 16, + exponent_bits_per_sample: 5, + } + } pub fn bits_per_sample(&self) -> u32 { self.bits_per_sample } diff --git a/jxl/src/render/stages/convert.rs b/jxl/src/render/stages/convert.rs index 582fee4b..73f91ecf 100644 --- a/jxl/src/render/stages/convert.rs +++ b/jxl/src/render/stages/convert.rs @@ -8,7 +8,7 @@ use crate::{ headers::bit_depth::BitDepth, render::{Channels, ChannelsMut, RenderPipelineInOutStage}, }; -use jxl_simd::{F32SimdVec, simd_function}; +use jxl_simd::{F32SimdVec, I32SimdVec, SimdMask, shl, simd_function}; pub struct ConvertU8F32Stage { channel: usize, @@ -135,20 +135,184 @@ impl std::fmt::Display for ConvertModularToF32Stage { } } +// SIMD 32-bit float passthrough (bitcast i32 to f32) +simd_function!( + int_to_float_32bit_simd_dispatch, + d: D, + fn int_to_float_32bit_simd(input: &[i32], output: &mut [f32], xsize: usize) { + let simd_width = D::I32Vec::LEN; + let num_full_chunks = xsize / simd_width; + + // Process full SIMD chunks + for (in_chunk, out_chunk) in input + .chunks_exact(simd_width) + .zip(output.chunks_exact_mut(simd_width)) + .take(num_full_chunks) + { + let val = D::I32Vec::load(d, in_chunk); + val.bitcast_to_f32().store(out_chunk); + } + + // Handle remainder with scalar + let remainder_start = num_full_chunks * simd_width; + for i in remainder_start..xsize { + output[i] = f32::from_bits(input[i] as u32); + } + } +); + +// SIMD 16-bit float (half-precision) to 32-bit float conversion +// This handles IEEE 754 binary16 format: 1 sign bit, 5 exponent bits, 10 mantissa bits +simd_function!( + int_to_float_16bit_simd_dispatch, + d: D, + fn int_to_float_16bit_simd(input: &[i32], output: &mut [f32], xsize: usize) { + let simd_width = D::I32Vec::LEN; + let num_full_chunks = xsize / simd_width; + + // Constants for 16-bit float (exp_bits=5, mant_bits=10) + let abs_mask = D::I32Vec::splat(d, 0x7FFF); // Mask for absolute value + let exp_mask = D::I32Vec::splat(d, 0x7C00); // Exponent bits in f16 + let mant_mask = D::I32Vec::splat(d, 0x03FF); // Mantissa bits in f16 + let exp_max = D::I32Vec::splat(d, 0x7C00); // Max exponent (inf/nan) + let exp_bias_adjust = D::I32Vec::splat(d, (127 - 15) << 23); // Bias adjustment shifted + let f32_inf_exp = D::I32Vec::splat(d, 0x7F80_0000_u32 as i32); + + for (in_chunk, out_chunk) in input + .chunks_exact(simd_width) + .zip(output.chunks_exact_mut(simd_width)) + .take(num_full_chunks) + { + let val = D::I32Vec::load(d, in_chunk); + + // Extract components + let abs_val = val & abs_mask; // Absolute value (exp + mantissa) + let exp_bits = val & exp_mask; // Exponent bits + let mant_bits = val & mant_mask; // Mantissa bits + + // Check for zero + let is_zero = abs_val.eq_zero(); + + // Check for inf/nan (exponent all 1s) + let is_inf_nan = exp_bits.eq(exp_max); + + // Check for subnormal (exponent is 0 but mantissa non-zero) + // Use andnot: !mant_is_zero & exp_is_zero + let exp_is_zero = exp_bits.eq_zero(); + let mant_is_zero = mant_bits.eq_zero(); + // is_subnormal = exp_is_zero AND NOT mant_is_zero + let is_subnormal = mant_is_zero.andnot(exp_is_zero); + + // Normal case: shift exponent and mantissa, adjust bias + // Sign bit at position 15 goes to position 31: shift left by 16 + // f16 exponent at bits 10-14 goes to f32 exponent at bits 23-30 + // f16 mantissa at bits 0-9 goes to f32 mantissa at bits 13-22 (shift left by 13) + let sign_shifted = shl!(val, 16) & D::I32Vec::splat(d, 0x8000_0000_u32 as i32); + let normal_exp = shl!(exp_bits, 13); + let normal_mant = shl!(mant_bits, 13); + let normal_result = sign_shifted | (normal_exp + exp_bias_adjust) | normal_mant; + + // Inf/NaN case: preserve mantissa pattern, set f32 inf exponent + let inf_nan_result = sign_shifted | f32_inf_exp | normal_mant; + + // Zero case: just the sign bit + let zero_result = sign_shifted; + + // Select result based on conditions + // Start with normal result, then override special cases + let result = is_inf_nan.if_then_else_i32(inf_nan_result, normal_result); + let result = is_zero.if_then_else_i32(zero_result, result); + + // For subnormals, fall back to scalar (rare case) + // maskz_i32 returns 0 where mask is true, so if any subnormal exists, + // there will be a 0 in subnormal_check, meaning eq_zero().all() would be true + // only if ALL elements are subnormal. We want to check if ANY are subnormal. + // So we check the inverse: if NOT eq_zero for all (meaning no subnormals), use SIMD. + let subnormal_check = is_subnormal.maskz_i32(D::I32Vec::splat(d, 1)); + // subnormal_check is 0 where is_subnormal=true, 1 where is_subnormal=false + // If all elements are 1 (no subnormals), eq(splat(1)).all() is true + let no_subnormals = subnormal_check.eq(D::I32Vec::splat(d, 1)); + if no_subnormals.all() { + // No subnormals - use SIMD result + result.bitcast_to_f32().store(out_chunk); + } else { + // At least one subnormal - process this chunk scalar + for (&in_val, out_val) in in_chunk.iter().zip(out_chunk.iter_mut()) { + *out_val = int_to_float_16bit_scalar(in_val); + } + } + } + + // Handle remainder with scalar + let remainder_start = num_full_chunks * simd_width; + for i in remainder_start..xsize { + output[i] = int_to_float_16bit_scalar(input[i]); + } + } +); + +// Scalar fallback for 16-bit float conversion (handles subnormals) +#[inline] +fn int_to_float_16bit_scalar(in_val: i32) -> f32 { + let mut f = in_val as u32; + let signbit = (f >> 15) != 0; + f &= 0x7FFF; + if f == 0 { + return if signbit { -0.0 } else { 0.0 }; + } + let mut exp = (f >> 10) as i32; + let mut mantissa = f & 0x3FF; + if exp == 31 { + // NaN or infinity + f = if signbit { 0x80000000 } else { 0 }; + f |= 0xFF << 23; + f |= mantissa << 13; + return f32::from_bits(f); + } + mantissa <<= 13; + if exp == 0 { + // subnormal number - normalize + while (mantissa & 0x800000) == 0 { + mantissa <<= 1; + exp -= 1; + } + exp += 1; + mantissa &= 0x7fffff; + } + exp = exp - 15 + 127; + f = if signbit { 0x80000000 } else { 0 }; + f |= (exp as u32) << 23; + f |= mantissa; + f32::from_bits(f) +} + // Converts custom [bits]-bit float (with [exp_bits] exponent bits) stored as // int back to binary32 float. -// TODO(sboukortt): SIMD fn int_to_float(input: &[i32], output: &mut [f32], bit_depth: &BitDepth) { assert_eq!(input.len(), output.len()); let bits = bit_depth.bits_per_sample(); let exp_bits = bit_depth.exponent_bits_per_sample(); - if bits == 32 { - assert_eq!(exp_bits, 8); - for (&in_val, out_val) in input.iter().zip(output) { - *out_val = f32::from_bits(in_val as u32); - } + let xsize = input.len(); + + // Use SIMD fast paths for common formats + if bits == 32 && exp_bits == 8 { + // 32-bit float passthrough + int_to_float_32bit_simd_dispatch(input, output, xsize); return; } + + if bits == 16 && exp_bits == 5 { + // IEEE 754 half-precision (f16) - common HDR format + int_to_float_16bit_simd_dispatch(input, output, xsize); + return; + } + + // Generic scalar path for other custom float formats + int_to_float_generic(input, output, bits, exp_bits); +} + +// Generic scalar conversion for arbitrary bit-depth floats +fn int_to_float_generic(input: &[i32], output: &mut [f32], bits: u32, exp_bits: u32) { let exp_bias = (1 << (exp_bits - 1)) - 1; let sign_shift = bits - 1; let mant_bits = bits - exp_bits - 1; @@ -419,6 +583,7 @@ impl RenderPipelineInOutStage for ConvertF32ToF16Stage { mod test { use super::*; use crate::error::Result; + use crate::headers::bit_depth::BitDepth; use test_log::test; #[test] @@ -448,4 +613,161 @@ mod test { fn f32_to_f16_consistency() -> Result<()> { crate::render::test::test_stage_consistency(|| ConvertF32ToF16Stage::new(0), (500, 500), 1) } + + #[test] + fn test_int_to_float_32bit() { + // Test 32-bit float passthrough + let bit_depth = BitDepth::f32(); + let test_values: Vec = vec![ + 0.0, + 1.0, + -1.0, + 0.5, + -0.5, + f32::INFINITY, + f32::NEG_INFINITY, + 1e-30, + 1e30, + ]; + let input: Vec = test_values.iter().map(|&f| f.to_bits() as i32).collect(); + let mut output = vec![0.0f32; input.len()]; + + int_to_float(&input, &mut output, &bit_depth); + + for (i, (&expected, &actual)) in test_values.iter().zip(output.iter()).enumerate() { + if expected.is_nan() { + assert!(actual.is_nan(), "index {}: expected NaN, got {}", i, actual); + } else { + assert_eq!(expected, actual, "index {}: mismatch", i); + } + } + } + + #[test] + fn test_int_to_float_16bit_normal() { + // Test 16-bit float (f16) conversion for normal values + let bit_depth = BitDepth::f16(); + + // f16 format: 1 sign, 5 exp, 10 mantissa + // Test cases: (f16_bits, expected_f32) + let test_cases: Vec<(u16, f32)> = vec![ + (0x0000, 0.0), // +0 + (0x8000, -0.0), // -0 + (0x3C00, 1.0), // 1.0 + (0xBC00, -1.0), // -1.0 + (0x3800, 0.5), // 0.5 + (0x4000, 2.0), // 2.0 + (0x4400, 4.0), // 4.0 + (0x7BFF, 65504.0), // max normal f16 + ]; + + let input: Vec = test_cases.iter().map(|(bits, _)| *bits as i32).collect(); + let mut output = vec![0.0f32; input.len()]; + + int_to_float(&input, &mut output, &bit_depth); + + for (i, ((_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() { + assert!( + (expected - actual).abs() < 1e-6 + || (expected.is_sign_negative() == actual.is_sign_negative() + && *expected == 0.0 + && actual == 0.0), + "index {}: expected {}, got {}", + i, + expected, + actual + ); + } + } + + #[test] + fn test_int_to_float_16bit_special() { + // Test 16-bit float conversion for special values (inf, nan) + let bit_depth = BitDepth::f16(); + + let test_cases: Vec<(u16, f32)> = vec![ + (0x7C00, f32::INFINITY), // +inf + (0xFC00, f32::NEG_INFINITY), // -inf + ]; + + let input: Vec = test_cases.iter().map(|(bits, _)| *bits as i32).collect(); + let mut output = vec![0.0f32; input.len()]; + + int_to_float(&input, &mut output, &bit_depth); + + for (i, ((_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() { + assert_eq!( + *expected, actual, + "index {}: expected {}, got {}", + i, expected, actual + ); + } + } + + #[test] + fn test_int_to_float_16bit_subnormal() { + // Test 16-bit float conversion for subnormal values + let bit_depth = BitDepth::f16(); + + // Verify bit_depth is set correctly + assert_eq!(bit_depth.bits_per_sample(), 16); + assert_eq!(bit_depth.exponent_bits_per_sample(), 5); + assert!(bit_depth.floating_point_sample()); + + // Smallest subnormal: 2^-24 ≈ 5.96e-8 + // Largest subnormal: (2^10 - 1) * 2^-24 ≈ 6.10e-5 + let test_cases: Vec<(u16, f32)> = vec![ + (0x0001, 5.960_464_5e-8), // smallest positive subnormal + (0x03FF, 6.097_555e-5), // largest positive subnormal + (0x8001, -5.960_464_5e-8), // smallest negative subnormal + ]; + + // First test the scalar function directly + for (bits, expected) in &test_cases { + let scalar_result = int_to_float_16bit_scalar(*bits as i32); + let rel_err = ((expected - scalar_result) / expected).abs(); + assert!( + rel_err < 1e-6, + "scalar: bits=0x{:04X}, expected {}, got {}, rel_err {}", + bits, + expected, + scalar_result, + rel_err + ); + } + + // Test through int_to_float_generic (which should match scalar) + let input: Vec = test_cases.iter().map(|(bits, _)| *bits as i32).collect(); + let mut generic_output = vec![0.0f32; input.len()]; + int_to_float_generic(&input, &mut generic_output, 16, 5); + for (i, ((_, expected), &actual)) in + test_cases.iter().zip(generic_output.iter()).enumerate() + { + let rel_err = ((expected - actual) / expected).abs(); + assert!( + rel_err < 1e-6, + "generic index {}: expected {}, got {}, rel_err {}", + i, + expected, + actual, + rel_err + ); + } + + // Now test through the main function (uses SIMD dispatch) + let mut output = vec![0.0f32; input.len()]; + int_to_float(&input, &mut output, &bit_depth); + + for (i, ((_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() { + let rel_err = ((expected - actual) / expected).abs(); + assert!( + rel_err < 1e-6, + "simd index {}: expected {}, got {}, rel_err {}", + i, + expected, + actual, + rel_err + ); + } + } }