Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions jxl/src/headers/bit_depth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ impl BitDepth {
exponent_bits_per_sample: 8,
}
}
#[cfg(test)]
pub fn f16() -> BitDepth {
BitDepth {
floating_point_sample: true,
bits_per_sample: 16,
exponent_bits_per_sample: 5,
}
}
pub fn bits_per_sample(&self) -> u32 {
self.bits_per_sample
}
Expand Down
336 changes: 329 additions & 7 deletions jxl/src/render/stages/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
headers::bit_depth::BitDepth,
render::{Channels, ChannelsMut, RenderPipelineInOutStage},
};
use jxl_simd::{F32SimdVec, simd_function};
use jxl_simd::{F32SimdVec, I32SimdVec, SimdMask, shl, simd_function};

pub struct ConvertU8F32Stage {
channel: usize,
Expand Down Expand Up @@ -135,20 +135,184 @@ impl std::fmt::Display for ConvertModularToF32Stage {
}
}

// SIMD 32-bit float passthrough (bitcast i32 to f32)
simd_function!(
int_to_float_32bit_simd_dispatch,
d: D,
fn int_to_float_32bit_simd(input: &[i32], output: &mut [f32], xsize: usize) {
let simd_width = D::I32Vec::LEN;
let num_full_chunks = xsize / simd_width;

// Process full SIMD chunks
for (in_chunk, out_chunk) in input
.chunks_exact(simd_width)
.zip(output.chunks_exact_mut(simd_width))
.take(num_full_chunks)
{
let val = D::I32Vec::load(d, in_chunk);
val.bitcast_to_f32().store(out_chunk);
}

// Handle remainder with scalar
let remainder_start = num_full_chunks * simd_width;
for i in remainder_start..xsize {
output[i] = f32::from_bits(input[i] as u32);
}
}
);

// SIMD 16-bit float (half-precision) to 32-bit float conversion
// This handles IEEE 754 binary16 format: 1 sign bit, 5 exponent bits, 10 mantissa bits
simd_function!(
Copy link
Member

@veluca93 veluca93 Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would prefer to have a pair of functions I32Vec::store_u16() and F32Vec::load_f16_bits() instead. Those functions function can use _mm256_cvtph_ps on AVX2 (by also requiring the F16C target feature, which is common) and vcvt_f32_f16 on NEON (although the Rust definition erroneously requires the f16 target feature, so we'd have to use inline assembly for now -- fixed in rust-lang/stdarch#1978), and fall back to scalar on SSE4.2.

We then could add store_f16() -- implemented in a similar way -- and use that to speed up the f16 conversion code..

int_to_float_16bit_simd_dispatch,
d: D,
fn int_to_float_16bit_simd(input: &[i32], output: &mut [f32], xsize: usize) {
let simd_width = D::I32Vec::LEN;
let num_full_chunks = xsize / simd_width;

// Constants for 16-bit float (exp_bits=5, mant_bits=10)
let abs_mask = D::I32Vec::splat(d, 0x7FFF); // Mask for absolute value
let exp_mask = D::I32Vec::splat(d, 0x7C00); // Exponent bits in f16
let mant_mask = D::I32Vec::splat(d, 0x03FF); // Mantissa bits in f16
let exp_max = D::I32Vec::splat(d, 0x7C00); // Max exponent (inf/nan)
let exp_bias_adjust = D::I32Vec::splat(d, (127 - 15) << 23); // Bias adjustment shifted
let f32_inf_exp = D::I32Vec::splat(d, 0x7F80_0000_u32 as i32);

for (in_chunk, out_chunk) in input
.chunks_exact(simd_width)
.zip(output.chunks_exact_mut(simd_width))
.take(num_full_chunks)
{
let val = D::I32Vec::load(d, in_chunk);

// Extract components
let abs_val = val & abs_mask; // Absolute value (exp + mantissa)
let exp_bits = val & exp_mask; // Exponent bits
let mant_bits = val & mant_mask; // Mantissa bits

// Check for zero
let is_zero = abs_val.eq_zero();

// Check for inf/nan (exponent all 1s)
let is_inf_nan = exp_bits.eq(exp_max);

// Check for subnormal (exponent is 0 but mantissa non-zero)
// Use andnot: !mant_is_zero & exp_is_zero
let exp_is_zero = exp_bits.eq_zero();
let mant_is_zero = mant_bits.eq_zero();
// is_subnormal = exp_is_zero AND NOT mant_is_zero
let is_subnormal = mant_is_zero.andnot(exp_is_zero);

// Normal case: shift exponent and mantissa, adjust bias
// Sign bit at position 15 goes to position 31: shift left by 16
// f16 exponent at bits 10-14 goes to f32 exponent at bits 23-30
// f16 mantissa at bits 0-9 goes to f32 mantissa at bits 13-22 (shift left by 13)
let sign_shifted = shl!(val, 16) & D::I32Vec::splat(d, 0x8000_0000_u32 as i32);
let normal_exp = shl!(exp_bits, 13);
let normal_mant = shl!(mant_bits, 13);
let normal_result = sign_shifted | (normal_exp + exp_bias_adjust) | normal_mant;

// Inf/NaN case: preserve mantissa pattern, set f32 inf exponent
let inf_nan_result = sign_shifted | f32_inf_exp | normal_mant;

// Zero case: just the sign bit
let zero_result = sign_shifted;

// Select result based on conditions
// Start with normal result, then override special cases
let result = is_inf_nan.if_then_else_i32(inf_nan_result, normal_result);
let result = is_zero.if_then_else_i32(zero_result, result);

// For subnormals, fall back to scalar (rare case)
// maskz_i32 returns 0 where mask is true, so if any subnormal exists,
// there will be a 0 in subnormal_check, meaning eq_zero().all() would be true
// only if ALL elements are subnormal. We want to check if ANY are subnormal.
// So we check the inverse: if NOT eq_zero for all (meaning no subnormals), use SIMD.
let subnormal_check = is_subnormal.maskz_i32(D::I32Vec::splat(d, 1));
// subnormal_check is 0 where is_subnormal=true, 1 where is_subnormal=false
// If all elements are 1 (no subnormals), eq(splat(1)).all() is true
let no_subnormals = subnormal_check.eq(D::I32Vec::splat(d, 1));
if no_subnormals.all() {
// No subnormals - use SIMD result
result.bitcast_to_f32().store(out_chunk);
} else {
// At least one subnormal - process this chunk scalar
for (&in_val, out_val) in in_chunk.iter().zip(out_chunk.iter_mut()) {
*out_val = int_to_float_16bit_scalar(in_val);
}
}
}

// Handle remainder with scalar
let remainder_start = num_full_chunks * simd_width;
for i in remainder_start..xsize {
output[i] = int_to_float_16bit_scalar(input[i]);
}
}
);

// Scalar fallback for 16-bit float conversion (handles subnormals)
#[inline]
fn int_to_float_16bit_scalar(in_val: i32) -> f32 {
let mut f = in_val as u32;
let signbit = (f >> 15) != 0;
f &= 0x7FFF;
if f == 0 {
return if signbit { -0.0 } else { 0.0 };
}
let mut exp = (f >> 10) as i32;
let mut mantissa = f & 0x3FF;
if exp == 31 {
// NaN or infinity
f = if signbit { 0x80000000 } else { 0 };
f |= 0xFF << 23;
f |= mantissa << 13;
return f32::from_bits(f);
}
mantissa <<= 13;
if exp == 0 {
// subnormal number - normalize
while (mantissa & 0x800000) == 0 {
mantissa <<= 1;
exp -= 1;
}
exp += 1;
mantissa &= 0x7fffff;
}
exp = exp - 15 + 127;
f = if signbit { 0x80000000 } else { 0 };
f |= (exp as u32) << 23;
f |= mantissa;
f32::from_bits(f)
}

// Converts custom [bits]-bit float (with [exp_bits] exponent bits) stored as
// int back to binary32 float.
// TODO(sboukortt): SIMD
fn int_to_float(input: &[i32], output: &mut [f32], bit_depth: &BitDepth) {
assert_eq!(input.len(), output.len());
let bits = bit_depth.bits_per_sample();
let exp_bits = bit_depth.exponent_bits_per_sample();
if bits == 32 {
assert_eq!(exp_bits, 8);
for (&in_val, out_val) in input.iter().zip(output) {
*out_val = f32::from_bits(in_val as u32);
}
let xsize = input.len();

// Use SIMD fast paths for common formats
if bits == 32 && exp_bits == 8 {
// 32-bit float passthrough
int_to_float_32bit_simd_dispatch(input, output, xsize);
return;
}

if bits == 16 && exp_bits == 5 {
// IEEE 754 half-precision (f16) - common HDR format
int_to_float_16bit_simd_dispatch(input, output, xsize);
return;
}

// Generic scalar path for other custom float formats
int_to_float_generic(input, output, bits, exp_bits);
}

// Generic scalar conversion for arbitrary bit-depth floats
fn int_to_float_generic(input: &[i32], output: &mut [f32], bits: u32, exp_bits: u32) {
let exp_bias = (1 << (exp_bits - 1)) - 1;
let sign_shift = bits - 1;
let mant_bits = bits - exp_bits - 1;
Expand Down Expand Up @@ -419,6 +583,7 @@ impl RenderPipelineInOutStage for ConvertF32ToF16Stage {
mod test {
use super::*;
use crate::error::Result;
use crate::headers::bit_depth::BitDepth;
use test_log::test;

#[test]
Expand Down Expand Up @@ -448,4 +613,161 @@ mod test {
fn f32_to_f16_consistency() -> Result<()> {
crate::render::test::test_stage_consistency(|| ConvertF32ToF16Stage::new(0), (500, 500), 1)
}

#[test]
fn test_int_to_float_32bit() {
// Test 32-bit float passthrough
let bit_depth = BitDepth::f32();
let test_values: Vec<f32> = vec![
0.0,
1.0,
-1.0,
0.5,
-0.5,
f32::INFINITY,
f32::NEG_INFINITY,
1e-30,
1e30,
];
let input: Vec<i32> = test_values.iter().map(|&f| f.to_bits() as i32).collect();
let mut output = vec![0.0f32; input.len()];

int_to_float(&input, &mut output, &bit_depth);

for (i, (&expected, &actual)) in test_values.iter().zip(output.iter()).enumerate() {
if expected.is_nan() {
assert!(actual.is_nan(), "index {}: expected NaN, got {}", i, actual);
} else {
assert_eq!(expected, actual, "index {}: mismatch", i);
}
}
}

#[test]
fn test_int_to_float_16bit_normal() {
// Test 16-bit float (f16) conversion for normal values
let bit_depth = BitDepth::f16();

// f16 format: 1 sign, 5 exp, 10 mantissa
// Test cases: (f16_bits, expected_f32)
let test_cases: Vec<(u16, f32)> = vec![
(0x0000, 0.0), // +0
(0x8000, -0.0), // -0
(0x3C00, 1.0), // 1.0
(0xBC00, -1.0), // -1.0
(0x3800, 0.5), // 0.5
(0x4000, 2.0), // 2.0
(0x4400, 4.0), // 4.0
(0x7BFF, 65504.0), // max normal f16
];

let input: Vec<i32> = test_cases.iter().map(|(bits, _)| *bits as i32).collect();
let mut output = vec![0.0f32; input.len()];

int_to_float(&input, &mut output, &bit_depth);

for (i, ((_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() {
assert!(
(expected - actual).abs() < 1e-6
|| (expected.is_sign_negative() == actual.is_sign_negative()
&& *expected == 0.0
&& actual == 0.0),
"index {}: expected {}, got {}",
i,
expected,
actual
);
}
}

#[test]
fn test_int_to_float_16bit_special() {
// Test 16-bit float conversion for special values (inf, nan)
let bit_depth = BitDepth::f16();

let test_cases: Vec<(u16, f32)> = vec![
(0x7C00, f32::INFINITY), // +inf
(0xFC00, f32::NEG_INFINITY), // -inf
];

let input: Vec<i32> = test_cases.iter().map(|(bits, _)| *bits as i32).collect();
let mut output = vec![0.0f32; input.len()];

int_to_float(&input, &mut output, &bit_depth);

for (i, ((_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() {
assert_eq!(
*expected, actual,
"index {}: expected {}, got {}",
i, expected, actual
);
}
}

#[test]
fn test_int_to_float_16bit_subnormal() {
// Test 16-bit float conversion for subnormal values
let bit_depth = BitDepth::f16();

// Verify bit_depth is set correctly
assert_eq!(bit_depth.bits_per_sample(), 16);
assert_eq!(bit_depth.exponent_bits_per_sample(), 5);
assert!(bit_depth.floating_point_sample());

// Smallest subnormal: 2^-24 ≈ 5.96e-8
// Largest subnormal: (2^10 - 1) * 2^-24 ≈ 6.10e-5
let test_cases: Vec<(u16, f32)> = vec![
(0x0001, 5.960_464_5e-8), // smallest positive subnormal
(0x03FF, 6.097_555e-5), // largest positive subnormal
(0x8001, -5.960_464_5e-8), // smallest negative subnormal
];

// First test the scalar function directly
for (bits, expected) in &test_cases {
let scalar_result = int_to_float_16bit_scalar(*bits as i32);
let rel_err = ((expected - scalar_result) / expected).abs();
assert!(
rel_err < 1e-6,
"scalar: bits=0x{:04X}, expected {}, got {}, rel_err {}",
bits,
expected,
scalar_result,
rel_err
);
}

// Test through int_to_float_generic (which should match scalar)
let input: Vec<i32> = test_cases.iter().map(|(bits, _)| *bits as i32).collect();
let mut generic_output = vec![0.0f32; input.len()];
int_to_float_generic(&input, &mut generic_output, 16, 5);
for (i, ((_, expected), &actual)) in
test_cases.iter().zip(generic_output.iter()).enumerate()
{
let rel_err = ((expected - actual) / expected).abs();
assert!(
rel_err < 1e-6,
"generic index {}: expected {}, got {}, rel_err {}",
i,
expected,
actual,
rel_err
);
}

// Now test through the main function (uses SIMD dispatch)
let mut output = vec![0.0f32; input.len()];
int_to_float(&input, &mut output, &bit_depth);

for (i, ((_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() {
let rel_err = ((expected - actual) / expected).abs();
assert!(
rel_err < 1e-6,
"simd index {}: expected {}, got {}, rel_err {}",
i,
expected,
actual,
rel_err
);
}
}
}
Loading