diff --git a/jxl/src/frame/render.rs b/jxl/src/frame/render.rs index 89280063..591370a1 100644 --- a/jxl/src/frame/render.rs +++ b/jxl/src/frame/render.rs @@ -237,8 +237,8 @@ impl Frame { } } for i in 3..num_channels { - pipeline = - pipeline.add_inout_stage(ConvertModularToF32Stage::new(i, metadata.bit_depth))?; + let ec_bit_depth = metadata.extra_channel_info[i - 3].bit_depth(); + pipeline = pipeline.add_inout_stage(ConvertModularToF32Stage::new(i, ec_bit_depth))?; } for c in 0..3 { diff --git a/jxl/src/headers/extra_channels.rs b/jxl/src/headers/extra_channels.rs index 78aa235e..e6bd5a33 100644 --- a/jxl/src/headers/extra_channels.rs +++ b/jxl/src/headers/extra_channels.rs @@ -89,6 +89,9 @@ impl ExtraChannelInfo { pub fn alpha_associated(&self) -> bool { self.alpha_associated } + pub fn bit_depth(&self) -> BitDepth { + self.bit_depth + } fn check(&self, _: &Empty) -> Result<(), Error> { if self.dim_shift > 3 { Err(Error::DimShiftTooLarge(self.dim_shift)) @@ -97,3 +100,119 @@ impl ExtraChannelInfo { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::headers::bit_depth::BitDepth; + + /// Test that extra channels can have their own bit depth independent of the image. + /// + /// This is important because extra channels (like alpha, depth maps, etc.) may + /// have different precision requirements than the main color channels. For example, + /// an image might be 8-bit RGB with a 16-bit alpha channel. + /// + /// Previously the render pipeline incorrectly used the image's metadata bit depth + /// for all extra channels, causing incorrect conversion for channels with different + /// bit depths. + #[test] + fn test_extra_channel_bit_depth() { + // Create an 8-bit extra channel + let ec_8bit = ExtraChannelInfo::new( + false, + ExtraChannel::Alpha, + BitDepth::integer_samples(8), + 0, + "alpha".to_string(), + false, + None, + None, + ); + assert_eq!(ec_8bit.bit_depth().bits_per_sample(), 8); + + // Create a 16-bit extra channel + let ec_16bit = ExtraChannelInfo::new( + false, + ExtraChannel::Depth, + BitDepth::integer_samples(16), + 0, + "depth".to_string(), + false, + None, + None, + ); + assert_eq!(ec_16bit.bit_depth().bits_per_sample(), 16); + + // Verify they are independent + assert_ne!( + ec_8bit.bit_depth().bits_per_sample(), + ec_16bit.bit_depth().bits_per_sample() + ); + } + + /// Test that the bit_depth getter returns the correct value for float samples. + #[test] + fn test_extra_channel_float_bit_depth() { + let ec_float = ExtraChannelInfo::new( + false, + ExtraChannel::Depth, + BitDepth::f32(), + 0, + "depth_float".to_string(), + false, + None, + None, + ); + assert!(ec_float.bit_depth().floating_point_sample()); + assert_eq!(ec_float.bit_depth().bits_per_sample(), 32); + } + + /// Test that using the wrong bit depth for conversion produces incorrect values. + /// + /// This test demonstrates why the render pipeline MUST use each extra channel's + /// own bit_depth rather than the image's global bit_depth. + /// + /// The modular-to-f32 conversion scale is: 1.0 / ((1 << bits) - 1) + /// - 8-bit: scale = 1/255, so value 255 → 1.0 + /// - 16-bit: scale = 1/65535, so value 255 → ~0.00389 + /// + /// If an 8-bit extra channel is decoded using a 16-bit scale (the image's bit depth), + /// the maximum value (255) would map to 0.00389 instead of 1.0 - completely wrong! + #[test] + fn test_wrong_bit_depth_produces_wrong_conversion() { + // Simulate the conversion scale calculation from ConvertModularToF32Stage + fn conversion_scale(bits: u32) -> f32 { + 1.0 / ((1u64 << bits) - 1) as f32 + } + + let scale_8bit = conversion_scale(8); + let scale_16bit = conversion_scale(16); + + // Max 8-bit value + let max_8bit_value = 255i32; + + // Correct conversion: 8-bit channel with 8-bit scale + let correct_result = max_8bit_value as f32 * scale_8bit; + assert!( + (correct_result - 1.0).abs() < 1e-6, + "8-bit max value should convert to 1.0, got {}", + correct_result + ); + + // WRONG conversion: 8-bit channel with 16-bit scale (the bug!) + let wrong_result = max_8bit_value as f32 * scale_16bit; + assert!( + (wrong_result - 0.00389).abs() < 0.0001, + "Using wrong scale, 255 converts to ~0.00389, got {}", + wrong_result + ); + + // The difference is catastrophic - values would be ~257x too small + let ratio = correct_result / wrong_result; + assert!( + ratio > 250.0, + "Using wrong bit depth causes ~257x error, ratio was {}", + ratio + ); + } +} diff --git a/jxl/src/render/stages/convert.rs b/jxl/src/render/stages/convert.rs index 582fee4b..4b0859ca 100644 --- a/jxl/src/render/stages/convert.rs +++ b/jxl/src/render/stages/convert.rs @@ -448,4 +448,89 @@ mod test { fn f32_to_f16_consistency() -> Result<()> { crate::render::test::test_stage_consistency(|| ConvertF32ToF16Stage::new(0), (500, 500), 1) } + + /// Test that modular-to-f32 conversion scale depends on bit depth. + /// + /// This test verifies the core math that ConvertModularToF32Stage uses: + /// - 8-bit: scale = 1/255, so max value 255 → 1.0 + /// - 16-bit: scale = 1/65535, so value 255 → ~0.00389 + /// + /// If the render pipeline passes the wrong bit_depth to ConvertModularToF32Stage + /// (e.g., using the image's bit_depth instead of the extra channel's bit_depth), + /// the output values would be catastrophically wrong (~257x too small). + #[test] + fn test_modular_to_f32_bit_depth_matters() { + // This tests the same scale calculation used in ConvertModularToF32Stage::process_row_chunk + fn conversion_scale(bits: u32) -> f32 { + 1.0 / ((1u64 << bits) - 1) as f32 + } + + let scale_8bit = conversion_scale(8); + let scale_16bit = conversion_scale(16); + + // Test values + let test_values: Vec = vec![0, 127, 255]; + + // 8-bit conversion + let results_8bit: Vec = test_values + .iter() + .map(|&v| v as f32 * scale_8bit) + .collect(); + + assert!( + (results_8bit[0] - 0.0).abs() < 1e-6, + "8-bit: 0 should convert to 0.0, got {}", + results_8bit[0] + ); + assert!( + (results_8bit[1] - 0.498).abs() < 0.01, + "8-bit: 127 should convert to ~0.498, got {}", + results_8bit[1] + ); + assert!( + (results_8bit[2] - 1.0).abs() < 1e-6, + "8-bit: 255 should convert to 1.0, got {}", + results_8bit[2] + ); + + // 16-bit conversion of same values (demonstrates the bug impact) + let results_16bit: Vec = test_values + .iter() + .map(|&v| v as f32 * scale_16bit) + .collect(); + + // With wrong (16-bit) scale, 255 converts to 255/65535 ≈ 0.00389 + assert!( + (results_16bit[2] - 0.00389).abs() < 0.0001, + "16-bit scale: 255 should convert to ~0.00389, got {}", + results_16bit[2] + ); + + // CRITICAL: Using wrong bit depth causes ~257x error! + let ratio = results_8bit[2] / results_16bit[2]; + assert!( + ratio > 250.0, + "Using wrong bit depth causes ~257x error, ratio was {}", + ratio + ); + } + + /// Test ConvertModularToF32Stage consistency with different bit depths. + #[test] + fn modular_to_f32_8bit_consistency() -> Result<()> { + crate::render::test::test_stage_consistency( + || ConvertModularToF32Stage::new(0, BitDepth::integer_samples(8)), + (500, 500), + 1, + ) + } + + #[test] + fn modular_to_f32_16bit_consistency() -> Result<()> { + crate::render::test::test_stage_consistency( + || ConvertModularToF32Stage::new(0, BitDepth::integer_samples(16)), + (500, 500), + 1, + ) + } }