diff --git a/jxl/src/api/decoder.rs b/jxl/src/api/decoder.rs index 68991bb4..b9602132 100644 --- a/jxl/src/api/decoder.rs +++ b/jxl/src/api/decoder.rs @@ -637,7 +637,6 @@ pub(crate) mod tests { #[test] fn test_premultiply_output_straight_alpha() { use crate::api::{JxlColorType, JxlDataFormat, JxlPixelFormat}; - use crate::image::{Image, Rect}; // Use alpha_nonpremultiplied.jxl which has straight alpha (alpha_associated=false) let file = @@ -652,83 +651,12 @@ pub(crate) mod tests { extra_channel_format: vec![None], }; - // Helper function to decode with given options - fn decode_image( - file: &[u8], - rgba_format: &JxlPixelFormat, - premultiply: bool, - use_simple: bool, - ) -> (Image, usize, usize) { - let options = JxlDecoderOptions { - premultiply_output: premultiply, - ..Default::default() - }; - let decoder = JxlDecoder::::new(options); - let mut input = file; - - // Advance to image info - let mut decoder = decoder; - let mut decoder = loop { - match decoder.process(&mut input).unwrap() { - ProcessingResult::Complete { result } => break result, - ProcessingResult::NeedsMoreInput { fallback, .. } => { - if input.is_empty() { - panic!("Unexpected end of input"); - } - decoder = fallback; - } - } - }; - decoder.set_use_simple_pipeline(use_simple); - decoder.set_pixel_format(rgba_format.clone()); - - let basic_info = decoder.basic_info().clone(); - let (width, height) = basic_info.size; - - // Advance to frame info - let mut decoder = loop { - match decoder.process(&mut input).unwrap() { - ProcessingResult::Complete { result } => break result, - ProcessingResult::NeedsMoreInput { fallback, .. } => { - if input.is_empty() { - panic!("Unexpected end of input"); - } - decoder = fallback; - } - } - }; - - let mut buffer = Image::::new((width * 4, height)).unwrap(); - let mut buffers: Vec<_> = vec![JxlOutputBuffer::from_image_rect_mut( - buffer - .get_rect_mut(Rect { - origin: (0, 0), - size: (width * 4, height), - }) - .into_raw(), - )]; - - // Decode - loop { - match decoder.process(&mut input, &mut buffers).unwrap() { - ProcessingResult::Complete { .. } => break, - ProcessingResult::NeedsMoreInput { fallback, .. } => { - if input.is_empty() { - panic!("Unexpected end of input"); - } - decoder = fallback; - } - } - } - - (buffer, width, height) - } - // Test both pipelines for use_simple in [true, false] { let (straight_buffer, width, height) = - decode_image(&file, &rgba_format, false, use_simple); - let (premul_buffer, _, _) = decode_image(&file, &rgba_format, true, use_simple); + decode_with_format::(&file, &rgba_format, use_simple, false); + let (premul_buffer, _, _) = + decode_with_format::(&file, &rgba_format, use_simple, true); // Verify premultiplied values: premul_rgb should equal straight_rgb * alpha let mut found_semitransparent = false; @@ -812,7 +740,6 @@ pub(crate) mod tests { #[test] fn test_premultiply_output_already_premultiplied() { use crate::api::{JxlColorType, JxlDataFormat, JxlPixelFormat}; - use crate::image::{Image, Rect}; // Use alpha_premultiplied.jxl which has alpha_associated=true let file = std::fs::read("resources/test/conformance_test_images/alpha_premultiplied.jxl") @@ -825,83 +752,12 @@ pub(crate) mod tests { extra_channel_format: vec![None], }; - // Helper function to decode with given options - fn decode_image( - file: &[u8], - rgba_format: &JxlPixelFormat, - premultiply: bool, - use_simple: bool, - ) -> (Image, usize, usize) { - let options = JxlDecoderOptions { - premultiply_output: premultiply, - ..Default::default() - }; - let decoder = JxlDecoder::::new(options); - let mut input = file; - - // Advance to image info - let mut decoder = decoder; - let mut decoder = loop { - match decoder.process(&mut input).unwrap() { - ProcessingResult::Complete { result } => break result, - ProcessingResult::NeedsMoreInput { fallback, .. } => { - if input.is_empty() { - panic!("Unexpected end of input"); - } - decoder = fallback; - } - } - }; - decoder.set_use_simple_pipeline(use_simple); - decoder.set_pixel_format(rgba_format.clone()); - - let basic_info = decoder.basic_info().clone(); - let (width, height) = basic_info.size; - - // Advance to frame info - let mut decoder = loop { - match decoder.process(&mut input).unwrap() { - ProcessingResult::Complete { result } => break result, - ProcessingResult::NeedsMoreInput { fallback, .. } => { - if input.is_empty() { - panic!("Unexpected end of input"); - } - decoder = fallback; - } - } - }; - - let mut buffer = Image::::new((width * 4, height)).unwrap(); - let mut buffers: Vec<_> = vec![JxlOutputBuffer::from_image_rect_mut( - buffer - .get_rect_mut(Rect { - origin: (0, 0), - size: (width * 4, height), - }) - .into_raw(), - )]; - - // Decode - loop { - match decoder.process(&mut input, &mut buffers).unwrap() { - ProcessingResult::Complete { .. } => break, - ProcessingResult::NeedsMoreInput { fallback, .. } => { - if input.is_empty() { - panic!("Unexpected end of input"); - } - decoder = fallback; - } - } - } - - (buffer, width, height) - } - // Test both pipelines for use_simple in [true, false] { let (without_flag_buffer, width, height) = - decode_image(&file, &rgba_format, false, use_simple); - let (with_flag_buffer, _, _) = decode_image(&file, &rgba_format, true, use_simple); + decode_with_format::(&file, &rgba_format, use_simple, false); + let (with_flag_buffer, _, _) = + decode_with_format::(&file, &rgba_format, use_simple, true); // Both outputs should be identical since source is already premultiplied // and we shouldn't double-premultiply @@ -1017,4 +873,253 @@ pub(crate) mod tests { frame_count ); } + + /// Test that u8 output matches f32 output within quantization tolerance. + /// This test would catch bugs like the offset miscalculation in PR #586 + /// that caused black bars in u8 output. + #[test] + fn test_output_format_u8_matches_f32() { + use crate::api::{JxlColorType, JxlDataFormat, JxlPixelFormat}; + + // Use bicycles.jxl - a larger image that exercises offset calculations + let file = std::fs::read("resources/test/conformance_test_images/bicycles.jxl").unwrap(); + + // Test both RGB and BGRA to catch channel reordering bugs + for (color_type, num_samples) in [(JxlColorType::Rgb, 3), (JxlColorType::Bgra, 4)] { + let f32_format = JxlPixelFormat { + color_type, + color_data_format: Some(JxlDataFormat::f32()), + extra_channel_format: vec![], + }; + let u8_format = JxlPixelFormat { + color_type, + color_data_format: Some(JxlDataFormat::U8 { bit_depth: 8 }), + extra_channel_format: vec![], + }; + + // Test both pipelines + for use_simple in [true, false] { + let (f32_buffer, width, height) = + decode_with_format::(&file, &f32_format, use_simple, false); + let (u8_buffer, _, _) = + decode_with_format::(&file, &u8_format, use_simple, false); + + // Compare values: u8 / 255.0 should match f32 + // Tolerance: quantization error of ±0.5/255 ≈ 0.00196 plus small rounding + let tolerance = 0.003; + let mut max_error: f32 = 0.0; + + for y in 0..height { + let f32_row = f32_buffer.row(y); + let u8_row = u8_buffer.row(y); + for x in 0..(width * num_samples) { + let f32_val = f32_row[x].clamp(0.0, 1.0); + let u8_val = u8_row[x] as f32 / 255.0; + let error = (f32_val - u8_val).abs(); + max_error = max_error.max(error); + assert!( + error < tolerance, + "{:?} u8 mismatch at ({},{}): f32={}, u8={} (scaled={}), error={} (use_simple={})", + color_type, + x, + y, + f32_val, + u8_row[x], + u8_val, + error, + use_simple + ); + } + } + } + } + } + + /// Test that u16 output matches f32 output within quantization tolerance. + #[test] + fn test_output_format_u16_matches_f32() { + use crate::api::{Endianness, JxlColorType, JxlDataFormat, JxlPixelFormat}; + + let file = std::fs::read("resources/test/conformance_test_images/bicycles.jxl").unwrap(); + + // Test both RGB and BGRA + for (color_type, num_samples) in [(JxlColorType::Rgb, 3), (JxlColorType::Bgra, 4)] { + let f32_format = JxlPixelFormat { + color_type, + color_data_format: Some(JxlDataFormat::f32()), + extra_channel_format: vec![], + }; + let u16_format = JxlPixelFormat { + color_type, + color_data_format: Some(JxlDataFormat::U16 { + endianness: Endianness::native(), + bit_depth: 16, + }), + extra_channel_format: vec![], + }; + + for use_simple in [true, false] { + let (f32_buffer, width, height) = + decode_with_format::(&file, &f32_format, use_simple, false); + let (u16_buffer, _, _) = + decode_with_format::(&file, &u16_format, use_simple, false); + + // Tolerance: quantization error of ±0.5/65535 plus small rounding + let tolerance = 0.0001; + + for y in 0..height { + let f32_row = f32_buffer.row(y); + let u16_row = u16_buffer.row(y); + for x in 0..(width * num_samples) { + let f32_val = f32_row[x].clamp(0.0, 1.0); + let u16_val = u16_row[x] as f32 / 65535.0; + let error = (f32_val - u16_val).abs(); + assert!( + error < tolerance, + "{:?} u16 mismatch at ({},{}): f32={}, u16={} (scaled={}), error={} (use_simple={})", + color_type, + x, + y, + f32_val, + u16_row[x], + u16_val, + error, + use_simple + ); + } + } + } + } + } + + /// Test that f16 output matches f32 output within f16 precision tolerance. + #[test] + fn test_output_format_f16_matches_f32() { + use crate::api::{Endianness, JxlColorType, JxlDataFormat, JxlPixelFormat}; + use crate::util::f16; + + let file = std::fs::read("resources/test/conformance_test_images/bicycles.jxl").unwrap(); + + // Test both RGB and BGRA + for (color_type, num_samples) in [(JxlColorType::Rgb, 3), (JxlColorType::Bgra, 4)] { + let f32_format = JxlPixelFormat { + color_type, + color_data_format: Some(JxlDataFormat::f32()), + extra_channel_format: vec![], + }; + let f16_format = JxlPixelFormat { + color_type, + color_data_format: Some(JxlDataFormat::F16 { + endianness: Endianness::native(), + }), + extra_channel_format: vec![], + }; + + for use_simple in [true, false] { + let (f32_buffer, width, height) = + decode_with_format::(&file, &f32_format, use_simple, false); + let (f16_buffer, _, _) = + decode_with_format::(&file, &f16_format, use_simple, false); + + // f16 has about 3 decimal digits of precision + // For values in [0,1], the relative error is about 0.001 + let tolerance = 0.002; + + for y in 0..height { + let f32_row = f32_buffer.row(y); + let f16_row = f16_buffer.row(y); + for x in 0..(width * num_samples) { + let f32_val = f32_row[x]; + let f16_val = f16_row[x].to_f32(); + let error = (f32_val - f16_val).abs(); + assert!( + error < tolerance, + "{:?} f16 mismatch at ({},{}): f32={}, f16={}, error={} (use_simple={})", + color_type, + x, + y, + f32_val, + f16_val, + error, + use_simple + ); + } + } + } + } + } + + /// Helper function to decode an image with a specific format. + fn decode_with_format( + file: &[u8], + pixel_format: &JxlPixelFormat, + use_simple: bool, + premultiply: bool, + ) -> (Image, usize, usize) { + let options = JxlDecoderOptions { + premultiply_output: premultiply, + ..Default::default() + }; + let mut decoder = JxlDecoder::::new(options); + let mut input = file; + + // Advance to image info + let mut decoder = loop { + match decoder.process(&mut input).unwrap() { + ProcessingResult::Complete { result } => break result, + ProcessingResult::NeedsMoreInput { fallback, .. } => { + if input.is_empty() { + panic!("Unexpected end of input"); + } + decoder = fallback; + } + } + }; + decoder.set_use_simple_pipeline(use_simple); + decoder.set_pixel_format(pixel_format.clone()); + + let basic_info = decoder.basic_info().clone(); + let (width, height) = basic_info.size; + + let num_samples = pixel_format.color_type.samples_per_pixel(); + + // Advance to frame info + let decoder = loop { + match decoder.process(&mut input).unwrap() { + ProcessingResult::Complete { result } => break result, + ProcessingResult::NeedsMoreInput { fallback, .. } => { + if input.is_empty() { + panic!("Unexpected end of input"); + } + decoder = fallback; + } + } + }; + + let mut buffer = Image::::new((width * num_samples, height)).unwrap(); + let mut buffers: Vec<_> = vec![JxlOutputBuffer::from_image_rect_mut( + buffer + .get_rect_mut(Rect { + origin: (0, 0), + size: (width * num_samples, height), + }) + .into_raw(), + )]; + + // Decode + let mut decoder = decoder; + loop { + match decoder.process(&mut input, &mut buffers).unwrap() { + ProcessingResult::Complete { .. } => break, + ProcessingResult::NeedsMoreInput { fallback, .. } => { + if input.is_empty() { + panic!("Unexpected end of input"); + } + decoder = fallback; + } + } + } + + (buffer, width, height) + } } diff --git a/jxl_cli/src/dec/mod.rs b/jxl_cli/src/dec/mod.rs index cf3e60ef..15a37715 100644 --- a/jxl_cli/src/dec/mod.rs +++ b/jxl_cli/src/dec/mod.rs @@ -8,10 +8,12 @@ use std::time::{Duration, Instant}; use color_eyre::eyre::{Result, eyre}; use jxl::{ api::{ - JxlAnimation, JxlBitDepth, JxlBitstreamInput, JxlColorProfile, JxlColorType, JxlDecoder, - JxlDecoderOptions, JxlOutputBuffer, ProcessingResult, states::WithImageInfo, + Endianness, JxlAnimation, JxlBitDepth, JxlBitstreamInput, JxlColorProfile, JxlColorType, + JxlDataFormat, JxlDecoder, JxlDecoderOptions, JxlOutputBuffer, JxlPixelFormat, + ProcessingResult, states::WithImageInfo, }, image::{Image, ImageDataType, Rect}, + util::f16, }; pub struct ImageFrame { @@ -121,3 +123,336 @@ pub fn decode_frames( Ok((image_data, start.elapsed())) } + +/// Output data type for decoding. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum OutputDataType { + U8, + U16, + F16, + F32, +} + +impl OutputDataType { + /// Parse from string (case-insensitive). + pub fn parse(s: &str) -> Option { + match s.to_lowercase().as_str() { + "u8" => Some(Self::U8), + "u16" => Some(Self::U16), + "f16" => Some(Self::F16), + "f32" => Some(Self::F32), + _ => None, + } + } + + /// Get the JxlDataFormat for this type. + pub fn to_data_format(self) -> JxlDataFormat { + match self { + Self::U8 => JxlDataFormat::U8 { bit_depth: 8 }, + Self::U16 => JxlDataFormat::U16 { + endianness: Endianness::native(), + bit_depth: 16, + }, + Self::F16 => JxlDataFormat::F16 { + endianness: Endianness::native(), + }, + Self::F32 => JxlDataFormat::f32(), + } + } +} + +/// Typed decode output that preserves the original output type. +/// The caller is responsible for converting to f32 when needed for saving. +pub enum TypedDecodeOutput { + U8(DecodeOutput), + U16(DecodeOutput), + F16(DecodeOutput), + F32(DecodeOutput), +} + +impl TypedDecodeOutput { + /// Get the image size. + pub fn size(&self) -> (usize, usize) { + match self { + Self::U8(d) => d.size, + Self::U16(d) => d.size, + Self::F16(d) => d.size, + Self::F32(d) => d.size, + } + } + + /// Get the output color profile. + pub fn output_profile(&self) -> &JxlColorProfile { + match self { + Self::U8(d) => &d.output_profile, + Self::U16(d) => &d.output_profile, + Self::F16(d) => &d.output_profile, + Self::F32(d) => &d.output_profile, + } + } + + /// Get the embedded color profile. + pub fn embedded_profile(&self) -> &JxlColorProfile { + match self { + Self::U8(d) => &d.embedded_profile, + Self::U16(d) => &d.embedded_profile, + Self::F16(d) => &d.embedded_profile, + Self::F32(d) => &d.embedded_profile, + } + } + + /// Get the original bit depth. + pub fn original_bit_depth(&self) -> &JxlBitDepth { + match self { + Self::U8(d) => &d.original_bit_depth, + Self::U16(d) => &d.original_bit_depth, + Self::F16(d) => &d.original_bit_depth, + Self::F32(d) => &d.original_bit_depth, + } + } + + /// Convert to f32 output for saving to encoders. + pub fn to_f32(self) -> Result> { + match self { + Self::U8(d) => convert_decode_output_to_f32(d), + Self::U16(d) => convert_decode_output_to_f32(d), + Self::F16(d) => convert_decode_output_to_f32(d), + Self::F32(d) => Ok(d), + } + } + + /// Truncate to keep only first N frames. + pub fn truncate_frames(&mut self, len: usize) { + match self { + Self::U8(d) => d.frames.truncate(len), + Self::U16(d) => d.frames.truncate(len), + Self::F16(d) => d.frames.truncate(len), + Self::F32(d) => d.frames.truncate(len), + } + } + + /// Get the first frame's color type and channel size (for preview handling). + pub fn first_frame_info(&self) -> Option<(JxlColorType, (usize, usize))> { + match self { + Self::U8(d) => d + .frames + .first() + .map(|f| (f.color_type, f.channels[0].size())), + Self::U16(d) => d + .frames + .first() + .map(|f| (f.color_type, f.channels[0].size())), + Self::F16(d) => d + .frames + .first() + .map(|f| (f.color_type, f.channels[0].size())), + Self::F32(d) => d + .frames + .first() + .map(|f| (f.color_type, f.channels[0].size())), + } + } + + /// Update the size field. + pub fn set_size(&mut self, size: (usize, usize)) { + match self { + Self::U8(d) => d.size = size, + Self::U16(d) => d.size = size, + Self::F16(d) => d.size = size, + Self::F32(d) => d.size = size, + } + } +} + +/// Decode a JXL image with a specific output data type. +/// Returns the raw typed output without conversion, so benchmark timing is accurate. +pub fn decode_frames_with_type( + input: &mut In, + decoder_options: JxlDecoderOptions, + output_type: OutputDataType, +) -> Result<(TypedDecodeOutput, Duration)> { + match output_type { + OutputDataType::U8 => { + let (output, duration) = + decode_frames_typed::(input, decoder_options, output_type)?; + Ok((TypedDecodeOutput::U8(output), duration)) + } + OutputDataType::U16 => { + let (output, duration) = + decode_frames_typed::(input, decoder_options, output_type)?; + Ok((TypedDecodeOutput::U16(output), duration)) + } + OutputDataType::F16 => { + let (output, duration) = + decode_frames_typed::(input, decoder_options, output_type)?; + Ok((TypedDecodeOutput::F16(output), duration)) + } + OutputDataType::F32 => { + let (output, duration) = decode_frames(input, decoder_options)?; + Ok((TypedDecodeOutput::F32(output), duration)) + } + } +} + +/// Generic decoder that decodes to type T. +fn decode_frames_typed( + input: &mut In, + decoder_options: JxlDecoderOptions, + output_type: OutputDataType, +) -> Result<(DecodeOutput, Duration)> { + let start = Instant::now(); + + let mut decoder_with_image_info = decode_header(input, decoder_options)?; + + // Get info and clone what we need before mutating the decoder + let info = decoder_with_image_info.basic_info().clone(); + let embedded_profile = decoder_with_image_info.embedded_color_profile().clone(); + let output_profile = decoder_with_image_info.output_color_profile().clone(); + + // Set the pixel format to the requested data type + let current_format = decoder_with_image_info.current_pixel_format().clone(); + let new_format = JxlPixelFormat { + color_type: current_format.color_type, + color_data_format: Some(output_type.to_data_format()), + extra_channel_format: current_format + .extra_channel_format + .iter() + .map(|f| f.as_ref().map(|_| output_type.to_data_format())) + .collect(), + }; + decoder_with_image_info.set_pixel_format(new_format); + + let mut image_data = DecodeOutput { + size: info.size, + frames: Vec::new(), + original_bit_depth: info.bit_depth.clone(), + output_profile, + embedded_profile, + jxl_animation: info.animation.clone(), + }; + + let extra_channels = info.extra_channels.len(); + let pixel_format = decoder_with_image_info.current_pixel_format().clone(); + let color_type = pixel_format.color_type; + let samples_per_pixel = if color_type == JxlColorType::Grayscale { + 1 + } else { + 3 + }; + + loop { + let decoder_with_frame_info = match decoder_with_image_info.process(input)? { + ProcessingResult::Complete { result } => result, + ProcessingResult::NeedsMoreInput { .. } => return Err(eyre!("Source file truncated")), + }; + + let frame_header = decoder_with_frame_info.frame_header(); + let frame_size = frame_header.size; + + // Create typed output buffers + let mut typed_outputs = vec![Image::::new(( + frame_size.0 * samples_per_pixel, + frame_size.1, + ))?]; + + for _ in 0..extra_channels { + typed_outputs.push(Image::::new(frame_size)?); + } + + let mut output_bufs: Vec> = typed_outputs + .iter_mut() + .map(|x| { + let rect = Rect { + size: x.size(), + origin: (0, 0), + }; + JxlOutputBuffer::from_image_rect_mut(x.get_rect_mut(rect).into_raw()) + }) + .collect(); + + decoder_with_image_info = match decoder_with_frame_info.process(input, &mut output_bufs)? { + ProcessingResult::Complete { result } => result, + ProcessingResult::NeedsMoreInput { .. } => return Err(eyre!("Source file truncated")), + }; + + image_data.frames.push(ImageFrame { + duration: frame_header.duration.unwrap_or(0.0), + channels: typed_outputs, + color_type, + }); + + if !decoder_with_image_info.has_more_frames() { + break; + } + } + + Ok((image_data, start.elapsed())) +} + +/// Trait for converting a value to f32. +trait ConvertToF32: Copy { + fn to_f32_normalized(self) -> f32; +} + +impl ConvertToF32 for u8 { + fn to_f32_normalized(self) -> f32 { + self as f32 / 255.0 + } +} + +impl ConvertToF32 for u16 { + fn to_f32_normalized(self) -> f32 { + self as f32 / 65535.0 + } +} + +impl ConvertToF32 for f16 { + fn to_f32_normalized(self) -> f32 { + self.to_f32() + } +} + +/// Convert a DecodeOutput from type T to f32. +fn convert_decode_output_to_f32( + src: DecodeOutput, +) -> Result> { + let mut frames = Vec::with_capacity(src.frames.len()); + for frame in src.frames { + let channels: Vec> = frame + .channels + .into_iter() + .map(|img| convert_image_to_f32(img)) + .collect::>()?; + frames.push(ImageFrame { + channels, + duration: frame.duration, + color_type: frame.color_type, + }); + } + Ok(DecodeOutput { + size: src.size, + frames, + original_bit_depth: src.original_bit_depth, + output_profile: src.output_profile, + embedded_profile: src.embedded_profile, + jxl_animation: src.jxl_animation, + }) +} + +/// Convert an image from type T to f32. +fn convert_image_to_f32( + src: Image, +) -> std::result::Result, jxl::error::Error> { + let size = src.size(); + let mut dst = Image::::new(size)?; + + for y in 0..size.1 { + let src_row = src.row(y); + let dst_row = dst.row_mut(y); + for x in 0..size.0 { + dst_row[x] = src_row[x].to_f32_normalized(); + } + } + + Ok(dst) +} diff --git a/jxl_cli/src/lib.rs b/jxl_cli/src/lib.rs index 70d1828f..5d1a28eb 100644 --- a/jxl_cli/src/lib.rs +++ b/jxl_cli/src/lib.rs @@ -5,3 +5,148 @@ pub mod dec; pub mod enc; + +#[cfg(test)] +mod tests { + use crate::dec::{OutputDataType, decode_frames_with_type}; + use jxl::api::JxlDecoderOptions; + + /// Test that decoding with all output data types produces consistent results. + /// This catches bugs in the f32→u8, f32→u16, f32→f16 conversion stages. + #[test] + fn test_output_formats_consistency() { + let test_files = [ + "../jxl/resources/test/conformance_test_images/bicycles.jxl", + "../jxl/resources/test/conformance_test_images/bike.jxl", + "../jxl/resources/test/zoltan_tasi_unsplash.jxl", // Failed in PR #586 + ]; + + for test_file in test_files { + let file = match std::fs::read(test_file) { + Ok(f) => f, + Err(_) => continue, // Skip if file not found + }; + + // Decode as f32 (reference) + let mut input = file.as_slice(); + let (f32_typed_output, _) = decode_frames_with_type( + &mut input, + JxlDecoderOptions::default(), + OutputDataType::F32, + ) + .unwrap(); + let f32_output = f32_typed_output.to_f32().unwrap(); + + // Test each data type + // clamps_values: true for integer formats that clamp to [0,1] + for (data_type, tolerance, name, clamps_values) in [ + (OutputDataType::U8, 0.003, "u8", true), // ~0.5/255 + margin + (OutputDataType::U16, 0.0001, "u16", true), // ~0.5/65535 + margin + (OutputDataType::F16, 0.001, "f16", false), // f16 precision, no clamping + ] { + let mut input = file.as_slice(); + let (typed_output, _) = + decode_frames_with_type(&mut input, JxlDecoderOptions::default(), data_type) + .unwrap(); + // Convert to f32 for comparison + let converted_output = typed_output.to_f32().unwrap(); + + assert_eq!( + f32_output.frames.len(), + converted_output.frames.len(), + "{test_file}: frame count mismatch for {name}" + ); + + for (frame_idx, (f32_frame, typed_frame)) in f32_output + .frames + .iter() + .zip(converted_output.frames.iter()) + .enumerate() + { + assert_eq!( + f32_frame.channels.len(), + typed_frame.channels.len(), + "{test_file}: channel count mismatch for {name} frame {frame_idx}" + ); + + for (ch_idx, (f32_ch, typed_ch)) in f32_frame + .channels + .iter() + .zip(typed_frame.channels.iter()) + .enumerate() + { + assert_eq!( + f32_ch.size(), + typed_ch.size(), + "{test_file}: size mismatch for {name} frame {frame_idx} channel {ch_idx}" + ); + + let size = f32_ch.size(); + let mut max_diff: f32 = 0.0; + + for y in 0..size.1 { + let f32_row = f32_ch.row(y); + let typed_row = typed_ch.row(y); + for x in 0..size.0 { + // Clamp f32 to [0,1] for integer formats that clamp output + let f32_val = if clamps_values { + f32_row[x].clamp(0.0, 1.0) + } else { + f32_row[x] + }; + let typed_val = typed_row[x]; + let diff = (f32_val - typed_val).abs(); + max_diff = max_diff.max(diff); + assert!( + diff <= tolerance, + "{test_file}: {name} mismatch at ({x},{y}): \ + f32={f32_val}, {name}={typed_val}, diff={diff}" + ); + } + } + + // Verify we actually processed pixels + assert!(size.0 > 0 && size.1 > 0); + // Log max diff for informational purposes + let _ = max_diff; + } + } + } + } + } + + /// Test that the high precision mode works with all output formats. + #[test] + fn test_output_formats_high_precision() { + let test_file = "../jxl/resources/test/conformance_test_images/bicycles.jxl"; + let file = match std::fs::read(test_file) { + Ok(f) => f, + Err(_) => return, // Skip if file not found + }; + + let high_precision_options = || { + let mut options = JxlDecoderOptions::default(); + options.high_precision = true; + options + }; + + // Decode as f32 (reference) + let mut input = file.as_slice(); + let (f32_typed_output, _) = + decode_frames_with_type(&mut input, high_precision_options(), OutputDataType::F32) + .unwrap(); + let f32_output = f32_typed_output.to_f32().unwrap(); + + // Test u8 with high precision + let mut input = file.as_slice(); + let (u8_typed_output, _) = + decode_frames_with_type(&mut input, high_precision_options(), OutputDataType::U8) + .unwrap(); + let u8_output = u8_typed_output.to_f32().unwrap(); + + // Verify both produce valid output + assert!(!f32_output.frames.is_empty()); + assert!(!u8_output.frames.is_empty()); + assert_eq!(f32_output.frames.len(), u8_output.frames.len()); + } +} diff --git a/jxl_cli/src/main.rs b/jxl_cli/src/main.rs index aaad075a..ff4b78e4 100644 --- a/jxl_cli/src/main.rs +++ b/jxl_cli/src/main.rs @@ -97,6 +97,11 @@ struct Opt { /// Use high precision mode for decoding #[clap(long)] high_precision: bool, + + /// Output data type for decoder (u8, u16, f16, f32). Used for benchmarking + /// the decoder's conversion pipeline. Default: f32 + #[clap(long, default_value = "f32")] + data_type: String, } // Extract RGB channels from interleaved RGB buffer @@ -189,61 +194,72 @@ fn main() -> Result<()> { // When extracting preview, don't skip it; otherwise skip preview by default let skip_preview = !opt.preview; - let mut image_data = if reps > 1 { + // Parse the output data type + let output_type = dec::OutputDataType::parse(&opt.data_type).ok_or_else(|| { + eyre!( + "Invalid data type '{}'. Must be u8, u16, f16, or f32", + opt.data_type + ) + })?; + + // Decode to typed output (timing excludes f32 conversion) + let typed_output = if reps > 1 { // For multiple repetitions (benchmarking), read into memory to avoid I/O variability let mut input_bytes = Vec::::new(); file.read_to_end(&mut input_bytes)?; (0..reps) - .try_fold(None, |_, _| -> Result>> { + .try_fold(None, |_, _| -> Result> { let mut input = input_bytes.as_slice(); - let (mut iteration_image_data, iteration_duration) = - dec::decode_frames(&mut input, options(skip_preview))?; + let (mut iteration_output, iteration_duration) = + dec::decode_frames_with_type(&mut input, options(skip_preview), output_type)?; duration_sum += iteration_duration; // When extracting preview, only keep the first frame (the preview) if opt.preview { - iteration_image_data.frames.truncate(1); - if let Some(frame) = iteration_image_data.frames.first() { - let samples = if frame.color_type == JxlColorType::Grayscale { + iteration_output.truncate_frames(1); + if let Some((color_type, (w, h))) = iteration_output.first_frame_info() { + let samples = if color_type == JxlColorType::Grayscale { 1 } else { 3 }; - let (w, h) = frame.channels[0].size(); - iteration_image_data.size = (w / samples, h); + iteration_output.set_size((w / samples, h)); } } - Ok(Some(iteration_image_data)) + Ok(Some(iteration_output)) })? .unwrap() } else { // For single decode, stream from file let mut reader = BufReader::new(file); - let (mut image_data, duration) = dec::decode_frames(&mut reader, options(skip_preview))?; + let (mut typed_output, duration) = + dec::decode_frames_with_type(&mut reader, options(skip_preview), output_type)?; duration_sum = duration; // When extracting preview, only keep the first frame (the preview) if opt.preview { - image_data.frames.truncate(1); - if let Some(frame) = image_data.frames.first() { - let samples = if frame.color_type == JxlColorType::Grayscale { + typed_output.truncate_frames(1); + if let Some((color_type, (w, h))) = typed_output.first_frame_info() { + let samples = if color_type == JxlColorType::Grayscale { 1 } else { 3 }; - let (w, h) = frame.channels[0].size(); - image_data.size = (w / samples, h); + typed_output.set_size((w / samples, h)); } } - image_data + typed_output }; - let data_icc_result = save_icc( - image_data.output_profile.as_icc().as_slice(), - opt.icc_out.as_ref(), - ); - let original_icc_result = save_icc( - image_data.embedded_profile.as_icc().as_slice(), - opt.original_icc_out.as_ref(), - ); + // Get metadata from typed output before converting + let output_icc = typed_output.output_profile().as_icc().to_vec(); + let embedded_icc = typed_output.embedded_profile().as_icc().to_vec(); + let image_size = typed_output.size(); + let original_bit_depth = typed_output.original_bit_depth().clone(); + + // Convert to f32 for saving (this happens AFTER timing is captured) + let mut image_data = typed_output.to_f32()?; + + let data_icc_result = save_icc(&output_icc, opt.icc_out.as_ref()); + let original_icc_result = save_icc(&embedded_icc, opt.original_icc_out.as_ref()); for frame in image_data.frames.iter_mut() { if frame.color_type != JxlColorType::Grayscale { @@ -254,7 +270,7 @@ fn main() -> Result<()> { } if opt.speedtest { - let num_pixels = image_data.size.0 * image_data.size.1; + let num_pixels = image_size.0 * image_size.1; let duration_seconds = duration_sum.as_nanos() as f64 / 1e9; let avg_seconds = duration_seconds / reps as f64; println!( @@ -267,7 +283,7 @@ fn main() -> Result<()> { let image_result: Option> = opt.output.map(|path| { let output_bit_depth = match opt.override_bitdepth { - None => image_data.original_bit_depth.bits_per_sample(), + None => original_bit_depth.bits_per_sample(), Some(num_bits) => num_bits, }; let image_result = save_image(&image_data, output_bit_depth, &path);