Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jxl/src/frame/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ impl Frame {
}
}
for i in 3..num_channels {
pipeline =
pipeline.add_inout_stage(ConvertModularToF32Stage::new(i, metadata.bit_depth))?;
let ec_bit_depth = metadata.extra_channel_info[i - 3].bit_depth();
pipeline = pipeline.add_inout_stage(ConvertModularToF32Stage::new(i, ec_bit_depth))?;
}

for c in 0..3 {
Expand Down
119 changes: 119 additions & 0 deletions jxl/src/headers/extra_channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ impl ExtraChannelInfo {
pub fn alpha_associated(&self) -> bool {
self.alpha_associated
}
pub fn bit_depth(&self) -> BitDepth {
self.bit_depth
}
fn check(&self, _: &Empty) -> Result<(), Error> {
if self.dim_shift > 3 {
Err(Error::DimShiftTooLarge(self.dim_shift))
Expand All @@ -97,3 +100,119 @@ impl ExtraChannelInfo {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::headers::bit_depth::BitDepth;

/// Test that extra channels can have their own bit depth independent of the image.
///
/// This is important because extra channels (like alpha, depth maps, etc.) may
/// have different precision requirements than the main color channels. For example,
/// an image might be 8-bit RGB with a 16-bit alpha channel.
///
/// Previously the render pipeline incorrectly used the image's metadata bit depth
/// for all extra channels, causing incorrect conversion for channels with different
/// bit depths.
#[test]
fn test_extra_channel_bit_depth() {
// Create an 8-bit extra channel
let ec_8bit = ExtraChannelInfo::new(
false,
ExtraChannel::Alpha,
BitDepth::integer_samples(8),
0,
"alpha".to_string(),
false,
None,
None,
);
assert_eq!(ec_8bit.bit_depth().bits_per_sample(), 8);

// Create a 16-bit extra channel
let ec_16bit = ExtraChannelInfo::new(
false,
ExtraChannel::Depth,
BitDepth::integer_samples(16),
0,
"depth".to_string(),
false,
None,
None,
);
assert_eq!(ec_16bit.bit_depth().bits_per_sample(), 16);

// Verify they are independent
assert_ne!(
ec_8bit.bit_depth().bits_per_sample(),
ec_16bit.bit_depth().bits_per_sample()
);
}

/// Test that the bit_depth getter returns the correct value for float samples.
#[test]
fn test_extra_channel_float_bit_depth() {
let ec_float = ExtraChannelInfo::new(
false,
ExtraChannel::Depth,
BitDepth::f32(),
0,
"depth_float".to_string(),
false,
None,
None,
);
assert!(ec_float.bit_depth().floating_point_sample());
assert_eq!(ec_float.bit_depth().bits_per_sample(), 32);
}

/// Test that using the wrong bit depth for conversion produces incorrect values.
///
/// This test demonstrates why the render pipeline MUST use each extra channel's
/// own bit_depth rather than the image's global bit_depth.
///
/// The modular-to-f32 conversion scale is: 1.0 / ((1 << bits) - 1)
/// - 8-bit: scale = 1/255, so value 255 → 1.0
/// - 16-bit: scale = 1/65535, so value 255 → ~0.00389
///
/// If an 8-bit extra channel is decoded using a 16-bit scale (the image's bit depth),
/// the maximum value (255) would map to 0.00389 instead of 1.0 - completely wrong!
#[test]
fn test_wrong_bit_depth_produces_wrong_conversion() {
// Simulate the conversion scale calculation from ConvertModularToF32Stage
fn conversion_scale(bits: u32) -> f32 {
1.0 / ((1u64 << bits) - 1) as f32
}

let scale_8bit = conversion_scale(8);
let scale_16bit = conversion_scale(16);

// Max 8-bit value
let max_8bit_value = 255i32;

// Correct conversion: 8-bit channel with 8-bit scale
let correct_result = max_8bit_value as f32 * scale_8bit;
assert!(
(correct_result - 1.0).abs() < 1e-6,
"8-bit max value should convert to 1.0, got {}",
correct_result
);

// WRONG conversion: 8-bit channel with 16-bit scale (the bug!)
let wrong_result = max_8bit_value as f32 * scale_16bit;
assert!(
(wrong_result - 0.00389).abs() < 0.0001,
"Using wrong scale, 255 converts to ~0.00389, got {}",
wrong_result
);

// The difference is catastrophic - values would be ~257x too small
let ratio = correct_result / wrong_result;
assert!(
ratio > 250.0,
"Using wrong bit depth causes ~257x error, ratio was {}",
ratio
);
}
}
85 changes: 85 additions & 0 deletions jxl/src/render/stages/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,4 +448,89 @@ mod test {
fn f32_to_f16_consistency() -> Result<()> {
crate::render::test::test_stage_consistency(|| ConvertF32ToF16Stage::new(0), (500, 500), 1)
}

/// Test that modular-to-f32 conversion scale depends on bit depth.
///
/// This test verifies the core math that ConvertModularToF32Stage uses:
/// - 8-bit: scale = 1/255, so max value 255 → 1.0
/// - 16-bit: scale = 1/65535, so value 255 → ~0.00389
///
/// If the render pipeline passes the wrong bit_depth to ConvertModularToF32Stage
/// (e.g., using the image's bit_depth instead of the extra channel's bit_depth),
/// the output values would be catastrophically wrong (~257x too small).
#[test]
fn test_modular_to_f32_bit_depth_matters() {
// This tests the same scale calculation used in ConvertModularToF32Stage::process_row_chunk
fn conversion_scale(bits: u32) -> f32 {
1.0 / ((1u64 << bits) - 1) as f32
}

let scale_8bit = conversion_scale(8);
let scale_16bit = conversion_scale(16);

// Test values
let test_values: Vec<i32> = vec![0, 127, 255];

// 8-bit conversion
let results_8bit: Vec<f32> = test_values
.iter()
.map(|&v| v as f32 * scale_8bit)
.collect();

assert!(
(results_8bit[0] - 0.0).abs() < 1e-6,
"8-bit: 0 should convert to 0.0, got {}",
results_8bit[0]
);
assert!(
(results_8bit[1] - 0.498).abs() < 0.01,
"8-bit: 127 should convert to ~0.498, got {}",
results_8bit[1]
);
assert!(
(results_8bit[2] - 1.0).abs() < 1e-6,
"8-bit: 255 should convert to 1.0, got {}",
results_8bit[2]
);

// 16-bit conversion of same values (demonstrates the bug impact)
let results_16bit: Vec<f32> = test_values
.iter()
.map(|&v| v as f32 * scale_16bit)
.collect();

// With wrong (16-bit) scale, 255 converts to 255/65535 ≈ 0.00389
assert!(
(results_16bit[2] - 0.00389).abs() < 0.0001,
"16-bit scale: 255 should convert to ~0.00389, got {}",
results_16bit[2]
);

// CRITICAL: Using wrong bit depth causes ~257x error!
let ratio = results_8bit[2] / results_16bit[2];
assert!(
ratio > 250.0,
"Using wrong bit depth causes ~257x error, ratio was {}",
ratio
);
}

/// Test ConvertModularToF32Stage consistency with different bit depths.
#[test]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I see the point of all the other tests, except these last two.
Do you mind removing them?

fn modular_to_f32_8bit_consistency() -> Result<()> {
crate::render::test::test_stage_consistency(
|| ConvertModularToF32Stage::new(0, BitDepth::integer_samples(8)),
(500, 500),
1,
)
}

#[test]
fn modular_to_f32_16bit_consistency() -> Result<()> {
crate::render::test::test_stage_consistency(
|| ConvertModularToF32Stage::new(0, BitDepth::integer_samples(16)),
(500, 500),
1,
)
}
}
Loading