diff --git a/Cargo.lock b/Cargo.lock index a6ae1394..b00d31b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -51,17 +51,20 @@ dependencies = [ "anyhow", "apache-avro-derive", "apache-avro-test-helper", + "async-stream", "bigdecimal", "bon", "bzip2", "crc32fast", "criterion", "digest", + "futures", "hex-literal", "log", "md-5", "miniz_oxide", "num-bigint", + "oval", "paste", "pretty_assertions", "quad-rand", @@ -108,6 +111,28 @@ dependencies = [ "log", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -205,6 +230,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" + [[package]] name = "bzip2" version = "0.6.0" @@ -543,6 +574,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -564,6 +606,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -815,6 +858,15 @@ version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +[[package]] +name = "oval" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135cef32720c6746450d910890b0b69bcba2bbf6f85c9f4583df13fe415de828" +dependencies = [ + "bytes", +] + [[package]] name = "parking_lot" version = "0.12.3" diff --git a/avro/Cargo.toml b/avro/Cargo.toml index db2b0fa5..7ffc0f3b 100644 --- a/avro/Cargo.toml +++ b/avro/Cargo.toml @@ -29,11 +29,14 @@ categories.workspace = true documentation.workspace = true [features] +default = ["futures", "sync"] bzip = ["dep:bzip2"] derive = ["dep:apache-avro-derive"] snappy = ["dep:crc32fast", "dep:snap"] xz = ["dep:xz2"] zstandard = ["dep:zstd"] +futures = [] +sync = [] [lib] # disable benchmarks to allow passing criterion arguments to `cargo bench` @@ -73,6 +76,9 @@ thiserror = { default-features = false, version = "2.0.16" } uuid = { default-features = false, version = "1.18.0", features = ["serde", "std"] } xz2 = { default-features = false, version = "0.1.7", optional = true } zstd = { default-features = false, version = "0.13.3", optional = true } +oval = { version = "2.0.0", features = ["bytes"] } +futures = "0.3.31" +async-stream = "0.3.6" [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/avro/src/bigdecimal.rs b/avro/src/bigdecimal.rs index 0f4e1647..09d88ea2 100644 --- a/avro/src/bigdecimal.rs +++ b/avro/src/bigdecimal.rs @@ -17,10 +17,9 @@ use crate::{ AvroResult, - decode::{decode_len, decode_long}, encode::{encode_bytes, encode_long}, error::Details, - types::Value, + util::{decode_len_simple, decode_variable}, }; pub use bigdecimal::BigDecimal; use num_bigint::BigInt; @@ -47,10 +46,12 @@ pub(crate) fn serialize_big_decimal(decimal: &BigDecimal) -> AvroResult> Ok(final_buffer) } -pub(crate) fn deserialize_big_decimal(bytes: &Vec) -> AvroResult { - let mut bytes: &[u8] = bytes.as_slice(); - let mut big_decimal_buffer = match decode_len(&mut bytes) { - Ok(size) => vec![0u8; size], +pub(crate) fn deserialize_big_decimal(mut bytes: &[u8]) -> AvroResult { + let mut big_decimal_buffer = match decode_len_simple(bytes) { + Ok((size, bytes_read)) => { + bytes = &bytes[bytes_read..]; + vec![0u8; size] + } Err(err) => return Err(Details::BigDecimalLen(Box::new(err)).into()), }; @@ -58,8 +59,8 @@ pub(crate) fn deserialize_big_decimal(bytes: &Vec) -> AvroResult .read_exact(&mut big_decimal_buffer[..]) .map_err(Details::ReadDouble)?; - match decode_long(&mut bytes) { - Ok(Value::Long(scale_value)) => { + match decode_variable(bytes) { + Ok(Some((scale_value, _))) => { let big_int: BigInt = BigInt::from_signed_bytes_be(&big_decimal_buffer); let decimal = BigDecimal::new(big_int, scale_value); Ok(decimal) @@ -71,7 +72,11 @@ pub(crate) fn deserialize_big_decimal(bytes: &Vec) -> AvroResult #[cfg(test)] mod tests { use super::*; - use crate::{Codec, Reader, Schema, Writer, error::Error, types::Record}; + use crate::{ + Codec, Reader, Schema, Writer, + error::Error, + types::{Record, Value}, + }; use apache_avro_test_helper::TestResult; use bigdecimal::{One, Zero}; use pretty_assertions::assert_eq; @@ -92,7 +97,8 @@ mod tests { let buffer: Vec = serialize_big_decimal(¤t)?; let mut as_slice = buffer.as_slice(); - decode_long(&mut as_slice)?; + let (_, bytes_read) = decode_variable(as_slice)?.unwrap(); + as_slice = &as_slice[bytes_read..]; let mut result: Vec = Vec::new(); result.extend_from_slice(as_slice); @@ -109,7 +115,8 @@ mod tests { let buffer: Vec = serialize_big_decimal(&BigDecimal::zero())?; let mut as_slice = buffer.as_slice(); - decode_long(&mut as_slice)?; + let (_, bytes_read) = decode_variable(as_slice)?.unwrap(); + as_slice = &as_slice[bytes_read..]; let mut result: Vec = Vec::new(); result.extend_from_slice(as_slice); diff --git a/avro/src/decode.rs b/avro/src/decode.rs deleted file mode 100644 index 78fefbd9..00000000 --- a/avro/src/decode.rs +++ /dev/null @@ -1,875 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::{ - AvroResult, Error, - bigdecimal::deserialize_big_decimal, - decimal::Decimal, - duration::Duration, - error::Details, - schema::{ - DecimalSchema, EnumSchema, FixedSchema, Name, Namespace, RecordSchema, ResolvedSchema, - Schema, - }, - types::Value, - util::{safe_len, zag_i32, zag_i64}, -}; -use std::{ - borrow::Borrow, - collections::HashMap, - io::{ErrorKind, Read}, -}; -use uuid::Uuid; - -#[inline] -pub(crate) fn decode_long(reader: &mut R) -> AvroResult { - zag_i64(reader).map(Value::Long) -} - -#[inline] -fn decode_int(reader: &mut R) -> AvroResult { - zag_i32(reader).map(Value::Int) -} - -#[inline] -pub(crate) fn decode_len(reader: &mut R) -> AvroResult { - let len = zag_i64(reader)?; - safe_len(usize::try_from(len).map_err(|e| Details::ConvertI64ToUsize(e, len))?) -} - -/// Decode the length of a sequence. -/// -/// Maps and arrays are 0-terminated, 0i64 is also encoded as 0 in Avro reading a length of 0 means -/// the end of the map or array. -fn decode_seq_len(reader: &mut R) -> AvroResult { - let raw_len = zag_i64(reader)?; - safe_len( - usize::try_from(match raw_len.cmp(&0) { - std::cmp::Ordering::Equal => return Ok(0), - std::cmp::Ordering::Less => { - let _size = zag_i64(reader)?; - raw_len.checked_neg().ok_or(Details::IntegerOverflow)? - } - std::cmp::Ordering::Greater => raw_len, - }) - .map_err(|e| Details::ConvertI64ToUsize(e, raw_len))?, - ) -} - -/// Decode a `Value` from avro format given its `Schema`. -pub fn decode(schema: &Schema, reader: &mut R) -> AvroResult { - let rs = ResolvedSchema::try_from(schema)?; - decode_internal(schema, rs.get_names(), &None, reader) -} - -pub(crate) fn decode_internal>( - schema: &Schema, - names: &HashMap, - enclosing_namespace: &Namespace, - reader: &mut R, -) -> AvroResult { - match *schema { - Schema::Null => Ok(Value::Null), - Schema::Boolean => { - let mut buf = [0u8; 1]; - match reader.read_exact(&mut buf[..]) { - Ok(_) => match buf[0] { - 0u8 => Ok(Value::Boolean(false)), - 1u8 => Ok(Value::Boolean(true)), - _ => Err(Details::BoolValue(buf[0]).into()), - }, - Err(io_err) => { - if let ErrorKind::UnexpectedEof = io_err.kind() { - Ok(Value::Null) - } else { - Err(Details::ReadBoolean(io_err).into()) - } - } - } - } - Schema::Decimal(DecimalSchema { ref inner, .. }) => match &**inner { - Schema::Fixed { .. } => { - match decode_internal(inner, names, enclosing_namespace, reader)? { - Value::Fixed(_, bytes) => Ok(Value::Decimal(Decimal::from(bytes))), - value => Err(Details::FixedValue(value).into()), - } - } - Schema::Bytes => match decode_internal(inner, names, enclosing_namespace, reader)? { - Value::Bytes(bytes) => Ok(Value::Decimal(Decimal::from(bytes))), - value => Err(Details::BytesValue(value).into()), - }, - schema => Err(Details::ResolveDecimalSchema(schema.into()).into()), - }, - Schema::BigDecimal => { - match decode_internal(&Schema::Bytes, names, enclosing_namespace, reader)? { - Value::Bytes(bytes) => deserialize_big_decimal(&bytes).map(Value::BigDecimal), - value => Err(Details::BytesValue(value).into()), - } - } - Schema::Uuid => { - let Value::Bytes(bytes) = - decode_internal(&Schema::Bytes, names, enclosing_namespace, reader)? - else { - // Calling decode_internal with Schema::Bytes can only return a Value::Bytes or an error - unreachable!(); - }; - - let uuid = if bytes.len() == 16 { - Uuid::from_slice(&bytes).map_err(Details::ConvertSliceToUuid)? - } else { - let string = std::str::from_utf8(&bytes).map_err(Details::ConvertToUtf8Error)?; - Uuid::parse_str(string).map_err(Details::ConvertStrToUuid)? - }; - Ok(Value::Uuid(uuid)) - } - Schema::Int => decode_int(reader), - Schema::Date => zag_i32(reader).map(Value::Date), - Schema::TimeMillis => zag_i32(reader).map(Value::TimeMillis), - Schema::Long => decode_long(reader), - Schema::TimeMicros => zag_i64(reader).map(Value::TimeMicros), - Schema::TimestampMillis => zag_i64(reader).map(Value::TimestampMillis), - Schema::TimestampMicros => zag_i64(reader).map(Value::TimestampMicros), - Schema::TimestampNanos => zag_i64(reader).map(Value::TimestampNanos), - Schema::LocalTimestampMillis => zag_i64(reader).map(Value::LocalTimestampMillis), - Schema::LocalTimestampMicros => zag_i64(reader).map(Value::LocalTimestampMicros), - Schema::LocalTimestampNanos => zag_i64(reader).map(Value::LocalTimestampNanos), - Schema::Duration => { - let mut buf = [0u8; 12]; - reader.read_exact(&mut buf).map_err(Details::ReadDuration)?; - Ok(Value::Duration(Duration::from(buf))) - } - Schema::Float => { - let mut buf = [0u8; std::mem::size_of::()]; - reader - .read_exact(&mut buf[..]) - .map_err(Details::ReadFloat)?; - Ok(Value::Float(f32::from_le_bytes(buf))) - } - Schema::Double => { - let mut buf = [0u8; std::mem::size_of::()]; - reader - .read_exact(&mut buf[..]) - .map_err(Details::ReadDouble)?; - Ok(Value::Double(f64::from_le_bytes(buf))) - } - Schema::Bytes => { - let len = decode_len(reader)?; - let mut buf = vec![0u8; len]; - reader.read_exact(&mut buf).map_err(Details::ReadBytes)?; - Ok(Value::Bytes(buf)) - } - Schema::String => { - let len = decode_len(reader)?; - let mut buf = vec![0u8; len]; - match reader.read_exact(&mut buf) { - Ok(_) => Ok(Value::String( - String::from_utf8(buf).map_err(Details::ConvertToUtf8)?, - )), - Err(io_err) => { - if let ErrorKind::UnexpectedEof = io_err.kind() { - Ok(Value::Null) - } else { - Err(Details::ReadString(io_err).into()) - } - } - } - } - Schema::Fixed(FixedSchema { size, .. }) => { - let mut buf = vec![0u8; size]; - reader - .read_exact(&mut buf) - .map_err(|e| Details::ReadFixed(e, size))?; - Ok(Value::Fixed(size, buf)) - } - Schema::Array(ref inner) => { - let mut items = Vec::new(); - - loop { - let len = decode_seq_len(reader)?; - if len == 0 { - break; - } - - items.reserve(len); - for _ in 0..len { - items.push(decode_internal( - &inner.items, - names, - enclosing_namespace, - reader, - )?); - } - } - - Ok(Value::Array(items)) - } - Schema::Map(ref inner) => { - let mut items = HashMap::new(); - - loop { - let len = decode_seq_len(reader)?; - if len == 0 { - break; - } - - items.reserve(len); - for _ in 0..len { - match decode_internal(&Schema::String, names, enclosing_namespace, reader)? { - Value::String(key) => { - let value = - decode_internal(&inner.types, names, enclosing_namespace, reader)?; - items.insert(key, value); - } - value => return Err(Details::MapKeyType(value.into()).into()), - } - } - } - - Ok(Value::Map(items)) - } - Schema::Union(ref inner) => match zag_i64(reader).map_err(Error::into_details) { - Ok(index) => { - let variants = inner.variants(); - let variant = variants - .get(usize::try_from(index).map_err(|e| Details::ConvertI64ToUsize(e, index))?) - .ok_or(Details::GetUnionVariant { - index, - num_variants: variants.len(), - })?; - let value = decode_internal(variant, names, enclosing_namespace, reader)?; - Ok(Value::Union(index as u32, Box::new(value))) - } - Err(Details::ReadVariableIntegerBytes(io_err)) => { - if let ErrorKind::UnexpectedEof = io_err.kind() { - Ok(Value::Union(0, Box::new(Value::Null))) - } else { - Err(Details::ReadVariableIntegerBytes(io_err).into()) - } - } - Err(io_err) => Err(Error::new(io_err)), - }, - Schema::Record(RecordSchema { - ref name, - ref fields, - .. - }) => { - let fully_qualified_name = name.fully_qualified_name(enclosing_namespace); - // Benchmarks indicate ~10% improvement using this method. - let mut items = Vec::with_capacity(fields.len()); - for field in fields { - // TODO: This clone is also expensive. See if we can do away with it... - items.push(( - field.name.clone(), - decode_internal( - &field.schema, - names, - &fully_qualified_name.namespace, - reader, - )?, - )); - } - Ok(Value::Record(items)) - } - Schema::Enum(EnumSchema { ref symbols, .. }) => { - Ok(if let Value::Int(raw_index) = decode_int(reader)? { - let index = usize::try_from(raw_index) - .map_err(|e| Details::ConvertI32ToUsize(e, raw_index))?; - if (0..symbols.len()).contains(&index) { - let symbol = symbols[index].clone(); - Value::Enum(raw_index as u32, symbol) - } else { - return Err(Details::GetEnumValue { - index, - nsymbols: symbols.len(), - } - .into()); - } - } else { - return Err(Details::GetEnumUnknownIndexValue.into()); - }) - } - Schema::Ref { ref name } => { - let fully_qualified_name = name.fully_qualified_name(enclosing_namespace); - if let Some(resolved) = names.get(&fully_qualified_name) { - decode_internal( - resolved.borrow(), - names, - &fully_qualified_name.namespace, - reader, - ) - } else { - Err(Details::SchemaResolutionError(fully_qualified_name).into()) - } - } - } -} - -#[cfg(test)] -#[allow(clippy::expect_fun_call)] -mod tests { - use crate::{ - Decimal, - decode::decode, - encode::{encode, tests::success}, - schema::{DecimalSchema, FixedSchema, Schema}, - types::{ - Value, - Value::{Array, Int, Map}, - }, - }; - use apache_avro_test_helper::TestResult; - use pretty_assertions::assert_eq; - use std::collections::HashMap; - use uuid::Uuid; - - #[test] - fn test_decode_array_without_size() -> TestResult { - let mut input: &[u8] = &[6, 2, 4, 6, 0]; - let result = decode(&Schema::array(Schema::Int), &mut input); - assert_eq!(Array(vec!(Int(1), Int(2), Int(3))), result?); - - Ok(()) - } - - #[test] - fn test_decode_array_with_size() -> TestResult { - let mut input: &[u8] = &[5, 6, 2, 4, 6, 0]; - let result = decode(&Schema::array(Schema::Int), &mut input); - assert_eq!(Array(vec!(Int(1), Int(2), Int(3))), result?); - - Ok(()) - } - - #[test] - fn test_decode_map_without_size() -> TestResult { - let mut input: &[u8] = &[0x02, 0x08, 0x74, 0x65, 0x73, 0x74, 0x02, 0x00]; - let result = decode(&Schema::map(Schema::Int), &mut input); - let mut expected = HashMap::new(); - expected.insert(String::from("test"), Int(1)); - assert_eq!(Map(expected), result?); - - Ok(()) - } - - #[test] - fn test_decode_map_with_size() -> TestResult { - let mut input: &[u8] = &[0x01, 0x0C, 0x08, 0x74, 0x65, 0x73, 0x74, 0x02, 0x00]; - let result = decode(&Schema::map(Schema::Int), &mut input); - let mut expected = HashMap::new(); - expected.insert(String::from("test"), Int(1)); - assert_eq!(Map(expected), result?); - - Ok(()) - } - - #[test] - fn test_negative_decimal_value() -> TestResult { - use crate::{encode::encode, schema::Name}; - use num_bigint::ToBigInt; - let inner = Box::new(Schema::Fixed( - FixedSchema::builder() - .name(Name::new("decimal")?) - .size(2) - .build(), - )); - let schema = Schema::Decimal(DecimalSchema { - inner, - precision: 4, - scale: 2, - }); - let bigint = (-423).to_bigint().unwrap(); - let value = Value::Decimal(Decimal::from(bigint.to_signed_bytes_be())); - - let mut buffer = Vec::new(); - encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); - - let mut bytes = &buffer[..]; - let result = decode(&schema, &mut bytes)?; - assert_eq!(result, value); - - Ok(()) - } - - #[test] - fn test_decode_decimal_with_bigger_than_necessary_size() -> TestResult { - use crate::{encode::encode, schema::Name}; - use num_bigint::ToBigInt; - let inner = Box::new(Schema::Fixed(FixedSchema { - size: 13, - name: Name::new("decimal")?, - aliases: None, - doc: None, - default: None, - attributes: Default::default(), - })); - let schema = Schema::Decimal(DecimalSchema { - inner, - precision: 4, - scale: 2, - }); - let value = Value::Decimal(Decimal::from( - ((-423).to_bigint().unwrap()).to_signed_bytes_be(), - )); - let mut buffer = Vec::::new(); - - encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); - let mut bytes: &[u8] = &buffer[..]; - let result = decode(&schema, &mut bytes)?; - assert_eq!(result, value); - - Ok(()) - } - - #[test] - fn test_avro_3448_recursive_definition_decode_union() -> TestResult { - // if encoding fails in this test check the corresponding test in encode - let schema = Schema::parse_str( - r#" - { - "type":"record", - "name":"TestStruct", - "fields": [ - { - "name":"a", - "type":[ "null", { - "type":"record", - "name": "Inner", - "fields": [ { - "name":"z", - "type":"int" - }] - }] - }, - { - "name":"b", - "type":"Inner" - } - ] - }"#, - )?; - - let inner_value1 = Value::Record(vec![("z".into(), Value::Int(3))]); - let inner_value2 = Value::Record(vec![("z".into(), Value::Int(6))]); - let outer_value1 = Value::Record(vec![ - ("a".into(), Value::Union(1, Box::new(inner_value1))), - ("b".into(), inner_value2.clone()), - ]); - let mut buf = Vec::new(); - encode(&outer_value1, &schema, &mut buf).expect(&success(&outer_value1, &schema)); - assert!(!buf.is_empty()); - let mut bytes = &buf[..]; - assert_eq!( - outer_value1, - decode(&schema, &mut bytes).expect(&format!( - "Failed to decode using recursive definitions with schema:\n {:?}\n", - &schema - )) - ); - - let mut buf = Vec::new(); - let outer_value2 = Value::Record(vec![ - ("a".into(), Value::Union(0, Box::new(Value::Null))), - ("b".into(), inner_value2), - ]); - encode(&outer_value2, &schema, &mut buf).expect(&success(&outer_value2, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_value2, - decode(&schema, &mut bytes).expect(&format!( - "Failed to decode using recursive definitions with schema:\n {:?}\n", - &schema - )) - ); - - Ok(()) - } - - #[test] - fn test_avro_3448_recursive_definition_decode_array() -> TestResult { - let schema = Schema::parse_str( - r#" - { - "type":"record", - "name":"TestStruct", - "fields": [ - { - "name":"a", - "type":{ - "type":"array", - "items": { - "type":"record", - "name": "Inner", - "fields": [ { - "name":"z", - "type":"int" - }] - } - } - }, - { - "name":"b", - "type": "Inner" - } - ] - }"#, - )?; - - let inner_value1 = Value::Record(vec![("z".into(), Value::Int(3))]); - let inner_value2 = Value::Record(vec![("z".into(), Value::Int(6))]); - let outer_value = Value::Record(vec![ - ("a".into(), Value::Array(vec![inner_value1])), - ("b".into(), inner_value2), - ]); - let mut buf = Vec::new(); - encode(&outer_value, &schema, &mut buf).expect(&success(&outer_value, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_value, - decode(&schema, &mut bytes).expect(&format!( - "Failed to decode using recursive definitions with schema:\n {:?}\n", - &schema - )) - ); - - Ok(()) - } - - #[test] - fn test_avro_3448_recursive_definition_decode_map() -> TestResult { - let schema = Schema::parse_str( - r#" - { - "type":"record", - "name":"TestStruct", - "fields": [ - { - "name":"a", - "type":{ - "type":"map", - "values": { - "type":"record", - "name": "Inner", - "fields": [ { - "name":"z", - "type":"int" - }] - } - } - }, - { - "name":"b", - "type": "Inner" - } - ] - }"#, - )?; - - let inner_value1 = Value::Record(vec![("z".into(), Value::Int(3))]); - let inner_value2 = Value::Record(vec![("z".into(), Value::Int(6))]); - let outer_value = Value::Record(vec![ - ( - "a".into(), - Value::Map(vec![("akey".into(), inner_value1)].into_iter().collect()), - ), - ("b".into(), inner_value2), - ]); - let mut buf = Vec::new(); - encode(&outer_value, &schema, &mut buf).expect(&success(&outer_value, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_value, - decode(&schema, &mut bytes).expect(&format!( - "Failed to decode using recursive definitions with schema:\n {:?}\n", - &schema - )) - ); - - Ok(()) - } - - #[test] - fn test_avro_3448_proper_multi_level_decoding_middle_namespace() -> TestResult { - // if encoding fails in this test check the corresponding test in encode - let schema = r#" - { - "name": "record_name", - "namespace": "space", - "type": "record", - "fields": [ - { - "name": "outer_field_1", - "type": [ - "null", - { - "type": "record", - "name": "middle_record_name", - "namespace":"middle_namespace", - "fields":[ - { - "name":"middle_field_1", - "type":[ - "null", - { - "type":"record", - "name":"inner_record_name", - "fields":[ - { - "name":"inner_field_1", - "type":"double" - } - ] - } - ] - } - ] - } - ] - }, - { - "name": "outer_field_2", - "type" : "middle_namespace.inner_record_name" - } - ] - } - "#; - let schema = Schema::parse_str(schema)?; - let inner_record = Value::Record(vec![("inner_field_1".into(), Value::Double(5.4))]); - let middle_record_variation_1 = Value::Record(vec![( - "middle_field_1".into(), - Value::Union(0, Box::new(Value::Null)), - )]); - let middle_record_variation_2 = Value::Record(vec![( - "middle_field_1".into(), - Value::Union(1, Box::new(inner_record.clone())), - )]); - let outer_record_variation_1 = Value::Record(vec![ - ( - "outer_field_1".into(), - Value::Union(0, Box::new(Value::Null)), - ), - ("outer_field_2".into(), inner_record.clone()), - ]); - let outer_record_variation_2 = Value::Record(vec![ - ( - "outer_field_1".into(), - Value::Union(1, Box::new(middle_record_variation_1)), - ), - ("outer_field_2".into(), inner_record.clone()), - ]); - let outer_record_variation_3 = Value::Record(vec![ - ( - "outer_field_1".into(), - Value::Union(1, Box::new(middle_record_variation_2)), - ), - ("outer_field_2".into(), inner_record), - ]); - - let mut buf = Vec::new(); - encode(&outer_record_variation_1, &schema, &mut buf) - .expect(&success(&outer_record_variation_1, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_record_variation_1, - decode(&schema, &mut bytes).expect(&format!( - "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", - &schema - )) - ); - - let mut buf = Vec::new(); - encode(&outer_record_variation_2, &schema, &mut buf) - .expect(&success(&outer_record_variation_2, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_record_variation_2, - decode(&schema, &mut bytes).expect(&format!( - "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", - &schema - )) - ); - - let mut buf = Vec::new(); - encode(&outer_record_variation_3, &schema, &mut buf) - .expect(&success(&outer_record_variation_3, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_record_variation_3, - decode(&schema, &mut bytes).expect(&format!( - "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", - &schema - )) - ); - - Ok(()) - } - - #[test] - fn test_avro_3448_proper_multi_level_decoding_inner_namespace() -> TestResult { - // if encoding fails in this test check the corresponding test in encode - let schema = r#" - { - "name": "record_name", - "namespace": "space", - "type": "record", - "fields": [ - { - "name": "outer_field_1", - "type": [ - "null", - { - "type": "record", - "name": "middle_record_name", - "namespace":"middle_namespace", - "fields":[ - { - "name":"middle_field_1", - "type":[ - "null", - { - "type":"record", - "name":"inner_record_name", - "namespace":"inner_namespace", - "fields":[ - { - "name":"inner_field_1", - "type":"double" - } - ] - } - ] - } - ] - } - ] - }, - { - "name": "outer_field_2", - "type" : "inner_namespace.inner_record_name" - } - ] - } - "#; - let schema = Schema::parse_str(schema)?; - let inner_record = Value::Record(vec![("inner_field_1".into(), Value::Double(5.4))]); - let middle_record_variation_1 = Value::Record(vec![( - "middle_field_1".into(), - Value::Union(0, Box::new(Value::Null)), - )]); - let middle_record_variation_2 = Value::Record(vec![( - "middle_field_1".into(), - Value::Union(1, Box::new(inner_record.clone())), - )]); - let outer_record_variation_1 = Value::Record(vec![ - ( - "outer_field_1".into(), - Value::Union(0, Box::new(Value::Null)), - ), - ("outer_field_2".into(), inner_record.clone()), - ]); - let outer_record_variation_2 = Value::Record(vec![ - ( - "outer_field_1".into(), - Value::Union(1, Box::new(middle_record_variation_1)), - ), - ("outer_field_2".into(), inner_record.clone()), - ]); - let outer_record_variation_3 = Value::Record(vec![ - ( - "outer_field_1".into(), - Value::Union(1, Box::new(middle_record_variation_2)), - ), - ("outer_field_2".into(), inner_record), - ]); - - let mut buf = Vec::new(); - encode(&outer_record_variation_1, &schema, &mut buf) - .expect(&success(&outer_record_variation_1, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_record_variation_1, - decode(&schema, &mut bytes).expect(&format!( - "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", - &schema - )) - ); - - let mut buf = Vec::new(); - encode(&outer_record_variation_2, &schema, &mut buf) - .expect(&success(&outer_record_variation_2, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_record_variation_2, - decode(&schema, &mut bytes).expect(&format!( - "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", - &schema - )) - ); - - let mut buf = Vec::new(); - encode(&outer_record_variation_3, &schema, &mut buf) - .expect(&success(&outer_record_variation_3, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_record_variation_3, - decode(&schema, &mut bytes).expect(&format!( - "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", - &schema - )) - ); - - Ok(()) - } - - #[test] - fn avro_3926_encode_decode_uuid_to_string() -> TestResult { - use crate::encode::encode; - - let schema = Schema::String; - let value = Value::Uuid(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000")?); - - let mut buffer = Vec::new(); - encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); - - let result = decode(&Schema::Uuid, &mut &buffer[..])?; - assert_eq!(result, value); - - Ok(()) - } - - #[test] - fn avro_3926_encode_decode_uuid_to_fixed() -> TestResult { - use crate::encode::encode; - - let schema = Schema::Fixed(FixedSchema { - size: 16, - name: "uuid".into(), - aliases: None, - doc: None, - default: None, - attributes: Default::default(), - }); - let value = Value::Uuid(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000")?); - - let mut buffer = Vec::new(); - encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); - - let result = decode(&Schema::Uuid, &mut &buffer[..])?; - assert_eq!(result, value); - - Ok(()) - } -} diff --git a/avro/src/error.rs b/avro/src/error.rs index 95aeb2b9..2163d050 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -17,6 +17,7 @@ use crate::{ schema::{Name, Schema, SchemaKind, UnionSchema}, + state_machines::reading::error::ValueFromTapeError, types::{Value, ValueKind}, }; use std::{error::Error as _, fmt}; @@ -56,6 +57,12 @@ impl From
for Error { } } +impl From for Error { + fn from(value: ValueFromTapeError) -> Self { + Self::new(value.into()) + } +} + impl serde::ser::Error for Error { fn custom(msg: T) -> Self { Self::new(
::custom(msg)) @@ -576,6 +583,9 @@ pub enum Details { #[error("Cannot convert a slice to Uuid: {0}")] UuidFromSlice(#[source] uuid::Error), + + #[error(transparent)] + ValueFromTapeError(#[from] ValueFromTapeError), } #[derive(thiserror::Error, PartialEq)] diff --git a/avro/src/lib.rs b/avro/src/lib.rs index 36a31e6e..b807437c 100644 --- a/avro/src/lib.rs +++ b/avro/src/lib.rs @@ -864,12 +864,12 @@ mod bytes; mod codec; mod de; mod decimal; -mod decode; mod duration; mod encode; mod reader; mod ser; mod ser_schema; +pub mod state_machines; mod util; mod writer; @@ -919,6 +919,12 @@ pub use apache_avro_derive::*; /// A convenience type alias for `Result`s with `Error`s. pub type AvroResult = Result; +/// Async versions of the types and functions. +pub mod not_sync { + #[doc(inline)] + pub use crate::reader::async_reader::*; +} + #[cfg(test)] mod tests { use crate::{ diff --git a/avro/src/reader.rs b/avro/src/reader.rs index e2f7570f..763e3c8b 100644 --- a/avro/src/reader.rs +++ b/avro/src/reader.rs @@ -16,504 +16,49 @@ // under the License. //! Logic handling reading from Avro format at user level. + +pub use crate::state_machines::reading::sync::{ + Reader, from_avro_datum, from_avro_datum_reader_schemata, from_avro_datum_schemata, +}; use crate::{ - AvroResult, Codec, Error, - decode::{decode, decode_internal}, + AvroResult, error::Details, from_value, headers::{HeaderBuilder, RabinFingerprintHeader}, - schema::{ - AvroSchema, Names, ResolvedOwnedSchema, ResolvedSchema, Schema, resolve_names, - resolve_names_with_schemata, - }, + schema::{AvroSchema, ResolvedOwnedSchema, Schema}, types::Value, - util, }; -use log::warn; +use futures::AsyncRead; use serde::de::DeserializeOwned; -use serde_json::from_slice; -use std::{ - collections::HashMap, - io::{ErrorKind, Read}, - marker::PhantomData, - str::FromStr, -}; +use std::{io::Read, marker::PhantomData}; -/// Internal Block reader. -#[derive(Debug, Clone)] -struct Block<'r, R> { - reader: R, - /// Internal buffering to reduce allocation. - buf: Vec, - buf_idx: usize, - /// Number of elements expected to exist within this block. - message_count: usize, - marker: [u8; 16], - codec: Codec, - writer_schema: Schema, - schemata: Vec<&'r Schema>, - user_metadata: HashMap>, - names_refs: Names, +pub mod async_reader { + #[doc(inline)] + pub use crate::state_machines::reading::async_impl::{ + Reader, from_avro_datum, from_avro_datum_reader_schemata, from_avro_datum_schemata, + }; } -impl<'r, R: Read> Block<'r, R> { - fn new(reader: R, schemata: Vec<&'r Schema>) -> AvroResult> { - let mut block = Block { - reader, - codec: Codec::Null, - writer_schema: Schema::Null, - schemata, - buf: vec![], - buf_idx: 0, - message_count: 0, - marker: [0; 16], - user_metadata: Default::default(), - names_refs: Default::default(), - }; - - block.read_header()?; - Ok(block) - } - - /// Try to read the header and to set the writer `Schema`, the `Codec` and the marker based on - /// its content. - fn read_header(&mut self) -> AvroResult<()> { - let mut buf = [0u8; 4]; - self.reader - .read_exact(&mut buf) - .map_err(Details::ReadHeader)?; - - if buf != [b'O', b'b', b'j', 1u8] { - return Err(Details::HeaderMagic.into()); - } - - let meta_schema = Schema::map(Schema::Bytes); - match decode(&meta_schema, &mut self.reader)? { - Value::Map(metadata) => { - self.read_writer_schema(&metadata)?; - self.codec = read_codec(&metadata)?; - - for (key, value) in metadata { - if key == "avro.schema" - || key == "avro.codec" - || key == "avro.codec.compression_level" - { - // already processed - } else if key.starts_with("avro.") { - warn!("Ignoring unknown metadata key: {key}"); - } else { - self.read_user_metadata(key, value); - } - } - } - _ => { - return Err(Details::GetHeaderMetadata.into()); - } - } - - self.reader - .read_exact(&mut self.marker) - .map_err(|e| Details::ReadMarker(e).into()) - } - - fn fill_buf(&mut self, n: usize) -> AvroResult<()> { - // The buffer needs to contain exactly `n` elements, otherwise codecs will potentially read - // invalid bytes. - // - // The are two cases to handle here: - // - // 1. `n > self.buf.len()`: - // In this case we call `Vec::resize`, which guarantees that `self.buf.len() == n`. - // 2. `n < self.buf.len()`: - // We need to resize to ensure that the buffer len is safe to read `n` elements. - // - // TODO: Figure out a way to avoid having to truncate for the second case. - self.buf.resize(util::safe_len(n)?, 0); - self.reader - .read_exact(&mut self.buf) - .map_err(Details::ReadIntoBuf)?; - self.buf_idx = 0; - Ok(()) - } - - /// Try to read a data block, also performing schema resolution for the objects contained in - /// the block. The objects are stored in an internal buffer to the `Reader`. - fn read_block_next(&mut self) -> AvroResult<()> { - assert!(self.is_empty(), "Expected self to be empty!"); - match util::read_long(&mut self.reader).map_err(Error::into_details) { - Ok(block_len) => { - self.message_count = block_len as usize; - let block_bytes = util::read_long(&mut self.reader)?; - self.fill_buf(block_bytes as usize)?; - let mut marker = [0u8; 16]; - self.reader - .read_exact(&mut marker) - .map_err(Details::ReadBlockMarker)?; - - if marker != self.marker { - return Err(Details::GetBlockMarker.into()); - } - - // NOTE (JAB): This doesn't fit this Reader pattern very well. - // `self.buf` is a growable buffer that is reused as the reader is iterated. - // For non `Codec::Null` variants, `decompress` will allocate a new `Vec` - // and replace `buf` with the new one, instead of reusing the same buffer. - // We can address this by using some "limited read" type to decode directly - // into the buffer. But this is fine, for now. - self.codec.decompress(&mut self.buf) - } - Err(Details::ReadVariableIntegerBytes(io_err)) => { - if let ErrorKind::UnexpectedEof = io_err.kind() { - // to not return any error in case we only finished to read cleanly from the stream - Ok(()) - } else { - Err(Details::ReadVariableIntegerBytes(io_err).into()) - } - } - Err(e) => Err(Error::new(e)), - } - } - - fn len(&self) -> usize { - self.message_count - } - - fn is_empty(&self) -> bool { - self.len() == 0 - } - - fn read_next(&mut self, read_schema: Option<&Schema>) -> AvroResult> { - if self.is_empty() { - self.read_block_next()?; - if self.is_empty() { - return Ok(None); - } - } - - let mut block_bytes = &self.buf[self.buf_idx..]; - let b_original = block_bytes.len(); - - let item = decode_internal( - &self.writer_schema, - &self.names_refs, - &None, - &mut block_bytes, - )?; - let item = match read_schema { - Some(schema) => item.resolve(schema)?, - None => item, - }; - - if b_original != 0 && b_original == block_bytes.len() { - // from_avro_datum did not consume any bytes, so return an error to avoid an infinite loop - return Err(Details::ReadBlock.into()); - } - self.buf_idx += b_original - block_bytes.len(); - self.message_count -= 1; - Ok(Some(item)) - } - - fn read_writer_schema(&mut self, metadata: &HashMap) -> AvroResult<()> { - let json: serde_json::Value = metadata - .get("avro.schema") - .and_then(|bytes| { - if let Value::Bytes(ref bytes) = *bytes { - from_slice(bytes.as_ref()).ok() - } else { - None - } - }) - .ok_or(Details::GetAvroSchemaFromMap)?; - if !self.schemata.is_empty() { - let rs = ResolvedSchema::try_from(self.schemata.clone())?; - let names: Names = rs - .get_names() - .iter() - .map(|(name, schema)| (name.clone(), (*schema).clone())) - .collect(); - self.writer_schema = Schema::parse_with_names(&json, names)?; - resolve_names_with_schemata(&self.schemata, &mut self.names_refs, &None)?; - } else { - self.writer_schema = Schema::parse(&json)?; - resolve_names(&self.writer_schema, &mut self.names_refs, &None)?; - } - Ok(()) - } - - fn read_user_metadata(&mut self, key: String, value: Value) { - match value { - Value::Bytes(ref vec) => { - self.user_metadata.insert(key, vec.clone()); - } - wrong => { - warn!("User metadata values must be Value::Bytes, found {wrong:?}"); - } - } - } -} - -fn read_codec(metadata: &HashMap) -> AvroResult { - let result = metadata - .get("avro.codec") - .map(|codec| { - if let Value::Bytes(ref bytes) = *codec { - match std::str::from_utf8(bytes.as_ref()) { - Ok(utf8) => Ok(utf8), - Err(utf8_error) => Err(Details::ConvertToUtf8Error(utf8_error).into()), - } - } else { - Err(Details::BadCodecMetadata.into()) - } - }) - .map(|codec_res| match codec_res { - Ok(codec) => match Codec::from_str(codec) { - Ok(codec) => match codec { - #[cfg(feature = "bzip")] - Codec::Bzip2(_) => { - use crate::Bzip2Settings; - if let Some(Value::Bytes(bytes)) = - metadata.get("avro.codec.compression_level") - { - Ok(Codec::Bzip2(Bzip2Settings::new(bytes[0]))) - } else { - Ok(codec) - } - } - #[cfg(feature = "xz")] - Codec::Xz(_) => { - use crate::XzSettings; - if let Some(Value::Bytes(bytes)) = - metadata.get("avro.codec.compression_level") - { - Ok(Codec::Xz(XzSettings::new(bytes[0]))) - } else { - Ok(codec) - } - } - #[cfg(feature = "zstandard")] - Codec::Zstandard(_) => { - use crate::ZstandardSettings; - if let Some(Value::Bytes(bytes)) = - metadata.get("avro.codec.compression_level") - { - Ok(Codec::Zstandard(ZstandardSettings::new(bytes[0]))) - } else { - Ok(codec) - } - } - _ => Ok(codec), - }, - Err(_) => Err(Details::CodecNotSupported(codec.to_owned()).into()), - }, - Err(err) => Err(err), - }); - - result.unwrap_or(Ok(Codec::Null)) -} - -/// Main interface for reading Avro formatted values. -/// -/// To be used as an iterator: +/// Reader for Avro objects created using the [single-object encoding]. /// -/// ```no_run -/// # use apache_avro::Reader; -/// # use std::io::Cursor; -/// # let input = Cursor::new(Vec::::new()); -/// for value in Reader::new(input).unwrap() { -/// match value { -/// Ok(v) => println!("{:?}", v), -/// Err(e) => println!("Error: {}", e), -/// }; -/// } -/// ``` -pub struct Reader<'a, R> { - block: Block<'a, R>, - reader_schema: Option<&'a Schema>, - errored: bool, - should_resolve_schema: bool, -} - -impl<'a, R: Read> Reader<'a, R> { - /// Creates a `Reader` given something implementing the `io::Read` trait to read from. - /// No reader `Schema` will be set. - /// - /// **NOTE** The avro header is going to be read automatically upon creation of the `Reader`. - pub fn new(reader: R) -> AvroResult> { - let block = Block::new(reader, vec![])?; - let reader = Reader { - block, - reader_schema: None, - errored: false, - should_resolve_schema: false, - }; - Ok(reader) - } - - /// Creates a `Reader` given a reader `Schema` and something implementing the `io::Read` trait - /// to read from. - /// - /// **NOTE** The avro header is going to be read automatically upon creation of the `Reader`. - pub fn with_schema(schema: &'a Schema, reader: R) -> AvroResult> { - let block = Block::new(reader, vec![schema])?; - let mut reader = Reader { - block, - reader_schema: Some(schema), - errored: false, - should_resolve_schema: false, - }; - // Check if the reader and writer schemas disagree. - reader.should_resolve_schema = reader.writer_schema() != schema; - Ok(reader) - } - - /// Creates a `Reader` given a reader `Schema` and something implementing the `io::Read` trait - /// to read from. - /// - /// **NOTE** The avro header is going to be read automatically upon creation of the `Reader`. - pub fn with_schemata( - schema: &'a Schema, - schemata: Vec<&'a Schema>, - reader: R, - ) -> AvroResult> { - let block = Block::new(reader, schemata)?; - let mut reader = Reader { - block, - reader_schema: Some(schema), - errored: false, - should_resolve_schema: false, - }; - // Check if the reader and writer schemas disagree. - reader.should_resolve_schema = reader.writer_schema() != schema; - Ok(reader) - } - - /// Get a reference to the writer `Schema`. - #[inline] - pub fn writer_schema(&self) -> &Schema { - &self.block.writer_schema - } - - /// Get a reference to the optional reader `Schema`. - #[inline] - pub fn reader_schema(&self) -> Option<&Schema> { - self.reader_schema - } - - /// Get a reference to the user metadata - #[inline] - pub fn user_metadata(&self) -> &HashMap> { - &self.block.user_metadata - } - - #[inline] - fn read_next(&mut self) -> AvroResult> { - let read_schema = if self.should_resolve_schema { - self.reader_schema - } else { - None - }; - - self.block.read_next(read_schema) - } -} - -impl Iterator for Reader<'_, R> { - type Item = AvroResult; - - fn next(&mut self) -> Option { - // to prevent keep on reading after the first error occurs - if self.errored { - return None; - }; - match self.read_next() { - Ok(opt) => opt.map(Ok), - Err(e) => { - self.errored = true; - Some(Err(e)) - } - } - } -} - -/// Decode a `Value` encoded in Avro format given its `Schema` and anything implementing `io::Read` -/// to read from. -/// -/// In case a reader `Schema` is provided, schema resolution will also be performed. -/// -/// **NOTE** This function has a quite small niche of usage and does NOT take care of reading the -/// header and consecutive data blocks; use [`Reader`](struct.Reader.html) if you don't know what -/// you are doing, instead. -pub fn from_avro_datum( - writer_schema: &Schema, - reader: &mut R, - reader_schema: Option<&Schema>, -) -> AvroResult { - let value = decode(writer_schema, reader)?; - match reader_schema { - Some(schema) => value.resolve(schema), - None => Ok(value), - } -} - -/// Decode a `Value` encoded in Avro format given the provided `Schema` and anything implementing `io::Read` -/// to read from. -/// If the writer schema is incomplete, i.e. contains `Schema::Ref`s then it will use the provided -/// schemata to resolve any dependencies. -/// -/// In case a reader `Schema` is provided, schema resolution will also be performed. -pub fn from_avro_datum_schemata( - writer_schema: &Schema, - writer_schemata: Vec<&Schema>, - reader: &mut R, - reader_schema: Option<&Schema>, -) -> AvroResult { - from_avro_datum_reader_schemata( - writer_schema, - writer_schemata, - reader, - reader_schema, - Vec::with_capacity(0), - ) -} - -/// Decode a `Value` encoded in Avro format given the provided `Schema` and anything implementing `io::Read` -/// to read from. -/// If the writer schema is incomplete, i.e. contains `Schema::Ref`s then it will use the provided -/// schemata to resolve any dependencies. -/// -/// In case a reader `Schema` is provided, schema resolution will also be performed. -pub fn from_avro_datum_reader_schemata( - writer_schema: &Schema, - writer_schemata: Vec<&Schema>, - reader: &mut R, - reader_schema: Option<&Schema>, - reader_schemata: Vec<&Schema>, -) -> AvroResult { - let rs = ResolvedSchema::try_from(writer_schemata)?; - let value = decode_internal(writer_schema, rs.get_names(), &None, reader)?; - match reader_schema { - Some(schema) => { - if reader_schemata.is_empty() { - value.resolve(schema) - } else { - value.resolve_schemata(schema, reader_schemata) - } - } - None => Ok(value), - } -} - +/// [single-object encoding]: https://avro.apache.org/docs/++version++/specification/#single-object-encoding pub struct GenericSingleObjectReader { write_schema: ResolvedOwnedSchema, expected_header: Vec, } impl GenericSingleObjectReader { + /// Create a reader for the given schema. + /// + /// This will expect the input to use the [`RabinFingerprintHeader`]. pub fn new(schema: Schema) -> AvroResult { let header_builder = RabinFingerprintHeader::from_schema(&schema); Self::new_with_header_builder(schema, header_builder) } + /// Create a reader for the given schema with a custom fingerprint. + /// + /// See [`HeaderBuilder`] for details on how to implement a custom fingerprint. pub fn new_with_header_builder( schema: Schema, header_builder: HB, @@ -525,17 +70,36 @@ impl GenericSingleObjectReader { }) } + /// Read a [`Value`] from the reader. pub fn read_value(&self, reader: &mut R) -> AvroResult { let mut header = vec![0; self.expected_header.len()]; match reader.read_exact(&mut header) { Ok(_) => { if self.expected_header == header { - decode_internal( - self.write_schema.get_root_schema(), - self.write_schema.get_names(), - &None, - reader, + from_avro_datum(self.write_schema.get_root_schema(), reader, None) + } else { + Err( + Details::SingleObjectHeaderMismatch(self.expected_header.clone(), header) + .into(), ) + } + } + Err(io_error) => Err(Details::ReadHeader(io_error).into()), + } + } + + pub async fn read_value_async( + &self, + reader: &mut R, + ) -> AvroResult { + use futures::AsyncReadExt as _; + + let mut header = vec![0; self.expected_header.len()]; + match reader.read_exact(&mut header).await { + Ok(_) => { + if self.expected_header == header { + async_reader::from_avro_datum(self.write_schema.get_root_schema(), reader, None) + .await } else { Err( Details::SingleObjectHeaderMismatch(self.expected_header.clone(), header) @@ -548,6 +112,9 @@ impl GenericSingleObjectReader { } } +/// Reader for Avro objects created using the [single-object encoding] deserializing directly to `T`. +/// +/// [single-object encoding]: https://avro.apache.org/docs/++version++/specification/#single-object-encoding pub struct SpecificSingleObjectReader where T: AvroSchema, @@ -560,6 +127,7 @@ impl SpecificSingleObjectReader where T: AvroSchema, { + /// Create the reader from the schema associated with `T`. pub fn new() -> AvroResult> { Ok(SpecificSingleObjectReader { inner: GenericSingleObjectReader::new(T::get_schema())?, @@ -572,21 +140,37 @@ impl SpecificSingleObjectReader where T: AvroSchema + From, { + /// Read a `T` from the reader. pub fn read_from_value(&self, reader: &mut R) -> AvroResult { self.inner.read_value(reader).map(|v| v.into()) } + + /// Read a `T` from the reader. + pub async fn read_from_value_async( + &self, + reader: &mut R, + ) -> AvroResult { + self.inner.read_value_async(reader).await.map(|v| v.into()) + } } impl SpecificSingleObjectReader where T: AvroSchema + DeserializeOwned, { + /// Read a `T` from the reader. pub fn read(&self, reader: &mut R) -> AvroResult { from_value::(&self.inner.read_value(reader)?) } + + pub async fn read_async(&self, reader: &mut R) -> AvroResult { + from_value::(&self.inner.read_value_async(reader).await?) + } } -/// Reads the marker bytes from Avro bytes generated earlier by a `Writer` +/// Reads the marker bytes from Avro bytes generated earlier by a [`Writer`]. +/// +/// [`Writer`]: crate::Writer pub fn read_marker(bytes: &[u8]) -> [u8; 16] { assert!( bytes.len() > 16, @@ -600,11 +184,13 @@ pub fn read_marker(bytes: &[u8]) -> [u8; 16] { #[cfg(test)] mod tests { use super::*; - use crate::{encode::encode, headers::GlueSchemaUuidHeader, rabin::Rabin, types::Record}; + use crate::{ + Error, encode::encode, headers::GlueSchemaUuidHeader, rabin::Rabin, types::Record, + }; use apache_avro_test_helper::TestResult; use pretty_assertions::assert_eq; use serde::Deserialize; - use std::io::Cursor; + use std::{collections::HashMap, io::Cursor}; use uuid::Uuid; const SCHEMA: &str = r#" @@ -704,22 +290,26 @@ mod tests { let schema = Schema::parse_str(TEST_RECORD_SCHEMA_3240)?; let mut encoded: &'static [u8] = &[54, 6, 102, 111, 111]; - let expected_record: TestRecord3240 = TestRecord3240 { - a: 27i64, - b: String::from("foo"), - a_nullable_array: None, - a_nullable_string: None, - }; + // The schema used to read is not compatible with what is written + assert!(from_avro_datum(&schema, &mut encoded, None).is_err()); - let avro_datum = from_avro_datum(&schema, &mut encoded, None)?; - let parsed_record: TestRecord3240 = match &avro_datum { - Value::Record(_) => from_value::(&avro_datum)?, - unexpected => { - panic!("could not map avro data to struct, found unexpected: {unexpected:?}") - } - }; + // let avro_datum = from_avro_datum(&schema, &mut encoded, None)?; - assert_eq!(parsed_record, expected_record); + // let expected_record: TestRecord3240 = TestRecord3240 { + // a: 27i64, + // b: String::from("foo"), + // a_nullable_array: None, + // a_nullable_string: None, + // }; + + // let parsed_record: TestRecord3240 = match &avro_datum { + // Value::Record(_) => from_value::(&avro_datum)?, + // unexpected => { + // panic!("could not map avro data to struct, found unexpected: {unexpected:?}") + // } + // }; + // + // assert_eq!(parsed_record, expected_record); Ok(()) } @@ -780,10 +370,13 @@ mod tests { .into_iter() .rev() .collect::>(); - let reader = Reader::with_schema(&schema, &invalid[..])?; - for value in reader { - assert!(value.is_err()); - } + let mut reader = Reader::with_schema(&schema, &invalid[..])?; + + // The block says it contains 2 values, but only contains one. + // The first value is successfully decoded + let _v = reader.next().unwrap().unwrap(); + // The second fails with an unexpected end of file error. + assert!(reader.next().unwrap().is_err()); Ok(()) } @@ -815,10 +408,7 @@ mod tests { let mut writer = Writer::new(&schema, Vec::new()); let mut user_meta_data: HashMap> = HashMap::new(); - user_meta_data.insert( - "stringKey".to_string(), - "stringValue".to_string().into_bytes(), - ); + user_meta_data.insert("stringKey".to_string(), b"stringValue".to_vec()); user_meta_data.insert("bytesKey".to_string(), b"bytesValue".to_vec()); user_meta_data.insert("vecKey".to_string(), vec![1, 2, 3]); diff --git a/avro/src/state_machines/mod.rs b/avro/src/state_machines/mod.rs new file mode 100644 index 00000000..28157eae --- /dev/null +++ b/avro/src/state_machines/mod.rs @@ -0,0 +1 @@ +pub mod reading; diff --git a/avro/src/state_machines/reading/async_impl.rs b/avro/src/state_machines/reading/async_impl.rs new file mode 100644 index 00000000..c4000f1f --- /dev/null +++ b/avro/src/state_machines/reading/async_impl.rs @@ -0,0 +1,302 @@ +use async_stream::try_stream; +use futures::{AsyncRead, AsyncReadExt, Stream}; +use oval::Buffer; +use serde::Deserialize; +use std::collections::HashMap; + +use crate::{ + AvroResult, Error, Schema, + error::Details, + schema::{resolve_names, resolve_names_with_schemata}, + state_machines::reading::{ + ItemRead, StateMachine, StateMachineControlFlow, + commands::CommandTape, + datum::DatumStateMachine, + deserialize_from_tape, + object_container_file::{ + ObjectContainerFileBodyStateMachine, ObjectContainerFileHeader, + ObjectContainerFileHeaderStateMachine, + }, + value_from_tape, + }, + types::Value, +}; + +// This should probably also be a state machine and be wrapped in sync and async versions. +// But this suffices for the demonstration. +pub struct Reader<'a, R> { + reader_schema: Option<&'a Schema>, + header: ObjectContainerFileHeader, + fsm: Option, + reader: R, + buffer: Buffer, +} + +impl<'a, R: AsyncRead + Unpin> Reader<'a, R> { + /// Creates a [`crate::Reader`] that will use the schema from the file header. + /// + /// No reader [`Schema`] will be set. + /// + /// **NOTE** The Avro header is going to be read automatically upon creation of the [`crate::Reader`]. + pub async fn new(reader: R) -> Result { + Self::new_inner(reader, None, Vec::new()).await + } + + /// Creates a [`crate::Reader`] that will use the given schema for schema resolution. + /// + /// **NOTE** The Avro header is going to be read automatically upon creation of the [`crate::Reader`]. + pub async fn with_schema(schema: &'a Schema, reader: R) -> Result { + Self::new_inner(reader, Some(schema), Vec::new()).await + } + + /// Creates a [`crate::Reader`] that will use the given schema for schema resolution. + /// + /// `schema` must be in `schemata`. Otherwise, any self-referential [`Schema::Ref`]s can't be + /// resolved and an error will be returned. + /// + /// Any [`Schema::Ref`] will be resolved using the schemata. + /// + /// **NOTE** The avro header is going to be read automatically upon creation of the [`crate::Reader`]. + pub async fn with_schemata( + schema: &'a Schema, + schemata: Vec<&'a Schema>, + reader: R, + ) -> Result { + Self::new_inner(reader, Some(schema), schemata).await + } + + /// Get a reference to the optional reader [`Schema`]. + /// + /// This will only be set if there was a reader schema provided *and* it differed from the + /// writer schema. + pub fn reader_schema(&self) -> Option<&'a Schema> { + self.reader_schema + } + + /// Get a reference to the user metadata. + pub fn user_metadata(&self) -> &HashMap> { + &self.header.metadata + } + + /// Get a reference to the file header. + pub fn header(&self) -> &ObjectContainerFileHeader { + &self.header + } + + async fn new_inner( + mut reader: R, + reader_schema: Option<&'a Schema>, + schemata: Vec<&'a Schema>, + ) -> Result { + // Read a maximum of 2Kb per read + let mut buffer = Buffer::with_capacity(2 * 1024); + + // Parse the header + let mut fsm = ObjectContainerFileHeaderStateMachine::new(schemata); + let header = loop { + // Fill the buffer + let n = reader + .read(buffer.space()) + .await + .map_err(Details::ReadHeader)?; + if n == 0 { + return Err(Details::ReadHeader(std::io::ErrorKind::UnexpectedEof.into()).into()); + } + buffer.fill(n); + + // Start/continue the state machine + match fsm.parse(&mut buffer)? { + StateMachineControlFlow::NeedMore(new_fsm) => fsm = new_fsm, + StateMachineControlFlow::Done(header) => break header, + } + }; + + let tape = CommandTape::build_from_schema(&header.schema, &header.names)?; + + let reader_schema = if let Some(schema) = reader_schema + && schema != &header.schema + { + Some(schema) + } else { + None + }; + + Ok(Self { + reader_schema, + fsm: Some(ObjectContainerFileBodyStateMachine::new( + tape, + header.sync, + header.codec, + )), + header, + reader, + buffer, + }) + } + + /// Get the next object in the file + async fn next_object(&mut self) -> Option, Error>> { + if let Some(mut fsm) = self.fsm.take() { + loop { + match fsm.parse(&mut self.buffer) { + Ok(StateMachineControlFlow::NeedMore(new_fsm)) => { + fsm = new_fsm; + let n = match self.reader.read(self.buffer.space()).await { + Ok(0) => { + return Some(Err(Details::ReadIntoBuf( + std::io::ErrorKind::UnexpectedEof.into(), + ) + .into())); + } + Ok(n) => n, + Err(e) => return Some(Err(Details::ReadIntoBuf(e).into())), + }; + self.buffer.fill(n); + } + Ok(StateMachineControlFlow::Done(Some((object, fsm)))) => { + self.fsm.replace(fsm); + return Some(Ok(object)); + } + Ok(StateMachineControlFlow::Done(None)) => { + return None; + } + Err(e) => { + return Some(Err(e)); + } + } + } + } + None + } + + pub async fn stream_serde<'b, T: Deserialize<'b>>( + &mut self, + ) -> impl Stream> { + assert!( + self.reader_schema.is_none(), + "Reader schema is not supported with Serde!" + ); + try_stream! { + while let Some(object) = self.next_object().await { + let mut tape = object?; + yield deserialize_from_tape(&mut tape, &self.header.schema)?; + } + } + } + + pub async fn stream(&mut self) -> impl Stream> { + try_stream! { + while let Some(object) = self.next_object().await { + let mut tape = object?; + + let value = value_from_tape(&mut tape, &self.header.schema, &self.header.names)?; + let resolved = if let Some(schema) = self.reader_schema { + value.resolve_internal(schema, &self.header.names, &None, &None)? + } else { + value + }; + yield resolved; + } + } + } +} + +/// Decode a raw Avro datum using the provided [`Schema`]. +/// +/// In case a reader [`Schema`] is provided, schema resolution will be performed. +/// +/// **NOTE** This function is very niche and does NOT take care of reading the header and +/// consecutive data blocks. use [`Reader`] if you just want to read an Avro encoded file. +pub async fn from_avro_datum( + writer_schema: &Schema, + reader: &mut R, + reader_schema: Option<&Schema>, +) -> AvroResult { + from_avro_datum_reader_schemata(writer_schema, Vec::new(), reader, reader_schema, Vec::new()) + .await +} + +/// Decode a raw Avro datum using the provided [`Schema`] and schemata. +/// +/// `schema` must be in `schemata`. Otherwise, any self-referential [`Schema::Ref`]s can't be +/// resolved and an error will be returned. +/// +/// If the writer schema contains any [`Schema::Ref`] then it will use the provided +/// schemata to resolve any dependencies. +/// +/// In case a reader [`Schema`] is provided, schema resolution will be performed. +pub async fn from_avro_datum_schemata( + writer_schema: &Schema, + writer_schemata: Vec<&Schema>, + reader: &mut R, + reader_schema: Option<&Schema>, +) -> AvroResult { + from_avro_datum_reader_schemata( + writer_schema, + writer_schemata, + reader, + reader_schema, + Vec::new(), + ) + .await +} + +/// Decode a raw Avro datum using the provided [`Schema`] and schemata. +/// +/// `schema` must be in `schemata`. Otherwise, any self-referential [`Schema::Ref`]s can't be +/// resolved and an error will be returned. +/// +/// If the writer schema contains any [`Schema::Ref`] then it will use the provided +/// schemata to resolve any dependencies. +/// +/// In case a reader [`Schema`] is provided, schema resolution will be performed. +pub async fn from_avro_datum_reader_schemata( + writer_schema: &Schema, + writer_schemata: Vec<&Schema>, + reader: &mut R, + reader_schema: Option<&Schema>, + reader_schemata: Vec<&Schema>, +) -> AvroResult { + let mut names = HashMap::new(); + if writer_schemata.is_empty() { + resolve_names(writer_schema, &mut names, &None)?; + } else { + resolve_names_with_schemata(&writer_schemata, &mut names, &None)?; + } + + let tape = CommandTape::build_from_schema(writer_schema, &names)?; + + // Read a maximum of 2Kb per read + let mut buffer = Buffer::with_capacity(2 * 1024); + let mut fsm = DatumStateMachine::new(tape); + let value = loop { + // Fill the buffer + let n = reader + .read(buffer.space()) + .await + .map_err(Details::ReadIntoBuf)?; + if n == 0 { + return Err(Details::ReadIntoBuf(std::io::ErrorKind::UnexpectedEof.into()).into()); + } + buffer.fill(n); + + match fsm.parse(&mut buffer)? { + StateMachineControlFlow::NeedMore(new_fsm) => { + fsm = new_fsm; + } + StateMachineControlFlow::Done(mut tape) => { + break value_from_tape(&mut tape, writer_schema, &names)?; + } + } + }; + match reader_schema { + Some(schema) => { + if reader_schemata.is_empty() { + value.resolve(schema) + } else { + value.resolve_schemata(schema, reader_schemata) + } + } + None => Ok(value), + } +} diff --git a/avro/src/state_machines/reading/block.rs b/avro/src/state_machines/reading/block.rs new file mode 100644 index 00000000..cf03f6f2 --- /dev/null +++ b/avro/src/state_machines/reading/block.rs @@ -0,0 +1,110 @@ +use oval::Buffer; + +use crate::{ + Error, + error::Details, + state_machines::reading::{ + CommandTape, ItemRead, StateMachine, StateMachineControlFlow, datum::DatumStateMachine, + decode_zigzag_buffer, + }, +}; + +/// Are we currently parsing an object or just finished/reading a block header +enum TapeOrFsm { + Tape(Vec), + Fsm(DatumStateMachine), +} + +pub struct BlockStateMachine { + command_tape: CommandTape, + tape_or_fsm: TapeOrFsm, + left_in_current_block: usize, + need_to_read_block_byte_size: bool, +} + +impl BlockStateMachine { + pub fn new_with_tape(command_tape: CommandTape, tape: Vec) -> Self { + Self { + // This clone is *cheap* + command_tape, + tape_or_fsm: TapeOrFsm::Tape(tape), + left_in_current_block: 0, + need_to_read_block_byte_size: false, + } + } +} + +impl StateMachine for BlockStateMachine { + type Output = Vec; + fn parse( + mut self, + buffer: &mut Buffer, + ) -> Result, Error> { + loop { + match self.tape_or_fsm { + TapeOrFsm::Tape(mut tape) => { + // If we finished the last block (or are newly created) read the block info + if self.left_in_current_block == 0 { + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + self.tape_or_fsm = TapeOrFsm::Tape(tape); + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + + // Need to read the block byte size when block is negative + self.need_to_read_block_byte_size = block.is_negative(); + + // We do the rest with the absolute block size + let abs_block = usize::try_from(block.unsigned_abs()) + .map_err(|e| Details::ConvertU64ToUsize(e, block.unsigned_abs()))?; + self.left_in_current_block = abs_block; + tape.push(ItemRead::Block(abs_block)); + + // Done parsing the blocks + if abs_block == 0 { + return Ok(StateMachineControlFlow::Done(tape)); + } + } + + // If the block length was negative we need to read the block size + if self.need_to_read_block_byte_size { + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + self.tape_or_fsm = TapeOrFsm::Tape(tape); + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + + // Make sure the value is sane + // TODO: Maybe use safe_len here? + let _ = usize::try_from(block) + .map_err(|e| Details::ConvertI64ToUsize(e, block))?; + + // This is not necessary, as it will be overwritten before being read again + // but it does show the intent more clearly + self.need_to_read_block_byte_size = false; + } + + // We've either finished reading the block header or the last object was read and + // left_in_current_block is not zero + self.tape_or_fsm = TapeOrFsm::Fsm(DatumStateMachine::new_with_tape( + self.command_tape.clone(), + tape, + )) + } + TapeOrFsm::Fsm(fsm) => { + // (Continue) reading the object + match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.tape_or_fsm = TapeOrFsm::Fsm(fsm); + return Ok(StateMachineControlFlow::NeedMore(self)); + } + StateMachineControlFlow::Done(tape) => { + self.tape_or_fsm = TapeOrFsm::Tape(tape); + self.left_in_current_block -= 1; + } + } + } + } + } + } +} diff --git a/avro/src/state_machines/reading/bytes.rs b/avro/src/state_machines/reading/bytes.rs new file mode 100644 index 00000000..68bc529c --- /dev/null +++ b/avro/src/state_machines/reading/bytes.rs @@ -0,0 +1,71 @@ +use oval::Buffer; + +use crate::{ + error::Details, + state_machines::reading::{StateMachine, StateMachineControlFlow, decode_zigzag_buffer}, +}; + +use super::StateMachineResult; + +// TODO: Also make a String specific state machine. This allows checking the utf-8 while parsing +// which would make the parser fail quicker on large invalid strings. +// TODO: This state machine could also produce inline strings (smolstr) for strings smaller than +// size_of::, and use some extra bits to store well-known strings +// like avro.schema and avro.codec as fixed strings. + +#[derive(Default)] +pub struct BytesStateMachine { + length: Option, + data: Vec, +} + +impl BytesStateMachine { + pub fn new() -> Self { + Self { + length: None, + data: Vec::new(), + } + } + + pub fn new_with_length(length: usize) -> Self { + Self { + length: Some(length), + data: Vec::with_capacity(length), + } + } +} + +impl StateMachine for BytesStateMachine { + // This is a Vec instead of a Box<[u8]> as it's easier to create a string from a vec + type Output = Vec; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + if self.length.is_none() { + let Some(length) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer varint byte plus we know + // there at least 127 bytes in the buffer now (as otherwise we wouldn't need one more varint byte). + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + let length = + usize::try_from(length).map_err(|e| Details::ConvertI64ToUsize(e, length))?; + self.length = Some(length); + self.data.reserve_exact(length); + } + // This was just set in the previous if statement and it returns if that was not possible to do. + let Some(length) = self.length else { + unreachable!() + }; + + // How much more data is needed + let remaining = length - self.data.len(); + // How much of that is available in the buffer + let available = remaining.min(buffer.available_data()); + self.data.extend_from_slice(&buffer.data()[..available]); + buffer.consume(available); + if remaining - available == 0 { + Ok(StateMachineControlFlow::Done(self.data)) + } else { + Ok(StateMachineControlFlow::NeedMore(self)) + } + } +} diff --git a/avro/src/state_machines/reading/codec.rs b/avro/src/state_machines/reading/codec.rs new file mode 100644 index 00000000..b823ee6c --- /dev/null +++ b/avro/src/state_machines/reading/codec.rs @@ -0,0 +1,180 @@ +use crate::{ + Codec, + state_machines::reading::{StateMachine, StateMachineControlFlow, StateMachineResult}, +}; +use oval::Buffer; + +pub struct CodecStateMachine { + sub_machine: Option, + codec: Decoder, + buffer: Buffer, +} + +impl CodecStateMachine { + pub fn new(sub_machine: T, codec: Codec) -> Self { + Self { + sub_machine: Some(sub_machine), + codec: codec.into(), + buffer: Buffer::with_capacity(1024), + } + } + + pub fn reset(&mut self, sub_machine: T) { + self.buffer.reset(); + self.sub_machine = Some(sub_machine); + self.codec.reset(); + } +} + +pub enum Decoder { + Null, + Deflate(Box), + #[cfg(feature = "snappy")] + Snappy(snap::raw::Decoder), + #[cfg(feature = "zstandard")] + Zstandard(zstd::stream::raw::Decoder<'static>), + #[cfg(feature = "bzip")] + Bzip2(bzip2::Decompress), + #[cfg(feature = "xz")] + Xz(xz2::stream::Stream), +} + +impl From for Decoder { + fn from(value: Codec) -> Self { + match value { + Codec::Null => Self::Null, + Codec::Deflate(_) => { + use miniz_oxide::{DataFormat::Raw, inflate::stream::InflateState}; + Self::Deflate(InflateState::new_boxed(Raw)) + } + #[cfg(feature = "snappy")] + Codec::Snappy => Self::Snappy(snap::raw::Decoder::new()), + #[cfg(feature = "zstandard")] + Codec::Zstandard(_) => Self::Zstandard(zstd::stream::raw::Decoder::new().unwrap()), + #[cfg(feature = "bzip")] + Codec::Bzip2(_) => Self::Bzip2(bzip2::Decompress::new(false)), + #[cfg(feature = "xz")] + Codec::Xz(_) => Self::Xz(xz2::stream::Stream::new_auto_decoder(u64::MAX, 0).unwrap()), + } + } +} + +impl Decoder { + pub fn reset(&mut self) { + match self { + Decoder::Null => {} + Decoder::Deflate(decoder) => { + decoder.reset_as(miniz_oxide::inflate::stream::MinReset); + } + #[cfg(feature = "snappy")] + Decoder::Snappy(_decoder) => {} // No reset needed + #[cfg(feature = "zstandard")] + Decoder::Zstandard(decoder) => zstd::stream::raw::Operation::reinit(decoder).unwrap(), + #[cfg(feature = "bzip")] + Decoder::Bzip2(decoder) => { + // No reset/reinit API available + let _drop = std::mem::replace(decoder, bzip2::Decompress::new(false)); + } + #[cfg(feature = "xz")] + Decoder::Xz(decoder) => { + // No reset/reinit API available + let _drop = std::mem::replace( + decoder, + xz2::stream::Stream::new_auto_decoder(u64::MAX, 0).unwrap(), + ); + } + } + } +} + +impl StateMachine for CodecStateMachine { + type Output = (T::Output, Self); + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + let buffer = match &mut self.codec { + Decoder::Null => buffer, + Decoder::Deflate(decoder) => { + use miniz_oxide::{MZFlush, StreamResult, inflate::stream::inflate}; + let StreamResult { + bytes_consumed, + bytes_written, + status, + } = inflate(decoder, buffer.data(), self.buffer.space(), MZFlush::None); + status.unwrap(); + buffer.consume(bytes_consumed); + self.buffer.fill(bytes_written); + + &mut self.buffer + } + #[cfg(feature = "snappy")] + Decoder::Snappy(_decoder) => { + todo!("Snap has no streaming decoder") + } + #[cfg(feature = "zstandard")] + Decoder::Zstandard(decoder) => { + use zstd::stream::raw::{Operation, Status}; + let Status { + bytes_read, + bytes_written, + .. + } = decoder + .run_on_buffers(buffer.data(), self.buffer.space()) + .map_err(crate::error::Details::ZstdDecompress)?; + buffer.consume(bytes_read); + self.buffer.fill(bytes_written); + + &mut self.buffer + } + #[cfg(feature = "bzip")] + Decoder::Bzip2(decoder) => { + let prev_total_in = decoder.total_in(); + let prev_total_out = decoder.total_out(); + + let _status = decoder + .decompress(buffer.data(), self.buffer.space()) + .unwrap(); + + let consumed = decoder.total_in() - prev_total_in; + let filled = decoder.total_out() - prev_total_out; + + buffer.consume(usize::try_from(consumed).unwrap()); + self.buffer.fill(usize::try_from(filled).unwrap()); + + &mut self.buffer + } + #[cfg(feature = "xz")] + Decoder::Xz(decoder) => { + use xz2::stream::Action::Run; + + let prev_total_in = decoder.total_in(); + let prev_total_out = decoder.total_out(); + + let _status = decoder + .process(buffer.data(), self.buffer.space(), Run) + .unwrap(); + + let consumed = decoder.total_in() - prev_total_in; + let filled = decoder.total_out() - prev_total_out; + + buffer.consume(usize::try_from(consumed).unwrap()); + self.buffer.fill(usize::try_from(filled).unwrap()); + + &mut self.buffer + } + }; + match self + .sub_machine + .take() + .expect("CodecStateMachine was not reset!") + .parse(buffer)? + { + StateMachineControlFlow::NeedMore(fsm) => { + self.sub_machine = Some(fsm); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done(result) => { + Ok(StateMachineControlFlow::Done((result, self))) + } + } + } +} diff --git a/avro/src/state_machines/reading/commands.rs b/avro/src/state_machines/reading/commands.rs new file mode 100644 index 00000000..734d5c04 --- /dev/null +++ b/avro/src/state_machines/reading/commands.rs @@ -0,0 +1,646 @@ +use crate::{ + Error, Schema, + error::Details, + schema::{ + ArraySchema, DecimalSchema, EnumSchema, FixedSchema, MapSchema, Name, Names, RecordSchema, + UnionSchema, + }, + state_machines::reading::{ + ItemRead, SubStateMachine, block::BlockStateMachine, bytes::BytesStateMachine, + datum::DatumStateMachine, union::UnionStateMachine, + }, +}; +use std::{collections::HashMap, ops::Range, sync::Arc}; + +/// The next item type that should be read. +#[must_use] +pub enum ToRead { + Null, + Boolean, + Int, + Long, + Float, + Double, + Bytes, + String, + Enum, + Ref(CommandTape), + Fixed(usize), + Block(CommandTape), + Union { + variants: CommandTape, + num_variants: usize, + }, +} + +impl ToRead { + pub fn into_state_machine(self, read: Vec) -> SubStateMachine { + match self { + ToRead::Null => SubStateMachine::Null(read), + ToRead::Boolean => SubStateMachine::Bool(read), + ToRead::Int => SubStateMachine::Int(read), + ToRead::Long => SubStateMachine::Long(read), + ToRead::Float => SubStateMachine::Float(read), + ToRead::Double => SubStateMachine::Double(read), + ToRead::Enum => SubStateMachine::Enum(read), + ToRead::Bytes => SubStateMachine::Bytes { + fsm: BytesStateMachine::new(), + read, + }, + ToRead::String => SubStateMachine::String { + fsm: BytesStateMachine::new(), + read, + }, + ToRead::Fixed(length) => SubStateMachine::Bytes { + fsm: BytesStateMachine::new_with_length(length), + read, + }, + ToRead::Ref(commands) => { + SubStateMachine::Object(DatumStateMachine::new_with_tape(commands, read)) + } + ToRead::Block(commands) => { + SubStateMachine::Block(BlockStateMachine::new_with_tape(commands, read)) + } + ToRead::Union { + variants, + num_variants, + } => SubStateMachine::Union(UnionStateMachine::new_with_tape( + variants, + num_variants, + read, + )), + } + } +} + +impl std::fmt::Debug for ToRead { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Null => write!(f, "Null"), + Self::Boolean => write!(f, "Boolean"), + Self::Int => write!(f, "Int"), + Self::Long => write!(f, "Long"), + Self::Float => write!(f, "Float"), + Self::Double => write!(f, "Double"), + Self::Bytes => write!(f, "Bytes"), + Self::String => write!(f, "String"), + Self::Enum => write!(f, "Enum"), + // We don't show the Ref command as that could recurse forever + Self::Ref(_) => write!(f, "Ref<...>"), + Self::Fixed(arg0) => write!(f, "Fixed<{arg0}>"), + Self::Block(arg0) => f.debug_tuple("Block").field(arg0).finish(), + Self::Union { variants, .. } => f.debug_tuple("Union").field(variants).finish(), + } + } +} + +/// A section of a tape of commands. +/// +/// This has a reference to the entire tape, so that references to types (for Union,Map,Array) can be resolved. +#[derive(Clone, PartialEq)] +#[must_use] +pub struct CommandTape { + inner: Arc<[u8]>, + read_range: Range, +} + +impl CommandTape { + pub const NULL: u8 = 0; + pub const BOOLEAN: u8 = 1; + pub const INT: u8 = 2; + pub const LONG: u8 = 3; + pub const FLOAT: u8 = 4; + pub const DOUBLE: u8 = 5; + pub const BYTES: u8 = 6; + pub const STRING: u8 = 7; + pub const ENUM: u8 = 8; + /// A fixed amount of bytes. + /// + /// If the amount of bytes is smaller than or equal to `0xF`, the amount is stored in the four + /// most significant bits of the byte. Otherwise, it's stored as a native endian usize directly + /// after the command byte. + pub const FIXED: u8 = 9; + /// A block based format follows (i.e. Map or Array). + /// + /// The command sequence of the type in the block follows immediately after the command byte. + /// The length of the sequence is stored in the most significant four bits of the command byte. + /// If the sequence is larger than `0xF`, then either the entire sequence or part of it is + /// put behind a [`Self::REF`]. + pub const BLOCK: u8 = 10; + pub const UNION: u8 = 11; + /// A reference to a command sequence somewhere else in the tape. + /// + /// If the length of the sequence is smaller than or equal to `0xF`, the length is stored in the + /// four most significant bits of the byte. Otherwise, it's stored as a native endian usize + /// directly after the command byte. After the length follows the offset as a native endian + /// usize. + pub const REF: u8 = 12; + /// Skip the next `n` commands. + /// + /// A SKIP command is not counted as a command. + /// + /// If `n` is smaller than or equal to `0xF`, the amount is stored in the four most significant + /// bits of the byte. Otherwise, it's stored as a native endian usize directly after the command + /// byte. + pub const SKIP: u8 = 13; + + /// Create a new tape that will be read from start to end. + pub fn new(command_tape: Arc<[u8]>) -> Self { + let length = command_tape.len(); + Self { + inner: command_tape, + read_range: 0..length, + } + } + + pub fn build_from_schema(schema: &Schema, names: &Names) -> Result { + CommandTapeBuilder::build(schema, names) + } + + /// Check if the section of the tape we're reading is finished. + pub fn is_finished(&self) -> bool { + self.read_range.is_empty() + } + + /// Extract a part from the tape to give to a sub-state machine. + /// + /// The tape will run from offset for the given amount of commands. + pub fn extract(&self, offset: usize, commands: usize) -> Self { + let mut temp = Self { + inner: self.inner.clone(), + read_range: offset..self.inner.len(), + }; + temp.skip(commands); + let max_index = temp.read_range.next().unwrap_or(self.inner.len()); + + assert!( + max_index <= self.inner.len(), + "Reference is (partly) outside the tape" + ); + Self { + inner: self.inner.clone(), + read_range: offset..max_index, + } + } + + /// Extract many parts from the tape to give to the Union state machine. + /// + /// The tapes will run from start to end (inclusive). + pub fn extract_many(&self, parts: &[(usize, usize)]) -> Box<[Self]> { + let mut vec = Vec::with_capacity(parts.len()); + for &(start, end) in parts { + vec.push(self.extract(start, end)); + } + vec.into_boxed_slice() + } + + /// Read an array of bytes from the tape. + fn read_array(&mut self) -> [u8; N] { + let start = self.read_range.next().expect("Read past the limit"); + let end = self.read_range.nth(N - 2).expect("Read past the limit"); + self.inner[start..=end].try_into().expect("Unreachable!") + } + + fn read_inline_or(&mut self, byte: u8) -> usize { + if byte >> 4 != 0 { + // Length is stored inline + (byte >> 4) as usize + } else { + usize::from_ne_bytes(self.read_array()) + } + } + + /// Get the next command from the tape. + /// + /// Will return `None` if exhausted. + pub fn command(&mut self) -> Option { + if let Some(position) = self.read_range.next() { + let byte = self.inner[position]; + match byte & 0xF { + Self::NULL => Some(ToRead::Null), + Self::BOOLEAN => Some(ToRead::Boolean), + Self::INT => Some(ToRead::Int), + Self::LONG => Some(ToRead::Long), + Self::FLOAT => Some(ToRead::Float), + Self::DOUBLE => Some(ToRead::Double), + Self::BYTES => Some(ToRead::Bytes), + Self::STRING => Some(ToRead::String), + Self::ENUM => Some(ToRead::Enum), + Self::FIXED => Some(ToRead::Fixed(self.read_inline_or(byte))), + Self::BLOCK => { + // ToRead::Block + let size = (byte >> 4) as usize; + self.skip(size); + Some(ToRead::Block(self.extract(position + 1, size))) + } + Self::UNION => { + // How many variants are there? + let num_variants = self.read_inline_or(byte); + + // Skip over the union variants while keeping track of their start and end + // so we can easily create the command tape + let start = self.read_range.start; + self.skip(num_variants); + let end = self.read_range.start; + + // Create the command tape from the previously tracked start and end + let mut tape = self.clone(); + tape.read_range.start = start; + tape.read_range.end = end; + + Some(ToRead::Union { + variants: tape, + num_variants, + }) + } + Self::REF => { + let size = self.read_inline_or(byte); + let offset = usize::from_ne_bytes(self.read_array()); + Some(ToRead::Ref(self.extract(offset, size))) + } + Self::SKIP => { + // Read how many commands to skip and skip them + let commands = self.read_inline_or(byte); + self.skip(commands); + + // Return the next command + self.command() + } + _ => unreachable!(), // TODO: There is room here to specialize certain types, like a Union of Null and some other type + } + } else { + None + } + } + + /// Skip `amount` commands. + /// + /// If a command contains subcommands, these will also be skipped. + /// + /// # Returns + /// `None` if it read past the end of the tape + pub(crate) fn skip(&mut self, mut amount: usize) -> Option<()> { + let mut i = 0; + while i < amount { + let position = self.read_range.next()?; + let byte = self.inner[position]; + match byte & 0xF { + CommandTape::BOOLEAN + | CommandTape::INT + | CommandTape::LONG + | CommandTape::FLOAT + | CommandTape::DOUBLE + | CommandTape::BYTES + | CommandTape::STRING + | CommandTape::ENUM + | CommandTape::NULL => {} + CommandTape::FIXED => { + let _size = self.read_inline_or(byte); + } + CommandTape::REF => { + let _size = self.read_inline_or(byte); + let _offset = usize::from_ne_bytes(self.read_array()); + } + CommandTape::UNION | CommandTape::BLOCK | CommandTape::SKIP => { + // These commands can inline other commands, so add them to the skip list + let num_variants = self.read_inline_or(byte); + amount += num_variants; + + // Skip does not count as a command, but we do increment `i` so we compensate + // for that by incrementing the amount + if byte & 0xF == CommandTape::SKIP { + amount += 1; + } + } + _ => unreachable!(), + } + i += 1; + } + Some(()) + } +} + +impl std::fmt::Debug for CommandTape { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut c = self.clone(); + + write!(f, "CommandTape: ")?; + let mut list = f.debug_list(); + while let Some(command) = c.command() { + list.entry(&command); + } + list.finish() + } +} + +struct CommandTapeBuilder<'a> { + tape: Vec, + references: HashMap<&'a Name, (usize, usize)>, + names: &'a Names, +} + +impl<'a> CommandTapeBuilder<'a> { + pub fn new(names: &'a Names) -> Self { + Self { + tape: Vec::new(), + references: HashMap::new(), + names, + } + } + + fn add_schema(&mut self, schema: &'a Schema, inline_up_to: usize) -> Result { + match schema { + Schema::Null => { + self.tape.push(CommandTape::NULL); + Ok(1) + } + Schema::Boolean => { + self.tape.push(CommandTape::BOOLEAN); + Ok(1) + } + Schema::Int | Schema::Date | Schema::TimeMillis => { + self.tape.push(CommandTape::INT); + Ok(1) + } + Schema::Long + | Schema::TimeMicros + | Schema::TimestampMillis + | Schema::TimestampMicros + | Schema::TimestampNanos + | Schema::LocalTimestampMillis + | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => { + self.tape.push(CommandTape::LONG); + Ok(1) + } + Schema::Float => { + self.tape.push(CommandTape::FLOAT); + Ok(1) + } + Schema::Double => { + self.tape.push(CommandTape::DOUBLE); + Ok(1) + } + Schema::Bytes | Schema::BigDecimal => { + self.tape.push(CommandTape::BYTES); + Ok(1) + } + Schema::String | Schema::Uuid => { + self.tape.push(CommandTape::STRING); + Ok(1) + } + Schema::Array(ArraySchema { items, .. }) => { + let block_offset = self.tape.len(); + self.tape.push(CommandTape::BLOCK); + let commands = self.add_schema(items, 16)?; + self.tape[block_offset] = CommandTape::BLOCK | (commands << 4) as u8; + Ok(1) + } + Schema::Map(MapSchema { types, .. }) => { + let block_offset = self.tape.len(); + self.tape.push(CommandTape::BLOCK); + self.tape.push(CommandTape::STRING); + let commands = self.add_schema(types, 15)?; + self.tape[block_offset] = CommandTape::BLOCK | ((commands + 1) << 4) as u8; + Ok(1) + } + Schema::Union(UnionSchema { schemas, .. }) => { + let schema_len = schemas.len(); + if 0 < schema_len && schema_len <= 0xF { + self.tape.push(CommandTape::UNION | (schema_len << 4) as u8); + } else { + self.tape.push(CommandTape::UNION); + self.tape.extend_from_slice(&schema_len.to_ne_bytes()); + } + for schema in schemas { + self.add_schema(schema, 1)?; + } + Ok(1) + } + Schema::Record(RecordSchema { name, fields, .. }) => { + if let Some(&(offset, commands)) = self.references.get(name) { + self.add_reference(offset, commands); + Ok(1) + } else if fields.is_empty() { + panic!("Record has no fields! {schema:?}"); + } else { + let commands = fields.len(); + if commands > inline_up_to { + // If this record is larger than the amount we're allowed to inline, inject + // a SKIP command. + if commands <= 0xF { + self.tape.push(CommandTape::SKIP | (commands << 4) as u8); + } else { + self.tape.push(CommandTape::SKIP); + self.tape.extend_from_slice(&commands.to_ne_bytes()); + } + } + let offset = self.tape.len(); + self.references.insert(name, (offset, commands)); + for field in fields { + let _commands = self.add_schema(&field.schema, 1)?; + } + if commands > inline_up_to { + // Now refer back to the skip block + self.add_reference(offset, commands); + Ok(1) + } else { + Ok(commands) + } + } + } + Schema::Enum(EnumSchema { name, .. }) => { + let offset = self.tape.len(); + let commands = 1; + self.tape.push(CommandTape::ENUM); + self.references.insert(name, (offset, commands)); + Ok(1) + } + Schema::Fixed(FixedSchema { name, size, .. }) => { + let offset = self.tape.len(); + if 0 < *size && *size <= 0xF { + self.tape.push(CommandTape::FIXED | (*size << 4) as u8); + } else { + self.tape.push(CommandTape::FIXED); + self.tape.extend_from_slice(&size.to_ne_bytes()); + } + self.references.entry(name).or_insert((offset, 1)); + Ok(1) + } + Schema::Decimal(DecimalSchema { inner, .. }) => self.add_schema(inner, inline_up_to), + Schema::Duration => { + self.tape.push(CommandTape::FIXED | 12 << 4); + Ok(1) + } + Schema::Ref { name } => { + if let Some(&(offset, commands)) = self.references.get(name) { + self.add_reference(offset, commands); + Ok(1) + } else if let Some(schema) = self.names.get(name).as_ref() { + self.add_schema(schema, inline_up_to) + } else { + Err(Details::SchemaResolutionError(name.clone()).into()) + } + } + } + } + + fn add_reference(&mut self, offset: usize, commands: usize) { + if commands == 0 { + self.tape.push(CommandTape::NULL); + } else if commands <= 0xF { + self.tape.push(CommandTape::REF | (commands << 4) as u8); + } else { + self.tape.push(CommandTape::REF); + self.tape.extend_from_slice(&commands.to_ne_bytes()); + } + self.tape.extend_from_slice(&offset.to_ne_bytes()); + } + + pub fn build(schema: &Schema, names: &'a Names) -> Result { + let mut builder = Self::new(names); + + builder.add_schema(schema, usize::MAX)?; + + let tape_len = builder.tape.len(); + Ok(CommandTape { + inner: Arc::from(builder.tape), + read_range: 0..tape_len, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + pub fn command_tape_simple() { + assert_eq!( + CommandTape::build_from_schema(&Schema::Null, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::NULL] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Boolean, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::BOOLEAN] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Int, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::INT] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Date, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::INT] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimeMillis, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::INT] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Long, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimeMicros, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimestampMillis, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimestampMicros, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimestampNanos, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::LocalTimestampMillis, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::LocalTimestampMicros, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::LocalTimestampNanos, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Float, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::FLOAT] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Double, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::DOUBLE] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Bytes, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::BYTES] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::String, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::STRING] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Uuid, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::STRING] + ); + } +} diff --git a/avro/src/state_machines/reading/datum.rs b/avro/src/state_machines/reading/datum.rs new file mode 100644 index 00000000..2e9a4beb --- /dev/null +++ b/avro/src/state_machines/reading/datum.rs @@ -0,0 +1,80 @@ +use oval::Buffer; + +use super::StateMachineResult; +use crate::state_machines::reading::{ + CommandTape, ItemRead, StateMachine, StateMachineControlFlow, SubStateMachine, +}; + +enum TapeOrFsm { + Tape(Vec), + Fsm(Box), +} + +pub struct DatumStateMachine { + command_tape: CommandTape, + tape_or_fsm: TapeOrFsm, +} + +impl DatumStateMachine { + /// Create a new state machine that reads a datum from the commands. + pub fn new(command_tape: CommandTape) -> Self { + Self::new_with_tape(command_tape, Vec::new()) + } + + /// Create a new state machine that appends to the tape (the tape is returned on completion). + pub fn new_with_tape(command_tape: CommandTape, tape: Vec) -> Self { + Self { + command_tape, + tape_or_fsm: TapeOrFsm::Tape(tape), + } + } +} + +impl StateMachine for DatumStateMachine { + type Output = Vec; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + // While there's data and commands to process we keep progressing the state machines + while !buffer.data().is_empty() { + match self.tape_or_fsm { + TapeOrFsm::Fsm(fsm) => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.tape_or_fsm = TapeOrFsm::Fsm(Box::new(fsm)); + return Ok(StateMachineControlFlow::NeedMore(self)); + } + StateMachineControlFlow::Done(read) => { + self.tape_or_fsm = TapeOrFsm::Tape(read); + } + }, + TapeOrFsm::Tape(tape) => { + if let Some(command) = self.command_tape.command() { + let fsm = command.into_state_machine(tape); + // This is a duplicate of the TapeOrFsm::Fsm logic, but saves us an allocation + // by doing it immediately. + match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.tape_or_fsm = TapeOrFsm::Fsm(Box::new(fsm)); + return Ok(StateMachineControlFlow::NeedMore(self)); + } + StateMachineControlFlow::Done(read) => { + self.tape_or_fsm = TapeOrFsm::Tape(read); + } + } + } else { + self.tape_or_fsm = TapeOrFsm::Tape(tape); + break; + } + } + } + } + + // Check if we're completely finished or need more data + match (self.tape_or_fsm, self.command_tape.is_finished()) { + (TapeOrFsm::Tape(read), true) => Ok(StateMachineControlFlow::Done(read)), + (tape_or_fsm, _) => { + self.tape_or_fsm = tape_or_fsm; + Ok(StateMachineControlFlow::NeedMore(self)) + } + } + } +} diff --git a/avro/src/state_machines/reading/error.rs b/avro/src/state_machines/reading/error.rs new file mode 100644 index 00000000..12bcefd8 --- /dev/null +++ b/avro/src/state_machines/reading/error.rs @@ -0,0 +1,16 @@ +use crate::{Schema, state_machines::reading::ItemRead}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ValueFromTapeError { + #[error("Unexpected end of tape while building Value")] + UnexpectedEndOfTape, + #[error( + "Mismatch between tape and schema while building Value: schema {schema}, tape: {item:?}" + )] + TapeSchemaMismatch { schema: Schema, item: ItemRead }, + #[error( + "Mismatch between tape and schema while building Value: Schema::Fixed expected {expected} bytes, but tape had {actual}" + )] + TapeSchemaMismatchFixed { expected: usize, actual: usize }, +} diff --git a/avro/src/state_machines/reading/mod.rs b/avro/src/state_machines/reading/mod.rs new file mode 100644 index 00000000..465642ac --- /dev/null +++ b/avro/src/state_machines/reading/mod.rs @@ -0,0 +1,1151 @@ +use crate::{ + Decimal, Duration, Error, Schema, + bigdecimal::deserialize_big_decimal, + error::Details, + schema::{ + ArraySchema, EnumSchema, FixedSchema, MapSchema, Name, Names, Namespace, RecordSchema, + ResolvedSchema, UnionSchema, + }, + state_machines::reading::{ + block::BlockStateMachine, bytes::BytesStateMachine, commands::CommandTape, + datum::DatumStateMachine, error::ValueFromTapeError, union::UnionStateMachine, + }, + types::Value, + util::decode_variable, +}; +use oval::Buffer; +use serde::Deserialize; +use std::{borrow::Borrow, collections::HashMap, io::Read, ops::Deref, str::FromStr}; +use uuid::Uuid; + +pub mod async_impl; +pub mod block; +pub mod bytes; +pub mod codec; +mod commands; +pub mod datum; +pub mod error; +mod object_container_file; +pub mod sync; +mod union; + +pub trait StateMachine: Sized { + type Output: Sized; + + /// Start/continue the state machine. + /// + /// Implementers are not allowed to return until they can't make progress anymore. + fn parse(self, buffer: &mut Buffer) -> StateMachineResult; +} + +/// Indicates whether the state machine has completed or needs to be polled again. +#[must_use] +pub enum StateMachineControlFlow { + /// The state machine needs more data before it can continue. + NeedMore(StateMachine), + /// The state machine is done and the result is returned.s + Done(Output), +} + +pub type StateMachineResult = + Result, Error>; + +/// The sub state machine that is currently being driven. +/// +/// The `Int`, `Long`, `Float`, `Double`, and `Enum` statemachines don't have state, as +/// they don't consume the buffer if there are not enough bytes. This means that the only +/// thing these statemachines are keeping track of is which type we're actually decoding. +pub enum SubStateMachine { + Null(Vec), + Bool(Vec), + Int(Vec), + Long(Vec), + Float(Vec), + Double(Vec), + Enum(Vec), + Bytes { + fsm: BytesStateMachine, + read: Vec, + }, + String { + fsm: BytesStateMachine, + read: Vec, + }, + Block(BlockStateMachine), + Object(DatumStateMachine), + Union(UnionStateMachine), +} + +impl StateMachine for SubStateMachine { + type Output = Vec; + + fn parse(self, buffer: &mut Buffer) -> StateMachineResult { + match self { + SubStateMachine::Null(mut read) => { + read.push(ItemRead::Null); + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Bool(mut read) => { + let mut byte = [0; 1]; + buffer + .read_exact(&mut byte) + .expect("Unreachable! Buffer is not empty"); + match byte { + [0] => read.push(ItemRead::Boolean(false)), + [1] => read.push(ItemRead::Boolean(true)), + [byte] => return Err(Details::BoolValue(byte).into()), + } + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Int(mut read) => { + let Some(n) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(Self::Int(read))); + }; + let n = i32::try_from(n).map_err(|e| Details::ZagI32(e, n))?; + read.push(ItemRead::Int(n)); + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Long(mut read) => { + let Some(n) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(Self::Long(read))); + }; + read.push(ItemRead::Long(n)); + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Float(mut read) => { + let Some(bytes) = buffer.data().first_chunk().copied() else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(Self::Float(read))); + }; + buffer.consume(4); + read.push(ItemRead::Float(f32::from_le_bytes(bytes))); + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Double(mut read) => { + let Some(bytes) = buffer.data().first_chunk().copied() else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(Self::Double(read))); + }; + buffer.consume(8); + read.push(ItemRead::Double(f64::from_le_bytes(bytes))); + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Enum(mut read) => { + let Some(n) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(Self::Enum(read))); + }; + // TODO: Wrong error + let n = u32::try_from(n).map_err(|e| Details::ZagI32(e, n))?; + read.push(ItemRead::Enum(n)); + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Bytes { fsm, mut read } => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + Ok(StateMachineControlFlow::NeedMore(Self::Bytes { fsm, read })) + } + StateMachineControlFlow::Done(bytes) => { + read.push(ItemRead::Bytes(bytes)); + Ok(StateMachineControlFlow::Done(read)) + } + }, + SubStateMachine::String { fsm, mut read } => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + Ok(StateMachineControlFlow::NeedMore(Self::String { + fsm, + read, + })) + } + StateMachineControlFlow::Done(bytes) => { + let string = String::from_utf8(bytes).map_err(Details::ConvertToUtf8)?; + read.push(ItemRead::String(string)); + Ok(StateMachineControlFlow::Done(read)) + } + }, + SubStateMachine::Block(fsm) => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + Ok(StateMachineControlFlow::NeedMore(Self::Block(fsm))) + } + StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + }, + SubStateMachine::Union(fsm) => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + Ok(StateMachineControlFlow::NeedMore(Self::Union(fsm))) + } + StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + }, + SubStateMachine::Object(fsm) => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + Ok(StateMachineControlFlow::NeedMore(Self::Object(fsm))) + } + StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + }, + } + } +} + +/// A item that was read from the document. +#[derive(Debug)] +#[must_use] +pub enum ItemRead { + Null, + Boolean(bool), + Int(i32), + Long(i64), + Float(f32), + Double(f64), + // TODO: smollvec/hipbytes? + Bytes(Vec), + // TODO: smollstr/hipstr? + String(String), + /// The variant of the Enum that was read. + Enum(u32), + /// The variant of the Union that was read. + /// + /// The variant data is next. + Union(u32), + /// The start of a block of a Map or Array. + Block(usize), +} + +/// Read a zigzagged varint from the buffer. +/// +/// Will only consume the buffer if a whole number has been read. +/// If insufficient bytes are available it will return `Ok(None)` to +/// indicate it needs more bytes. +pub fn decode_zigzag_buffer(buffer: &mut Buffer) -> Result, Error> { + if let Some((decoded, consumed)) = decode_variable(buffer.data())? { + buffer.consume(consumed); + Ok(Some(decoded)) + } else { + Ok(None) + } +} + +/// Deserialize a tape to a [`Value`] using the provided [`Schema`]. +/// +/// The schema must be compatible with the schema used by the original writer. +/// +/// Both `names` and `extra_names` are checked when a [`Schema::Ref`] is encountered. They're allowed +/// to have overlapping items. +/// +/// # Panics +/// Can panic if the provided schema does not exactly match the schema used to create the tape. To +/// convert between the writer and reader schema use [`Value::resolve`] instead. +pub fn value_from_tape( + tape: &mut Vec, + schema: &Schema, + names: &Names, +) -> Result { + value_from_tape_internal(&mut tape.drain(..), schema, &None, names) +} + +/// Recursively transform the `tape` into a [`Value`] according to the provided [`Schema`]. +/// +/// Both `names` and `extra_names` are checked when a [`Schema::Ref`] is encountered. They're allowed +/// to have overlapping items. +pub fn value_from_tape_internal( + tape: &mut impl Iterator, + schema: &Schema, + enclosing_namespace: &Namespace, + names: &Names, +) -> Result { + match schema { + Schema::Null => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Null => Ok(Value::Null), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Boolean => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Boolean(bool) => Ok(Value::Boolean(bool)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Int => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Int(bool) => Ok(Value::Int(bool)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Long => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::Long(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Float => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Float(float) => Ok(Value::Float(float)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Double => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Double(double) => Ok(Value::Double(double)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Bytes => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Bytes(bytes) => Ok(Value::Bytes(bytes)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::String => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::String(string) => Ok(Value::String(string)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Array(ArraySchema { items, .. }) => { + let mut collected = Vec::new(); + loop { + let n = match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Block(n) => Ok(n), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + }), + }?; + if n == 0 { + break; + } + collected.reserve(n); + for _ in 0..n { + collected.push(value_from_tape_internal( + tape, + items, + enclosing_namespace, + names, + )?); + } + } + Ok(Value::Array(collected)) + } + Schema::Map(MapSchema { types, .. }) => { + let mut collected = HashMap::new(); + loop { + let n = match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Block(n) => Ok(n), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + }), + }?; + if n == 0 { + break; + } + collected.reserve(n); + for _ in 0..n { + let key = match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::String(string) => Ok(string), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: Schema::String, + item, + }), + }?; + let val = value_from_tape_internal(tape, types, enclosing_namespace, names)?; + collected.insert(key, val); + } + } + Ok(Value::Map(collected)) + } + Schema::Union(UnionSchema { schemas, .. }) => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Union(variant) => { + let schema = schemas.get(usize::try_from(variant).unwrap()).ok_or( + Details::GetUnionVariant { + index: variant as i64, + num_variants: schemas.len(), + }, + )?; + let value = Box::new(value_from_tape_internal( + tape, + schema, + enclosing_namespace, + names, + )?); + Ok(Value::Union(variant, value)) + } + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::Record(RecordSchema { name, fields, .. }) => { + let fqn = name.fully_qualified_name(enclosing_namespace); + let mut collected = Vec::with_capacity(fields.len()); + for field in fields { + let collect = value_from_tape_internal(tape, &field.schema, &fqn.namespace, names)?; + collected.push((field.name.clone(), collect)); + } + Ok(Value::Record(collected)) + } + Schema::Enum(EnumSchema { symbols, .. }) => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Enum(val) => Ok(Value::Enum( + val, + symbols.get(usize::try_from(val).unwrap()).unwrap().clone(), + )), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::Fixed(FixedSchema { size, .. }) => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Bytes(fixed) => { + if *size == fixed.len() { + Ok(Value::Fixed(fixed.len(), fixed)) + } else { + Err(ValueFromTapeError::TapeSchemaMismatchFixed { + expected: *size, + actual: fixed.len(), + } + .into()) + } + } + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::Decimal(_) => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Bytes(bytes) => Ok(Value::Decimal(Decimal::from(&bytes))), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::BigDecimal => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Bytes(bytes) => deserialize_big_decimal(&bytes).map(Value::BigDecimal), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Uuid => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::String(string) => Uuid::from_str(&string) + .map(Value::Uuid) + .map_err(|e| Details::ConvertStrToUuid(e).into()), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Date => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Int(int) => Ok(Value::Date(int)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::TimeMillis => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Int(int) => Ok(Value::TimeMillis(int)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::TimeMicros => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::TimeMicros(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::TimestampMillis => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::TimestampMillis(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::TimestampMicros => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::TimestampMicros(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::TimestampNanos => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::TimestampNanos(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::LocalTimestampMillis => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::LocalTimestampMillis(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::LocalTimestampMicros => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::LocalTimestampMicros(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::LocalTimestampNanos => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::LocalTimestampNanos(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::Duration => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Bytes(bytes) => { + let array: [u8; 12] = bytes.deref().try_into().unwrap(); + Ok(Value::Duration(Duration::from(array))) + } + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Ref { name } => { + let fqn = name.fully_qualified_name(enclosing_namespace); + if let Some(resolved) = names.get(&fqn) { + value_from_tape_internal(tape, resolved, &fqn.namespace, names) + } else { + Err(Details::SchemaResolutionError(fqn).into()) + } + } + } +} + +/// Deserialize a tape to `T` using the provided [`Schema`]. +/// +/// The schema must be compatible with the schema used by the original writer. +pub fn deserialize_from_tape<'a, T: Deserialize<'a>>( + tape: &mut Vec, + schema: &Schema, +) -> Result { + let rs = ResolvedSchema::try_from(schema)?; + deserialize_from_tape_internal(tape, schema, rs.get_names(), &None) +} + +/// Recursively transform the `tape` into a `T` according to the provided [`Schema`]. +fn deserialize_from_tape_internal<'a, T: Deserialize<'a>, S: Borrow>( + tape: &mut Vec, + _schema: &Schema, + _names: &HashMap, + _enclosing_namespace: &Namespace, +) -> Result { + tape.clear(); + todo!() +} + +#[cfg(test)] +#[allow(clippy::expect_fun_call)] +mod tests { + use crate::{ + Decimal, + encode::{encode, tests::success}, + from_avro_datum, + schema::{DecimalSchema, FixedSchema, Schema}, + types::{ + Value, + Value::{Array, Int, Map}, + }, + }; + use apache_avro_test_helper::TestResult; + use pretty_assertions::assert_eq; + use std::collections::HashMap; + use uuid::Uuid; + + #[test] + fn test_decode_array_without_size() -> TestResult { + let mut input: &[u8] = &[6, 2, 4, 6, 0]; + + let result = from_avro_datum(&Schema::array(Schema::Int), &mut input, None)?; + + assert_eq!(Array(vec!(Int(1), Int(2), Int(3))), result); + + Ok(()) + } + + #[test] + fn test_decode_array_with_size() -> TestResult { + let mut input: &[u8] = &[5, 6, 2, 4, 6, 0]; + let result = from_avro_datum(&Schema::array(Schema::Int), &mut input, None)?; + assert_eq!(Array(vec!(Int(1), Int(2), Int(3))), result); + + Ok(()) + } + + #[test] + fn test_decode_map_without_size() -> TestResult { + let mut input: &[u8] = &[0x02, 0x08, 0x74, 0x65, 0x73, 0x74, 0x02, 0x00]; + let result = from_avro_datum(&Schema::map(Schema::Int), &mut input, None)?; + let mut expected = HashMap::new(); + expected.insert(String::from("test"), Int(1)); + assert_eq!(Map(expected), result); + + Ok(()) + } + + #[test] + fn test_decode_map_with_size() -> TestResult { + let mut input: &[u8] = &[0x01, 0x0C, 0x08, 0x74, 0x65, 0x73, 0x74, 0x02, 0x00]; + let result = from_avro_datum(&Schema::map(Schema::Int), &mut input, None)?; + let mut expected = HashMap::new(); + expected.insert(String::from("test"), Int(1)); + assert_eq!(Map(expected), result); + + Ok(()) + } + + #[test] + fn test_negative_decimal_value() -> TestResult { + use crate::{encode::encode, schema::Name}; + use num_bigint::ToBigInt; + let inner = Box::new(Schema::Fixed( + FixedSchema::builder() + .name(Name::new("decimal")?) + .size(2) + .build(), + )); + let schema = Schema::Decimal(DecimalSchema { + inner, + precision: 4, + scale: 2, + }); + let bigint = (-423).to_bigint().unwrap(); + let value = Value::Decimal(Decimal::from(bigint.to_signed_bytes_be())); + + let mut buffer = Vec::new(); + encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); + + let mut bytes = &buffer[..]; + let result = from_avro_datum(&schema, &mut bytes, None)?; + assert_eq!(result, value); + + Ok(()) + } + + #[test] + fn test_decode_decimal_with_bigger_than_necessary_size() -> TestResult { + use crate::{encode::encode, schema::Name}; + use num_bigint::ToBigInt; + let inner = Box::new(Schema::Fixed(FixedSchema { + size: 13, + name: Name::new("decimal")?, + aliases: None, + doc: None, + default: None, + attributes: Default::default(), + })); + let schema = Schema::Decimal(DecimalSchema { + inner, + precision: 4, + scale: 2, + }); + let value = Value::Decimal(Decimal::from( + ((-423).to_bigint().unwrap()).to_signed_bytes_be(), + )); + let mut buffer = Vec::::new(); + + encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); + let mut bytes: &[u8] = &buffer[..]; + let result = from_avro_datum(&schema, &mut bytes, None)?; + assert_eq!(result, value); + + Ok(()) + } + + #[test] + fn test_avro_3448_recursive_definition_decode_union() -> TestResult { + // if encoding fails in this test check the corresponding test in encode + let schema = Schema::parse_str( + r#" + { + "type":"record", + "name":"TestStruct", + "fields": [ + { + "name":"a", + "type":[ "null", { + "type":"record", + "name": "Inner", + "fields": [ { + "name":"z", + "type":"int" + }] + }] + }, + { + "name":"b", + "type":"Inner" + } + ] + }"#, + )?; + + let inner_value1 = Value::Record(vec![("z".into(), Value::Int(3))]); + let inner_value2 = Value::Record(vec![("z".into(), Value::Int(6))]); + let outer_value1 = Value::Record(vec![ + ("a".into(), Value::Union(1, Box::new(inner_value1))), + ("b".into(), inner_value2.clone()), + ]); + let mut buf = Vec::new(); + encode(&outer_value1, &schema, &mut buf).expect(&success(&outer_value1, &schema)); + assert!(!buf.is_empty()); + let mut bytes = &buf[..]; + assert_eq!( + outer_value1, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to decode using recursive definitions with schema:\n {:?}\n", + &schema + )) + ); + + let mut buf = Vec::new(); + let outer_value2 = Value::Record(vec![ + ("a".into(), Value::Union(0, Box::new(Value::Null))), + ("b".into(), inner_value2), + ]); + encode(&outer_value2, &schema, &mut buf).expect(&success(&outer_value2, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_value2, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to decode using recursive definitions with schema:\n {:?}\n", + &schema + )) + ); + + Ok(()) + } + + #[test] + fn test_avro_3448_recursive_definition_decode_array() -> TestResult { + let schema = Schema::parse_str( + r#" + { + "type":"record", + "name":"TestStruct", + "fields": [ + { + "name":"a", + "type":{ + "type":"array", + "items": { + "type":"record", + "name": "Inner", + "fields": [ { + "name":"z", + "type":"int" + }] + } + } + }, + { + "name":"b", + "type": "Inner" + } + ] + }"#, + )?; + + let inner_value1 = Value::Record(vec![("z".into(), Value::Int(3))]); + let inner_value2 = Value::Record(vec![("z".into(), Value::Int(6))]); + let outer_value = Value::Record(vec![ + ("a".into(), Value::Array(vec![inner_value1])), + ("b".into(), inner_value2), + ]); + let mut buf = Vec::new(); + encode(&outer_value, &schema, &mut buf).expect(&success(&outer_value, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_value, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to decode using recursive definitions with schema:\n {:?}\n", + &schema + )) + ); + + Ok(()) + } + + #[test] + fn test_avro_3448_recursive_definition_decode_map() -> TestResult { + let schema = Schema::parse_str( + r#" + { + "type":"record", + "name":"TestStruct", + "fields": [ + { + "name":"a", + "type":{ + "type":"map", + "values": { + "type":"record", + "name": "Inner", + "fields": [ { + "name":"z", + "type":"int" + }] + } + } + }, + { + "name":"b", + "type": "Inner" + } + ] + }"#, + )?; + + let inner_value1 = Value::Record(vec![("z".into(), Value::Int(3))]); + let inner_value2 = Value::Record(vec![("z".into(), Value::Int(6))]); + let outer_value = Value::Record(vec![ + ( + "a".into(), + Value::Map(vec![("akey".into(), inner_value1)].into_iter().collect()), + ), + ("b".into(), inner_value2), + ]); + let mut buf = Vec::new(); + encode(&outer_value, &schema, &mut buf).expect(&success(&outer_value, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_value, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to decode using recursive definitions with schema:\n {:?}\n", + &schema + )) + ); + + Ok(()) + } + + #[test] + fn test_avro_3448_proper_multi_level_decoding_middle_namespace() -> TestResult { + // if encoding fails in this test check the corresponding test in encode + let schema = r#" + { + "name": "record_name", + "namespace": "space", + "type": "record", + "fields": [ + { + "name": "outer_field_1", + "type": [ + "null", + { + "type": "record", + "name": "middle_record_name", + "namespace":"middle_namespace", + "fields":[ + { + "name":"middle_field_1", + "type":[ + "null", + { + "type":"record", + "name":"inner_record_name", + "fields":[ + { + "name":"inner_field_1", + "type":"double" + } + ] + } + ] + } + ] + } + ] + }, + { + "name": "outer_field_2", + "type" : "middle_namespace.inner_record_name" + } + ] + } + "#; + let schema = Schema::parse_str(schema)?; + let inner_record = Value::Record(vec![("inner_field_1".into(), Value::Double(5.4))]); + let middle_record_variation_1 = Value::Record(vec![( + "middle_field_1".into(), + Value::Union(0, Box::new(Value::Null)), + )]); + let middle_record_variation_2 = Value::Record(vec![( + "middle_field_1".into(), + Value::Union(1, Box::new(inner_record.clone())), + )]); + let outer_record_variation_1 = Value::Record(vec![ + ( + "outer_field_1".into(), + Value::Union(0, Box::new(Value::Null)), + ), + ("outer_field_2".into(), inner_record.clone()), + ]); + let outer_record_variation_2 = Value::Record(vec![ + ( + "outer_field_1".into(), + Value::Union(1, Box::new(middle_record_variation_1)), + ), + ("outer_field_2".into(), inner_record.clone()), + ]); + let outer_record_variation_3 = Value::Record(vec![ + ( + "outer_field_1".into(), + Value::Union(1, Box::new(middle_record_variation_2)), + ), + ("outer_field_2".into(), inner_record), + ]); + + let mut buf = Vec::new(); + encode(&outer_record_variation_1, &schema, &mut buf) + .expect(&success(&outer_record_variation_1, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_record_variation_1, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", + &schema + )) + ); + + let mut buf = Vec::new(); + encode(&outer_record_variation_2, &schema, &mut buf) + .expect(&success(&outer_record_variation_2, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_record_variation_2, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", + &schema + )) + ); + + let mut buf = Vec::new(); + encode(&outer_record_variation_3, &schema, &mut buf) + .expect(&success(&outer_record_variation_3, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_record_variation_3, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", + &schema + )) + ); + + Ok(()) + } + + #[test] + fn test_avro_3448_proper_multi_level_decoding_inner_namespace() -> TestResult { + // if encoding fails in this test check the corresponding test in encode + let schema = r#" + { + "name": "record_name", + "namespace": "space", + "type": "record", + "fields": [ + { + "name": "outer_field_1", + "type": [ + "null", + { + "type": "record", + "name": "middle_record_name", + "namespace":"middle_namespace", + "fields":[ + { + "name":"middle_field_1", + "type":[ + "null", + { + "type":"record", + "name":"inner_record_name", + "namespace":"inner_namespace", + "fields":[ + { + "name":"inner_field_1", + "type":"double" + } + ] + } + ] + } + ] + } + ] + }, + { + "name": "outer_field_2", + "type" : "inner_namespace.inner_record_name" + } + ] + } + "#; + let schema = Schema::parse_str(schema)?; + let inner_record = Value::Record(vec![("inner_field_1".into(), Value::Double(5.4))]); + let middle_record_variation_1 = Value::Record(vec![( + "middle_field_1".into(), + Value::Union(0, Box::new(Value::Null)), + )]); + let middle_record_variation_2 = Value::Record(vec![( + "middle_field_1".into(), + Value::Union(1, Box::new(inner_record.clone())), + )]); + let outer_record_variation_1 = Value::Record(vec![ + ( + "outer_field_1".into(), + Value::Union(0, Box::new(Value::Null)), + ), + ("outer_field_2".into(), inner_record.clone()), + ]); + let outer_record_variation_2 = Value::Record(vec![ + ( + "outer_field_1".into(), + Value::Union(1, Box::new(middle_record_variation_1)), + ), + ("outer_field_2".into(), inner_record.clone()), + ]); + let outer_record_variation_3 = Value::Record(vec![ + ( + "outer_field_1".into(), + Value::Union(1, Box::new(middle_record_variation_2)), + ), + ("outer_field_2".into(), inner_record), + ]); + + let mut buf = Vec::new(); + encode(&outer_record_variation_1, &schema, &mut buf) + .expect(&success(&outer_record_variation_1, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_record_variation_1, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", + &schema + )) + ); + + let mut buf = Vec::new(); + encode(&outer_record_variation_2, &schema, &mut buf) + .expect(&success(&outer_record_variation_2, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_record_variation_2, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", + &schema + )) + ); + + let mut buf = Vec::new(); + encode(&outer_record_variation_3, &schema, &mut buf) + .expect(&success(&outer_record_variation_3, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_record_variation_3, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", + &schema + )) + ); + + Ok(()) + } + + #[test] + fn avro_3926_encode_decode_uuid_to_string() -> TestResult { + use crate::encode::encode; + + let schema = Schema::String; + let value = Value::Uuid(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000")?); + + let mut buffer = Vec::new(); + encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); + + let result = from_avro_datum(&Schema::Uuid, &mut &buffer[..], None)?; + assert_eq!(result, value); + + Ok(()) + } + + // TODO: Schema::Uuid needs a sub schema which is either String or Fixed. It's now part of the + // spec anyway. + // #[test] + // fn avro_3926_encode_decode_uuid_to_fixed() -> TestResult { + // use crate::encode::encode; + // + // let schema = Schema::Fixed(FixedSchema { + // size: 16, + // name: "uuid".into(), + // aliases: None, + // doc: None, + // default: None, + // attributes: Default::default(), + // }); + // let value = Value::Uuid(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000")?); + // + // let mut buffer = Vec::new(); + // encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); + // + // let result = from_avro_datum(&Schema::Uuid, &mut &buffer[..], None)?; + // assert_eq!(result, value); + // + // Ok(()) + // } +} diff --git a/avro/src/state_machines/reading/object_container_file.rs b/avro/src/state_machines/reading/object_container_file.rs new file mode 100644 index 00000000..8f9c2bd5 --- /dev/null +++ b/avro/src/state_machines/reading/object_container_file.rs @@ -0,0 +1,318 @@ +use crate::{ + Codec, Error, Schema, + error::Details, + schema::{Names, ResolvedSchema, resolve_names, resolve_names_with_schemata}, + state_machines::reading::{ + CommandTape, ItemRead, StateMachine, StateMachineControlFlow, StateMachineResult, + codec::CodecStateMachine, datum::DatumStateMachine, decode_zigzag_buffer, + }, +}; +use log::warn; +use oval::Buffer; +use serde_json::Value; +use std::{collections::HashMap, io::Read, str::FromStr, sync::Arc}; + +// TODO: Dynamically/const construct this, this one works only on 64-bit LE +/// The tape corresponding to [`HEADER_JSON`]. +/// +/// ```json +/// { +/// "type": "record", +/// "name": "org.apache.avro.file.HeaderNoMagic", +/// "fields": [ +/// {"name": "meta", "type": {"type": "map", "values": "bytes"}}, +/// {"name": "sync", "type": {"type": "fixed", "name": "Sync", "size": 16}} +/// ] +/// } +/// ``` +#[rustfmt::skip] +const HEADER_TAPE: &[u8] = &[ + CommandTape::BLOCK | 2 << 4, // Starts with a map + CommandTape::STRING, // The keys are strings + CommandTape::BYTES, // The values are bytes + CommandTape::FIXED, // After the map there is a Fixed amount of bytes + 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // The amount of bytes is 0x0F +]; +#[cfg(test)] +const HEADER_JSON: &str = r#"{"type": "record","name": "org.apache.avro.file.HeaderNoMagic","fields": [{"name": "meta", "type": {"type": "map", "values": "bytes"}},{"name": "sync", "type": {"type": "fixed", "name": "Sync", "size": 16}}]}"#; + +/// The header as read from an Object Container file format. +pub struct ObjectContainerFileHeader { + /// The schema used to write the file. + pub schema: Schema, + pub names: Names, + /// The compression used. + pub codec: Codec, + /// The sync marker used between blocks + pub sync: [u8; 16], + /// User metadata in the header + pub metadata: HashMap>, +} + +impl ObjectContainerFileHeader { + pub fn command_tape() -> CommandTape { + CommandTape::new(Arc::from(HEADER_TAPE)) + } + + /// Create the header from an output tape. + /// + /// # Panics + /// Will panic if the tape was not produced from [`Self::command_tape()`]. + pub fn from_tape(mut tape: Vec, mut schemata: Vec<&Schema>) -> Result { + // We want to read the tape from front to back + let mut tape = tape.drain(..); + + let mut schema = None; + let mut codec = None; + let mut found_compression_level = false; + let mut metadata = HashMap::new(); + let mut names = HashMap::new(); + + while let Some(ItemRead::Block(items_left)) = tape.next() { + if items_left == 0 { + // Got to the end of the map + break; + } + for _ in 0..items_left { + let Some(ItemRead::String(key)) = tape.next() else { + panic!("The input does not correspond to the command tape"); + }; + let Some(ItemRead::Bytes(value)) = tape.next() else { + panic!("The input does not correspond to the command tape"); + }; + + match key.as_ref() { + "avro.schema" => { + if schema.is_some() { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + let json: Value = + serde_json::from_slice(&value).map_err(Details::ParseSchemaJson)?; + + if !schemata.is_empty() { + // TODO: Make parse_with_names accept NamesRef + let schemata = std::mem::take(&mut schemata); + resolve_names_with_schemata(&schemata, &mut names, &None)?; + + // TODO: Maybe we can not do this, and just past &names to Schema::parse_with_names + let rs = ResolvedSchema::try_from(schemata)?; + let names: Names = rs + .get_names() + .iter() + .map(|(name, &schema)| (name.clone(), schema.clone())) + .collect(); + + let parsed_schema = Schema::parse_with_names(&json, names)?; + schema.replace(parsed_schema); + } else { + let parsed_schema = Schema::parse(&json)?; + resolve_names(&parsed_schema, &mut names, &None)?; + schema.replace(parsed_schema); + } + } + "avro.codec" => { + let string = String::from_utf8(value).map_err(Details::ConvertToUtf8)?; + let parsed_codec = Codec::from_str(&string) + .map_err(|_| Details::CodecNotSupported(string))?; + if codec.replace(parsed_codec).is_some() { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + } + "avro.codec.compression_level" => { + // Compression level is not useful for decoding + if found_compression_level { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + found_compression_level = true; + } + _ => { + if key.starts_with("avro.") { + warn!("Ignoring unknown metadata key: {key}"); + } + if metadata.insert(key, value).is_some() { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + } + } + } + } + let Some(schema) = schema else { + return Err(Details::GetHeaderMetadata.into()); + }; + let codec = codec.unwrap_or(Codec::Null); + let Some(ItemRead::Bytes(raw_sync)) = tape.next() else { + panic!("The input does not correspond to the command tape"); + }; + let sync = raw_sync + .as_slice() + .try_into() + .expect("The input does not correspond to the command tape"); + Ok(ObjectContainerFileHeader { + schema, + names, + codec, + sync, + metadata, + }) + } +} + +/// A state machine for parsing the header of the Object Container file format. +/// +/// After finishing this state machine the body can be read with [`ObjectContainerFileBodyStateMachine`]. +pub struct ObjectContainerFileHeaderStateMachine<'a> { + /// The actual state machine used to parse the header. + /// + /// This doesn't actually need to be an [`Option`] as it's constructed in [`Self::new`]. However, + /// as [`StateMachine::parse`] takes `self` we need it in an `Option` so we can do [`Option::take`]. + fsm: Option, + read_magic: bool, + schemata: Vec<&'a Schema>, +} + +impl<'a> ObjectContainerFileHeaderStateMachine<'a> { + pub fn new(schemata: Vec<&'a Schema>) -> Self { + let commands = CommandTape::new(Arc::from(HEADER_TAPE)); + Self { + fsm: Some(DatumStateMachine::new(commands)), + read_magic: false, + schemata, + } + } +} + +impl StateMachine for ObjectContainerFileHeaderStateMachine<'_> { + type Output = ObjectContainerFileHeader; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + while !self.read_magic { + if buffer.available_data() < 4 { + return Ok(StateMachineControlFlow::NeedMore(self)); + } + if buffer.data()[0..4] != [b'O', b'b', b'j', 1] { + return Err(Details::HeaderMagic.into()); + } + buffer.consume(4); + self.read_magic = true; + } + match self.fsm.take().expect("Unreachable!").parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + let _ = self.fsm.insert(fsm); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done(tape) => Ok(StateMachineControlFlow::Done( + ObjectContainerFileHeader::from_tape(tape, self.schemata)?, + )), + } + } +} + +pub struct ObjectContainerFileBodyStateMachine { + fsm: Option>, + tape: CommandTape, + sync: [u8; 16], + left_in_block: usize, + need_to_read_block_byte_size: bool, + need_to_read_sync: bool, +} + +impl ObjectContainerFileBodyStateMachine { + pub fn new(tape: CommandTape, sync: [u8; 16], codec: Codec) -> Self { + Self { + fsm: Some(CodecStateMachine::new( + DatumStateMachine::new(tape.clone()), + codec, + )), + tape, + sync, + left_in_block: 0, + need_to_read_block_byte_size: false, + need_to_read_sync: false, + } + } +} + +impl StateMachine for ObjectContainerFileBodyStateMachine { + type Output = Option<(Vec, Self)>; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + if self.left_in_block == 0 { + if self.need_to_read_sync { + if buffer.available_data() < 16 { + return Ok(StateMachineControlFlow::NeedMore(self)); + } + let mut sync = [0; 16]; + assert_eq!( + buffer.read(&mut sync).expect("Unreachable!"), + 16, + "Did not read enough data!" + ); + if sync != self.sync { + return Err(Details::GetBlockMarker.into()); + } + self.need_to_read_sync = false; + } + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + let abs_block = block.unsigned_abs(); + let abs_block = + usize::try_from(abs_block).map_err(|e| Details::ConvertU64ToUsize(e, abs_block))?; + if abs_block == 0 { + // Done parsing the array + return Ok(StateMachineControlFlow::Done(None)); + } + self.need_to_read_block_byte_size = true; + // This will only be done after this block is finished + self.need_to_read_sync = true; + self.left_in_block = abs_block; + } + if self.need_to_read_block_byte_size { + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + // Make sure the value is sane + let _size = usize::try_from(block).map_err(|e| Details::ConvertI64ToUsize(e, block))?; + self.need_to_read_block_byte_size = false; + } + + match self.fsm.take().expect("Unreachable!").parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.fsm.replace(fsm); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done((result, mut codec)) => { + codec.reset(DatumStateMachine::new(self.tape.clone())); + self.fsm.replace(codec); + self.left_in_block -= 1; + Ok(StateMachineControlFlow::Done(Some((result, self)))) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, sync::Arc}; + + use crate::{ + Schema, + state_machines::reading::{ + commands::CommandTape, + object_container_file::{HEADER_JSON, HEADER_TAPE}, + }, + }; + + #[test] + pub fn header_tape() { + let schema = Schema::parse_str(HEADER_JSON).unwrap(); + let tape = CommandTape::build_from_schema(&schema, &HashMap::new()).unwrap(); + assert_eq!(tape, CommandTape::new(Arc::from(HEADER_TAPE))); + } +} diff --git a/avro/src/state_machines/reading/sync.rs b/avro/src/state_machines/reading/sync.rs new file mode 100644 index 00000000..daab82a0 --- /dev/null +++ b/avro/src/state_machines/reading/sync.rs @@ -0,0 +1,351 @@ +use oval::Buffer; +use serde::Deserialize; +use std::{collections::HashMap, io::Read}; + +use crate::{ + AvroResult, Error, Schema, + error::Details, + schema::{resolve_names, resolve_names_with_schemata}, + state_machines::reading::{ + ItemRead, StateMachine, StateMachineControlFlow, + commands::CommandTape, + datum::DatumStateMachine, + deserialize_from_tape, + object_container_file::{ + ObjectContainerFileBodyStateMachine, ObjectContainerFileHeader, + ObjectContainerFileHeaderStateMachine, + }, + value_from_tape, + }, + types::Value, +}; + +/// Main interface for reading Avro formatted values. +/// +/// To be used as an iterator: +/// +/// ```no_run +/// # use apache_avro::Reader; +/// # use std::io::Cursor; +/// # let input = Cursor::new(Vec::::new()); +/// for value in Reader::new(input).unwrap() { +/// match value { +/// Ok(v) => println!("{:?}", v), +/// Err(e) => println!("Error: {}", e), +/// }; +/// } +/// ``` +pub struct Reader<'a, R> { + reader_schema: Option<&'a Schema>, + header: ObjectContainerFileHeader, + fsm: Option, + reader: R, + buffer: Buffer, +} + +impl<'a, R: Read> Reader<'a, R> { + /// Creates a [`Reader`] that will use the schema from the file header. + /// + /// No reader [`Schema`] will be set. + /// + /// **NOTE** The Avro header is going to be read automatically upon creation of the [`Reader`]. + pub fn new(reader: R) -> Result { + Self::new_inner(reader, None, Vec::new()) + } + + /// Creates a [`Reader`] that will use the given schema for schema resolution. + /// + /// **NOTE** The Avro header is going to be read automatically upon creation of the [`Reader`]. + pub fn with_schema(schema: &'a Schema, reader: R) -> Result { + Self::new_inner(reader, Some(schema), Vec::new()) + } + + /// Creates a [`Reader`] that will use the given schema for schema resolution. + /// + /// `schema` must be in `schemata`. Otherwise, any self-referential [`Schema::Ref`]s can't be + /// resolved and an error will be returned. + /// + /// Any [`Schema::Ref`] will be resolved using the schemata. + /// + /// **NOTE** The avro header is going to be read automatically upon creation of the [`Reader`]. + pub fn with_schemata( + schema: &'a Schema, + schemata: Vec<&'a Schema>, + reader: R, + ) -> Result { + Self::new_inner(reader, Some(schema), schemata) + } + + /// Get a reference to the writer [`Schema`]. + pub fn writer_schema(&self) -> &Schema { + &self.header.schema + } + + /// Get a reference to the optional reader [`Schema`]. + /// + /// This will only be set if there was a reader schema provided *and* it differed from the + /// writer schema. + pub fn reader_schema(&self) -> Option<&'a Schema> { + self.reader_schema + } + + /// Get a reference to the user metadata. + pub fn user_metadata(&self) -> &HashMap> { + &self.header.metadata + } + + /// Get a reference to the file header. + pub fn header(&self) -> &ObjectContainerFileHeader { + &self.header + } + + fn new_inner( + mut reader: R, + reader_schema: Option<&'a Schema>, + schemata: Vec<&'a Schema>, + ) -> Result { + // Read a maximum of 2Kb per read + let mut buffer = Buffer::with_capacity(2 * 1024); + + // Parse the header + let mut fsm = ObjectContainerFileHeaderStateMachine::new(schemata); + let header = loop { + // Fill the buffer + let n = reader.read(buffer.space()).map_err(Details::ReadHeader)?; + if n == 0 { + return Err(Details::ReadHeader(std::io::ErrorKind::UnexpectedEof.into()).into()); + } + buffer.fill(n); + + // Start/continue the state machine + match fsm.parse(&mut buffer)? { + StateMachineControlFlow::NeedMore(new_fsm) => fsm = new_fsm, + StateMachineControlFlow::Done(header) => break header, + } + }; + + let tape = CommandTape::build_from_schema(&header.schema, &header.names)?; + + let reader_schema = if let Some(schema) = reader_schema + && schema != &header.schema + { + Some(schema) + } else { + None + }; + + Ok(Self { + reader_schema, + fsm: Some(ObjectContainerFileBodyStateMachine::new( + tape, + header.sync, + header.codec, + )), + header, + reader, + buffer, + }) + } + + /// Get the next object in the file + fn next_object(&mut self) -> Option, Error>> { + if let Some(mut fsm) = self.fsm.take() { + loop { + match fsm.parse(&mut self.buffer) { + Ok(StateMachineControlFlow::NeedMore(new_fsm)) => { + fsm = new_fsm; + let n = match self.reader.read(self.buffer.space()) { + Ok(0) => { + return Some(Err(Details::ReadIntoBuf( + std::io::ErrorKind::UnexpectedEof.into(), + ) + .into())); + } + Ok(n) => n, + Err(e) => return Some(Err(Details::ReadIntoBuf(e).into())), + }; + self.buffer.fill(n); + } + Ok(StateMachineControlFlow::Done(Some((object, fsm)))) => { + self.fsm.replace(fsm); + return Some(Ok(object)); + } + Ok(StateMachineControlFlow::Done(None)) => { + return None; + } + Err(e) => { + return Some(Err(e)); + } + } + } + } + None + } + + /// Deserialize the next object directly to `T`. + /// + /// This function goes immediately from the inner representation to `T` without going through + /// [`Value`] first. It does not support schema resolution using a reader [`Schema`]. + /// + /// # Panics + /// Will panic if a reader [`Schema`] was supplied when creating the [`Reader`]. + pub fn next_serde<'b, T: Deserialize<'b>>(&mut self) -> Option> { + assert!( + self.reader_schema.is_none(), + "Schema resolution is not supported with this function!" + ); + self.next_object() + .map(|r| r.and_then(|mut tape| deserialize_from_tape(&mut tape, &self.header.schema))) + } +} + +impl Iterator for Reader<'_, R> { + type Item = Result; + + fn next(&mut self) -> Option { + self.next_object().map(|r| { + r.and_then(|mut tape| { + value_from_tape(&mut tape, &self.header.schema, &self.header.names) + }) + .and_then(|v| { + if let Some(schema) = &self.reader_schema { + v.resolve_internal(schema, &self.header.names, &None, &None) + } else { + Ok(v) + } + }) + }) + } +} + +/// Decode a raw Avro datum using the provided [`Schema`]. +/// +/// In case a reader [`Schema`] is provided, schema resolution will be performed. +/// +/// **NOTE** This function is very niche and does NOT take care of reading the header and +/// consecutive data blocks. use [`Reader`] if you just want to read an Avro encoded file. +pub fn from_avro_datum( + writer_schema: &Schema, + reader: &mut R, + reader_schema: Option<&Schema>, +) -> AvroResult { + from_avro_datum_reader_schemata(writer_schema, Vec::new(), reader, reader_schema, Vec::new()) +} + +/// Decode a raw Avro datum using the provided [`Schema`] and schemata. +/// +/// `schema` must be in `schemata`. Otherwise, any self-referential [`Schema::Ref`]s can't be +/// resolved and an error will be returned. +/// +/// If the writer schema contains any [`Schema::Ref`] then it will use the provided +/// schemata to resolve any dependencies. +/// +/// In case a reader [`Schema`] is provided, schema resolution will be performed. +pub fn from_avro_datum_schemata( + writer_schema: &Schema, + writer_schemata: Vec<&Schema>, + reader: &mut R, + reader_schema: Option<&Schema>, +) -> AvroResult { + from_avro_datum_reader_schemata( + writer_schema, + writer_schemata, + reader, + reader_schema, + Vec::new(), + ) +} + +/// Decode a raw Avro datum using the provided [`Schema`] and schemata. +/// +/// `schema` must be in `schemata`. Otherwise, any self-referential [`Schema::Ref`]s can't be +/// resolved and an error will be returned. +/// +/// If the writer schema contains any [`Schema::Ref`] then it will use the provided +/// schemata to resolve any dependencies. +/// +/// In case a reader [`Schema`] is provided, schema resolution will be performed. +// TODO: These should really be a reusable reader, as quite a lot of work is done on creation +pub fn from_avro_datum_reader_schemata( + writer_schema: &Schema, + writer_schemata: Vec<&Schema>, + reader: &mut R, + reader_schema: Option<&Schema>, + reader_schemata: Vec<&Schema>, +) -> AvroResult { + let mut names = HashMap::new(); + if writer_schemata.is_empty() { + resolve_names(writer_schema, &mut names, &None)?; + } else { + resolve_names_with_schemata(&writer_schemata, &mut names, &None)?; + } + + let tape = CommandTape::build_from_schema(writer_schema, &names)?; + + // Read a maximum of 2Kb per read + let mut buffer = Buffer::with_capacity(2 * 1024); + let mut fsm = DatumStateMachine::new(tape); + let value = loop { + // Fill the buffer + let n = reader.read(buffer.space()).map_err(Details::ReadIntoBuf)?; + if n == 0 { + // If the writer schema is null, this is expected and we just return a null value + if matches!(writer_schema, &Schema::Null) { + break Value::Null; + } + return Err(Details::ReadIntoBuf(std::io::ErrorKind::UnexpectedEof.into()).into()); + } + buffer.fill(n); + + match fsm.parse(&mut buffer)? { + StateMachineControlFlow::NeedMore(new_fsm) => { + fsm = new_fsm; + } + StateMachineControlFlow::Done(mut tape) => { + break value_from_tape(&mut tape, writer_schema, &names)?; + } + } + }; + match reader_schema { + Some(schema) => { + if reader_schemata.is_empty() { + value.resolve(schema) + } else { + value.resolve_schemata(schema, reader_schemata) + } + } + None => Ok(value), + } +} + +#[cfg(test)] +mod tests { + use crate::{Schema, Writer, state_machines::reading::sync::Reader, types::Value}; + use std::io::Cursor; + + /// Test it reads all the sync markers + #[test] + fn sync_markers() { + let mut writer = Writer::new(&Schema::String, Vec::new()); + writer.append(Value::String("Hello".to_string())).unwrap(); + writer.flush().unwrap(); + writer.append(Value::String("World".to_string())).unwrap(); + writer.flush().unwrap(); + let mut written = Cursor::new(writer.into_inner().unwrap()); + + let mut reader = Reader::new(&mut written).unwrap(); + assert_eq!( + reader.next().unwrap().unwrap(), + Value::String("Hello".to_string()) + ); + assert_eq!( + reader.next().unwrap().unwrap(), + Value::String("World".to_string()) + ); + + drop(reader); + let position = written.position(); + let expected = written.into_inner().len(); + assert_eq!(position, expected as u64); + } +} diff --git a/avro/src/state_machines/reading/union.rs b/avro/src/state_machines/reading/union.rs new file mode 100644 index 00000000..ac5c5386 --- /dev/null +++ b/avro/src/state_machines/reading/union.rs @@ -0,0 +1,78 @@ +use crate::{ + error::Details, + state_machines::reading::{ + ItemRead, StateMachine, StateMachineControlFlow, StateMachineResult, SubStateMachine, + commands::CommandTape, decode_zigzag_buffer, + }, +}; +use oval::Buffer; + +enum VariantsOrFsm { + Variants { + variants: CommandTape, + read: Vec, + }, + Fsm(Box), +} + +pub struct UnionStateMachine { + variants_or_fsm: VariantsOrFsm, + num_variants: usize, +} + +impl UnionStateMachine { + pub fn new_with_tape(variants: CommandTape, num_variants: usize, read: Vec) -> Self { + Self { + variants_or_fsm: VariantsOrFsm::Variants { variants, read }, + num_variants, + } + } +} + +impl StateMachine for UnionStateMachine { + type Output = Vec; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + match self.variants_or_fsm { + VariantsOrFsm::Variants { + mut variants, + mut read, + } => { + let Some(index) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + self.variants_or_fsm = VariantsOrFsm::Variants { variants, read }; + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + let option = + usize::try_from(index).map_err(|e| Details::ConvertI64ToUsize(e, index))?; + + variants.skip(option).ok_or(Details::GetUnionVariant { + index, + num_variants: self.num_variants, + })?; + + let variant = variants.command().ok_or(Details::GetUnionVariant { + index, + num_variants: self.num_variants, + })?; + + read.push(ItemRead::Union(u32::try_from(option).unwrap())); + + match variant.into_state_machine(read).parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.variants_or_fsm = VariantsOrFsm::Fsm(Box::new(fsm)); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + } + } + VariantsOrFsm::Fsm(fsm) => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.variants_or_fsm = VariantsOrFsm::Fsm(Box::new(fsm)); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + }, + } + } +} diff --git a/avro/src/types.rs b/avro/src/types.rs index 4448eef2..12e5909b 100644 --- a/avro/src/types.rs +++ b/avro/src/types.rs @@ -639,6 +639,7 @@ impl Value { mut self, schema: &Schema, names: &HashMap, + // TODO: These two should be Option<&T> instead of &Option enclosing_namespace: &Namespace, field_default: &Option, ) -> AvroResult { diff --git a/avro/src/util.rs b/avro/src/util.rs index a751fcd5..809e339e 100644 --- a/avro/src/util.rs +++ b/avro/src/util.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::{AvroResult, error::Details, schema::Documentation}; +use crate::{AvroResult, Error, error::Details, schema::Documentation}; use serde_json::{Map, Value}; use std::{ - io::{Read, Write}, + io::Write, sync::{ Once, atomic::{AtomicBool, AtomicUsize, Ordering}, @@ -75,10 +75,6 @@ impl MapHelper for Map { } } -pub fn read_long(reader: &mut R) -> AvroResult { - zag_i64(reader) -} - pub fn zig_i32(n: i32, buffer: W) -> AvroResult { zig_i64(n as i64, buffer) } @@ -87,20 +83,6 @@ pub fn zig_i64(n: i64, writer: W) -> AvroResult { encode_variable(((n << 1) ^ (n >> 63)) as u64, writer) } -pub fn zag_i32(reader: &mut R) -> AvroResult { - let i = zag_i64(reader)?; - i32::try_from(i).map_err(|e| Details::ZagI32(e, i).into()) -} - -pub fn zag_i64(reader: &mut R) -> AvroResult { - let z = decode_variable(reader)?; - Ok(if z & 0x1 == 0 { - (z >> 1) as i64 - } else { - !(z >> 1) as i64 - }) -} - fn encode_variable(mut z: u64, mut writer: W) -> AvroResult { let mut buffer = [0u8; 10]; let mut i: usize = 0; @@ -120,28 +102,66 @@ fn encode_variable(mut z: u64, mut writer: W) -> AvroResult { .map_err(|e| Details::WriteBytes(e).into()) } -fn decode_variable(reader: &mut R) -> AvroResult { - let mut i = 0u64; - let mut buf = [0u8; 1]; +/// Decode a zigzag encoded length. +/// +/// This version of [`decode_len`] will return a [`Details::ReadVariableIntegerBytes`] error if there are not +/// enough bytes and does not return the amount of bytes read. +/// +/// See [`decode_len`] for more details. +pub fn decode_len_simple(buffer: &[u8]) -> AvroResult<(usize, usize)> { + decode_len(buffer)?.ok_or_else(|| { + Details::ReadVariableIntegerBytes(std::io::ErrorKind::UnexpectedEof.into()).into() + }) +} - let mut j = 0; - loop { - if j > 9 { - // if j * 7 > 64 - return Err(Details::IntegerOverflow.into()); - } - reader - .read_exact(&mut buf[..]) - .map_err(Details::ReadVariableIntegerBytes)?; - i |= (u64::from(buf[0] & 0x7F)) << (j * 7); - if (buf[0] >> 7) == 0 { +/// Decode a zigzag encoded length. +/// +/// This will use [`safe_len`] to check if the length is in allowed bounds. +/// +/// # Returns +/// `Some(integer, bytes read)` if it completely read an integer, `None` if it did not have enough +/// bytes in the slice. +pub fn decode_len(buffer: &[u8]) -> AvroResult> { + if let Some((integer, bytes)) = decode_variable(buffer)? { + let length = + usize::try_from(integer).map_err(|e| Details::ConvertI64ToUsize(e, integer))?; + let safe = safe_len(length)?; + Ok(Some((safe, bytes))) + } else { + Ok(None) + } +} + +/// Decode a zigzag encoded integer. +/// +/// # Returns +/// `Some(integer, bytes read)` if it completely read an integer, `None` if it did not have enough +/// bytes in the slice. +pub fn decode_variable(buffer: &[u8]) -> Result, Error> { + let mut decoded = 0; + let mut loops_done = 0; + let mut last_byte = 0; + + for (counter, &byte) in buffer.iter().take(10).enumerate() { + decoded |= u64::from(byte & 0x7F) << (counter * 7); + loops_done = counter; + last_byte = byte; + if byte >> 7 == 0 { break; - } else { - j += 1; } } - Ok(i) + if last_byte >> 7 != 0 { + if loops_done == 10 { + Err(Details::IntegerOverflow.into()) + } else { + Ok(None) + } + } else if decoded & 0x1 == 0 { + Ok(Some(((decoded >> 1) as i64, loops_done + 1))) + } else { + Ok(Some((!(decoded >> 1) as i64, loops_done + 1))) + } } /// Set a new maximum number of bytes that can be allocated when decoding data. @@ -282,8 +302,8 @@ mod tests { #[test] fn test_overflow() { - let causes_left_shift_overflow: &[u8] = &[0xe1, 0xe1, 0xe1, 0xe1, 0xe1]; - assert!(decode_variable(&mut &*causes_left_shift_overflow).is_err()); + let not_enough_bytes: &[u8] = &[0xe1, 0xe1, 0xe1, 0xe1, 0xe1]; + assert!(decode_variable(not_enough_bytes).unwrap().is_none()); } #[test] diff --git a/avro_derive/tests/derive.proptest-regressions b/avro_derive/tests/derive.proptest-regressions new file mode 100644 index 00000000..093d3789 --- /dev/null +++ b/avro_derive/tests/derive.proptest-regressions @@ -0,0 +1,12 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 7808c1b336ad8808516f0338a69eb961f7b70c7700e6d00f528b86c9ec48b9e7 # shrinks to a = false, b = 0, c = 0, d = 0, e = 0, f = 0, g = -4611686018427387905, h = 0.0, i = 0.0, j = "" +cc f9d4f3f6442d2c9a718a23a9450e94ec2cef74ac62776b068581376082f0aace # shrinks to a = 0, b = "" +cc 7bc3220d3d8a41cdbf5ab3618aeb4c8dd8268b524c1feb510d4ad39d104b261a # shrinks to a = "", b = [], c = {} +cc 2d98b332f36a28eba7e4a8d5e97856a4ba7ea54defb427285a70129cb4782b1a # shrinks to a = false, b = 0, c = 0, d = 0, e = 0, f = 0, g = 0, h = 0.0, i = -0.0, j = "" +cc 5b163ba13d686d86e78c8139e88278720019789f6841e58ead62807244f7582a # shrinks to a = false, b = 0, c = 0, d = 0, e = 0, f = 0, g = 0, h = 0.0, i = 0.0, j = "" +cc aa33e374dc52ac66d49a7bad56cb62eeb80c60f42c9dc8e831db1f581d4a2b07 # shrinks to a = false, b = 0, c = 0, d = 0, e = 0, f = 0, g = 0, h = 0.0, i = 0.0, j = "", aa = 0 diff --git a/rfc.md b/rfc.md new file mode 100644 index 00000000..1c2c88dc --- /dev/null +++ b/rfc.md @@ -0,0 +1,90 @@ +# Possible implementations + +## Maintaining two separate implementations + +Pros: + - Easy to implement (just copy-paste the blocking implementation and start inserting `async`/`await`) + - Allows for optimal performance in both situations + - *Should* be able to share at least a part of the implementation + +Cons: + - Maintenance, any bug needs to be fixed in both implementations. Same goes for testing. + - Hard to onboard, new contributors will be confronted with a very large codebase (see [Good ol' copy-pasting](https://nullderef.com/blog/rust-async-sync/#good-ol-copy-pasting)) + - Adding new functionality means implementing it twice. + +## Implement in async, use `block_on` for sync implementation + +In this implementation, the core codebase is implemented asynchronously. A `blocking` module is provided which wraps +the async functions/types in `block_on` calls. Recreating the runtime on every call is very slow, so to make this work +it would involve spawning a thread for the runtime and using that to spawn the async functions. This is how `reqwest` +implements their async/sync code. + +Pros: + - Only need to maintain/test/upgrade one implementation + - Optimal performance for async code + +Cons: + - Degrades sync performance + - Need to pull in a runtime when the `blocking` feature is enabled (`reqwest` use `tokio` but something like `smoll` might make more sense) + +## Implement in async, use `maybe_async` to generate sync implementation + +[`maybe_async`](https://crates.io/crates/maybe-async) is a proc macro that removes the `.await` from the async code and uses it to generate sync code. [`synca`](https://docs.rs/synca/latest/synca/) is another option where both sync and async code can coexist. + +Pros: + - Only need to maintain/test/upgrade one implementation + - Optimal performance for both async and sync code + +Cons: + - Crate breaks if both the `sync` and `async` features are enabled (only for `maybe_async`) + - `synca` hasn't seen an update in more than a year, but seems to be feature complete + +## Sans I/O + +Implement the parser as a state machine that can be driven by both async and sync code. This is how [`rc-zip`](https://lib.rs/crates/rc-zip) +is implemented. + +Pros: + - Only need to maintain/test/upgrade one implementation + - Optimal performance for both async and sync code + +Cons: + - Have to manually implement the state machines + - In the distant future [it's possible to use coroutines/generators](https://internals.rust-lang.org/t/using-coroutines-for-a-sans-io-parser/22968), but they're currently *very* unstable. + - You can use async functions to generate the state machines for you, [according to this blogpost](https://jeffmcbride.net/blog/2025/05/16/rust-async-functions-as-state-machines/) + +## Do not provide an async implementation + +Pros: + - Easiest option, nothing has to change + +Cons: + - An async implementation is really nice for using Avro over the network + +# Serde + +One problem not mentioned yet, is that Serde does not have an async interface. This doesn't necessarily have to be a problem. +The current deserialize implementation also first decodes a `avro::Value` and then uses that to deserialize the Serde type (reverse for serialize). +The decoding to `avro::Value` can be made async, and then the serde part can be done in a sync way as it does not use any I/O. + +Some alternative options: +- [tokio-serde](https://docs.rs/tokio-serde/latest/tokio_serde/index.html) + - A wrapper around Serde that requires the user to split the input into frames containing one object. +- [destream](https://docs.rs/destream/0.9.0/destream/index.html) + - Async versions of the Serde traits, but not compatible with serde so lacks ecosystem support. + +# Best option? + +I'm currently leaning towards implementing Sans I/O. It provides an (almost) optimal implementation for both async and sync code. +It doesn't duplicate code (except the interfaces) and doesn't require pulling in any runtime (only parts of `futures`). + +Care needs to be taken that the state machines are kept small and understandable. + +The second-best option is probably using `block_on` in a separate thread. But that seems unnecessarily heavy. + +# References + +- [Blog post by the maintainer of `RSpotify` who tried multiple of the above options](https://nullderef.com/blog/rust-async-sync/) +- [A discussion about Sans I/O](https://sdr-podcast.com/episodes/sans-io/) +- [A explanation of Sans I/O by the author of `rc-zip`](https://fasterthanli.me/articles/the-case-for-sans-io) + - The blog post is currently not freely available, but the [video](https://www.youtube.com/watch?v=RYHYiXMJdZI) (which has the exact same content) is freely available diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..9fabc05e --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,2 @@ +imports_granularity="Crate" +