diff --git a/Cargo.lock b/Cargo.lock index d34974e3..19254305 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,49 +19,58 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "0.2.6" +version = "0.6.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "342258dd14006105c2b75ab1bd7543a03bdf0cfc94383303ac212a04939dff6f" +checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" dependencies = [ "anstyle", "anstyle-parse", + "anstyle-query", "anstyle-wincon", - "concolor-override", - "concolor-query", - "is-terminal", + "colorchoice", + "is_terminal_polyfill", "utf8parse", ] [[package]] name = "anstyle" -version = "0.3.5" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23ea9e81bd02e310c216d080f6223c179012256e5151c41db88d12c88a1684d2" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anstyle-parse" -version = "0.1.1" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7d1bb534e9efed14f3e5f44e7dd1a4f709384023a4165199a4241e18dff0116" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" dependencies = [ "utf8parse", ] +[[package]] +name = "anstyle-query" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "anstyle-wincon" -version = "0.2.0" +version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3127af6145b149f3287bb9a0d10ad9c5692dba8c53ad48285e5bec4063834fa" +checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" dependencies = [ "anstyle", - "windows-sys 0.45.0", + "windows-sys 0.52.0", ] [[package]] name = "anyhow" -version = "1.0.79" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "autocfg" @@ -161,33 +170,31 @@ dependencies = [ [[package]] name = "clap" -version = "4.2.1" +version = "4.5.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "046ae530c528f252094e4a77886ee1374437744b2bff1497aa898bbddbbb29b3" +checksum = "ed93b9805f8ba930df42c2590f05453d5ec36cbb85d018868a5b24d31f6ac000" dependencies = [ "clap_builder", "clap_derive", - "once_cell", ] [[package]] name = "clap_builder" -version = "4.2.1" +version = "4.5.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "223163f58c9a40c3b0a43e1c4b50a9ce09f007ea2cb1ec258a687945b4b7929f" +checksum = "379026ff283facf611b0ea629334361c4211d1b12ee01024eec1591133b04120" dependencies = [ "anstream", "anstyle", - "bitflags", "clap_lex", - "strsim 0.10.0", + "strsim", ] [[package]] name = "clap_derive" -version = "4.2.0" +version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9644cd56d6b87dbe899ef8b053e331c0637664e9e21a33dfcdc36093f5c5c4" +checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" dependencies = [ "heck", "proc-macro2", @@ -197,24 +204,15 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.4.1" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" [[package]] -name = "concolor-override" -version = "1.0.0" +name = "colorchoice" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a855d4a1978dc52fb0536a04d384c2c0c1aa273597f08b77c8c4d3b2eec6037f" - -[[package]] -name = "concolor-query" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf" -dependencies = [ - "windows-sys 0.45.0", -] +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "cons-rs" @@ -309,7 +307,7 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "strsim 0.11.1", + "strsim", "syn", ] @@ -388,6 +386,17 @@ dependencies = [ "serde_json", ] +[[package]] +name = "doodle-rec" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "doodle", + "serde", + "serde_json", +] + [[package]] name = "doodle_gencode" version = "0.1.0" @@ -552,9 +561,9 @@ dependencies = [ [[package]] name = "heck" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" @@ -600,6 +609,12 @@ dependencies = [ "windows-sys 0.45.0", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itertools" version = "0.10.5" @@ -919,18 +934,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.159" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.159" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", @@ -939,21 +954,16 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.95" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] -[[package]] -name = "strsim" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" - [[package]] name = "strsim" version = "0.11.1" diff --git a/Cargo.toml b/Cargo.toml index dc36b8c2..76067621 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = [".", "generated/", "doodle-formats/"] +members = [".", "generated/", "doodle-formats/", "experiments/doodle-rec"] [package] name = "doodle" diff --git a/experiments/doodle-rec/Cargo.toml b/experiments/doodle-rec/Cargo.toml new file mode 100644 index 00000000..2f4401fe --- /dev/null +++ b/experiments/doodle-rec/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "doodle-rec" +version = "0.1.0" +edition = "2024" + +[lib] +path = "src/lib.rs" +bench = false + +[dependencies] +anyhow = "1.0.98" +clap = "4.5.38" +doodle = { path = "../../" } +linked_hash_set = "0.1.5" +serde = { version = "1.0.219", features = ["derive"] } +serde_json = "1.0.140" diff --git a/experiments/doodle-rec/src/decoder.rs b/experiments/doodle-rec/src/decoder.rs new file mode 100644 index 00000000..b44ed60e --- /dev/null +++ b/experiments/doodle-rec/src/decoder.rs @@ -0,0 +1,583 @@ +use serde::Serialize; +use std::{ + collections::{HashMap, HashSet}, + rc::Rc, +}; + +use crate::{ + Arith, Expr, Format, FormatDecl, FormatId, FormatModule, FormatType, IntRel, Label, RecId, + RecurseCtx, Span, Unary, + matchtree::{MatchTree, Next}, +}; +use anyhow::{Result as AResult, anyhow}; +use doodle::{IntWidth, byte_set::ByteSet, read::ReadCtxt}; + +#[derive(Debug, Clone, Serialize)] +pub enum Value { + // Primitive values + U8(u8), + U16(u16), + U32(u32), + U64(u64), + Bool(bool), + Char(char), + + // Shape-based values + Tuple(Vec), + Seq(Vec), + Option(Option>), + Variant(Label, Box), +} + +impl Value { + fn get_usize_with_precision(&self) -> (usize, IntWidth) { + match self { + Value::U8(n) => (*n as usize, IntWidth::Bits8), + Value::U16(n) => (*n as usize, IntWidth::Bits16), + Value::U32(n) => (*n as usize, IntWidth::Bits32), + Value::U64(n) => (*n as usize, IntWidth::Bits64), + _ => panic!("value is not a number: {self:?}"), + } + } + + pub(crate) fn unwrap_bool(&self) -> bool { + match self { + Value::Bool(b) => *b, + _ => panic!("value is not a bool"), + } + } +} + +#[derive(Debug, Clone)] +pub struct Program { + pub decoders: Vec<(Decoder, FormatType)>, +} + +impl Program { + fn new() -> Self { + let decoders = Vec::new(); + Program { decoders } + } + + pub fn run<'input>(&self, input: ReadCtxt<'input>) -> DecodeResult<(Value, ReadCtxt<'input>)> { + self.decoders[0].0.parse(self, input) + } +} + +type Batch = Option>; + +pub struct Compiler<'a> { + module: &'a FormatModule, + program: Program, + decoder_map: HashMap<(usize, Rc>), usize>, + compile_queue: Vec<(&'a Format, Rc>, usize, Batch)>, +} + +impl<'a> Compiler<'a> { + fn new(module: &'a FormatModule) -> Self { + let program = Program::new(); + let decoder_map = HashMap::new(); + let compile_queue = Vec::new(); + Compiler { + module, + program, + decoder_map, + compile_queue, + } + } + + pub fn compile_program( + module: &FormatModule, + format: &Format, + ctx: RecurseCtx, + ) -> AResult { + let mut compiler = Compiler::new(module); + + let mut visited = HashSet::new(); + + let batch = ctx.as_span(); + + let t = format.infer_type(&mut visited, module, batch)?; + compiler.queue_compile(t, format, Rc::new(Next::Empty), batch); + while let Some((f, next, n, batch)) = compiler.compile_queue.pop() { + let f_ctx = match batch { + Some(span) => RecurseCtx::Recurse { + span, + batch: &module.decls[span.start..=span.end], + entry_id: n - span.start, + }, + None => RecurseCtx::NonRec, + }; + let d = compiler.compile_format(f, next, f_ctx)?; + compiler.program.decoders[n].0 = d; + } + Ok(compiler.program) + } + + fn queue_compile( + &mut self, + t: FormatType, + f: &'a Format, + next: Rc>, + batch: Option>, + ) -> usize { + let n = self.program.decoders.len(); + self.program.decoders.push((Decoder::FAIL, t)); + self.compile_queue.push((f, next, n, batch)); + n + } + + fn queue_compile_batch( + &mut self, + decls: &'a [FormatDecl], + which_next: RecId, + next: Rc>, + span: Span, + ) -> usize { + let n = self.program.decoders.len(); + for (ix, d) in decls.into_iter().enumerate() { + let t = d.solve_type(self.module).unwrap().clone(); + self.program.decoders.push((Decoder::FAIL, t)); + let next = if ix == which_next { + next.clone() + } else { + Rc::new(Next::Empty) + }; + self.compile_queue + .push((&d.format, next, n + ix, Some(span))); + } + n + which_next + } + + pub fn compile_one(format: &Format) -> AResult { + let module = FormatModule::new(); + let mut compiler = Compiler::new(&module); + let ctx = RecurseCtx::NonRec; + compiler.compile_format(format, Rc::new(Next::Empty), ctx) + } + + fn compile_format( + &mut self, + format: &'a Format, + next: Rc>, + ctx: RecurseCtx<'a>, + ) -> AResult { + match format { + Format::ItemVar(level) => { + let f = self.module.get_format(*level); + let next = if f.depends_on_next(self.module, ctx) { + next + } else { + Rc::new(Next::Empty) + }; + let n = if let Some(n) = self.decoder_map.get(&(*level, next.clone())) { + *n + } else { + let t = self.module.get_format_type(*level).clone(); + let n = match self.module.get_batch(*level) { + Some(span) => { + let batch = &self.module.decls[span.start..=span.end]; + self.queue_compile_batch(batch, level - span.start, next.clone(), span) + } + None => self.queue_compile(t, f, next.clone(), None), + }; + self.decoder_map.insert((*level, next.clone()), n); + n + }; + Ok(Decoder::Call(n)) + } + Format::RecVar(batch_ix) => { + let (new_ctx, _) = ctx.enter(*batch_ix); + let level = new_ctx.get_level().unwrap(); + // REVIEW - do we need to do any work here? + Ok(Decoder::CallRec(level, *batch_ix)) + } + Format::FailWith(msg) => Ok(Decoder::FailWith(msg.clone())), + Format::EndOfInput => Ok(Decoder::EndOfInput), + Format::Byte(bs) => Ok(Decoder::Byte(*bs)), + Format::Variant(label, f) => { + let d = self.compile_format(f, next.clone(), ctx)?; + Ok(Decoder::Variant(label.clone(), Box::new(d))) + } + Format::Compute(expr) => Ok(Decoder::Compute(expr.clone())), + Format::Union(branches) => { + let mut ds = Vec::with_capacity(branches.len()); + for f in branches { + ds.push(self.compile_format(f, next.clone(), ctx)?); + } + if let Some(tree) = MatchTree::build(self.module, branches, next, ctx) { + Ok(Decoder::Branch(tree, ds)) + } else { + Err(anyhow!("cannot build match tree for {:?}", format)) + } + } + Format::Tuple(elems) => { + let mut decs = Vec::with_capacity(elems.len()); + let mut fields = elems.iter(); + while let Some(f) = fields.next() { + let next = Rc::new(Next::Sequence(fields.as_slice(), next.clone())); + let df = self.compile_format(f, next, ctx)?; + decs.push(df); + } + Ok(Decoder::Tuple(decs)) + } + Format::Seq(elems) => { + let mut decs = Vec::with_capacity(elems.len()); + let mut fields = elems.iter(); + while let Some(f) = fields.next() { + let next = Rc::new(Next::Sequence(fields.as_slice(), next.clone())); + let df = self.compile_format(f, next, ctx)?; + decs.push(df); + } + Ok(Decoder::Seq(decs)) + } + Format::Repeat(a) => { + if a.is_nullable(self.module) { + return Err(anyhow!("cannot repeat nullable format: {a:?}")); + } + let da = self.compile_format(a, Rc::new(Next::Repeat(a, next.clone())), ctx)?; + let astar = Format::Repeat(a.clone()); + let fa = Format::Tuple(vec![(**a).clone(), astar]); + let fb = Format::EMPTY; + if let Some(tree) = MatchTree::build(self.module, &[fa, fb], next, ctx) { + Ok(Decoder::While(tree, Box::new(da))) + } else { + Err(anyhow!("cannot build match tree for {:?}", format)) + } + } + Format::Maybe(x, a) => { + let da = Box::new(self.compile_format(a, Rc::new(Next::Empty), ctx)?); + Ok(Decoder::Maybe(x.clone(), da)) + } + } + } +} + +impl Expr { + pub fn eval(&self) -> Value { + match self { + Expr::U8(i) => Value::U8(*i), + Expr::U16(i) => Value::U16(*i), + Expr::U32(i) => Value::U32(*i), + Expr::U64(i) => Value::U64(*i), + Expr::Bool(b) => Value::Bool(*b), + + Expr::AsChar(expr) => match expr.eval() { + Value::U8(x) => Value::Char(char::from(x)), + Value::U16(x) => { + Value::Char(char::from_u32(x as u32).unwrap_or(char::REPLACEMENT_CHARACTER)) + } + Value::U32(x) => { + Value::Char(char::from_u32(x).unwrap_or(char::REPLACEMENT_CHARACTER)) + } + Value::U64(x) => Value::Char( + char::from_u32(u32::try_from(x).unwrap()) + .unwrap_or(char::REPLACEMENT_CHARACTER), + ), + _ => panic!("AsChar: expected U8, U16, U32, or U64"), + }, + Expr::AsU8(x) => { + match x.eval() { + Value::U8(x) => Value::U8(x), + Value::U16(x) => Value::U8(u8::try_from(x).unwrap_or_else(|err| { + panic!("cannot perform AsU8 cast on u16 {x}: {err}") + })), + Value::U32(x) => Value::U8(u8::try_from(x).unwrap_or_else(|err| { + panic!("cannot perform AsU8 cast on u32 {x}: {err}") + })), + Value::U64(x) => Value::U8(u8::try_from(x).unwrap_or_else(|err| { + panic!("cannot perform AsU8 cast on u64 {x}: {err}") + })), + x => panic!("cannot convert {x:?} to U8"), + } + } + + Expr::AsU16(x) => match x.eval() { + Value::U8(x) => Value::U16(u16::from(x)), + Value::U16(x) => Value::U16(x), + Value::U32(x) => Value::U16(u16::try_from(x).unwrap()), + Value::U64(x) => Value::U16(u16::try_from(x).unwrap()), + x => panic!("cannot convert {x:?} to U16"), + }, + Expr::AsU32(x) => match x.eval() { + Value::U8(x) => Value::U32(u32::from(x)), + Value::U16(x) => Value::U32(u32::from(x)), + Value::U32(x) => Value::U32(x), + Value::U64(x) => Value::U32(u32::try_from(x).unwrap()), + x => panic!("cannot convert {x:?} to U32"), + }, + Expr::AsU64(x) => match x.eval() { + Value::U8(x) => Value::U64(u64::from(x)), + Value::U16(x) => Value::U64(u64::from(x)), + Value::U32(x) => Value::U64(u64::from(x)), + Value::U64(x) => Value::U64(x), + x => panic!("cannot convert {x:?} to U64"), + }, + Expr::Seq(exprs) => Value::Seq(exprs.iter().map(Expr::eval).collect()), + Expr::Tuple(exprs) => Value::Tuple(exprs.iter().map(Expr::eval).collect()), + Expr::LiftOption(None) => Value::Option(None), + Expr::LiftOption(Some(expr)) => Value::Option(Some(Box::new(expr.eval()))), + Expr::Variant(lab, expr) => Value::Variant(lab.clone(), Box::new(expr.eval())), + Expr::IntRel(rel, lhs, rhs) => { + let lhs = lhs.eval(); + let rhs = rhs.eval(); + let (l, _lw) = lhs.get_usize_with_precision(); + let (r, _rw) = rhs.get_usize_with_precision(); + if _lw != _rw { + panic!("cannot compare {lhs:?} with {rhs:?}"); + } + match rel { + IntRel::Eq => Value::Bool(l == r), + IntRel::Lt => Value::Bool(l < r), + IntRel::Gt => Value::Bool(l > r), + IntRel::Neq => Value::Bool(l != r), + IntRel::Lte => Value::Bool(l <= r), + IntRel::Gte => Value::Bool(l >= r), + } + } + Expr::Arith(Arith::Add, lhs, rhs) => match (lhs.eval(), rhs.eval()) { + (Value::U8(l), Value::U8(r)) => Value::U8(l.checked_add(r).unwrap()), + (Value::U16(l), Value::U16(r)) => Value::U16(l.checked_add(r).unwrap()), + (Value::U32(l), Value::U32(r)) => Value::U32(l.checked_add(r).unwrap()), + (Value::U64(l), Value::U64(r)) => Value::U64(l.checked_add(r).unwrap()), + (l, r) => panic!("cannot add {l:?} and {r:?}"), + }, + Expr::Arith(Arith::Sub, lhs, rhs) => match (lhs.eval(), rhs.eval()) { + (Value::U8(l), Value::U8(r)) => Value::U8(l.checked_sub(r).unwrap()), + (Value::U16(l), Value::U16(r)) => Value::U16(l.checked_sub(r).unwrap()), + (Value::U32(l), Value::U32(r)) => Value::U32(l.checked_sub(r).unwrap()), + (Value::U64(l), Value::U64(r)) => Value::U64(l.checked_sub(r).unwrap()), + (l, r) => panic!("cannot subtract {l:?} and {r:?}"), + }, + Expr::Arith(Arith::Mul, lhs, rhs) => match (lhs.eval(), rhs.eval()) { + (Value::U8(l), Value::U8(r)) => Value::U8(l.checked_mul(r).unwrap()), + (Value::U16(l), Value::U16(r)) => Value::U16(l.checked_mul(r).unwrap()), + (Value::U32(l), Value::U32(r)) => Value::U32(l.checked_mul(r).unwrap()), + (Value::U64(l), Value::U64(r)) => Value::U64(l.checked_mul(r).unwrap()), + (l, r) => panic!("cannot multiply {l:?} and {r:?}"), + }, + Expr::Arith(Arith::Div, lhs, rhs) => match (lhs.eval(), rhs.eval()) { + (Value::U8(l), Value::U8(r)) => Value::U8(l.checked_div(r).unwrap()), + (Value::U16(l), Value::U16(r)) => Value::U16(l.checked_div(r).unwrap()), + (Value::U32(l), Value::U32(r)) => Value::U32(l.checked_div(r).unwrap()), + (Value::U64(l), Value::U64(r)) => Value::U64(l.checked_div(r).unwrap()), + (l, r) => panic!("cannot divide {l:?} and {r:?}"), + }, + Expr::Arith(Arith::Rem, lhs, rhs) => match (lhs.eval(), rhs.eval()) { + (Value::U8(l), Value::U8(r)) => Value::U8(l.checked_rem(r).unwrap()), + (Value::U16(l), Value::U16(r)) => Value::U16(l.checked_rem(r).unwrap()), + (Value::U32(l), Value::U32(r)) => Value::U32(l.checked_rem(r).unwrap()), + (Value::U64(l), Value::U64(r)) => Value::U64(l.checked_rem(r).unwrap()), + (l, r) => panic!("cannot compute remainder {l:?} and {r:?}"), + }, + Expr::Arith(Arith::BitAnd, lhs, rhs) => match (lhs.eval(), rhs.eval()) { + (Value::U8(l), Value::U8(r)) => Value::U8(l & r), + (Value::U16(l), Value::U16(r)) => Value::U16(l & r), + (Value::U32(l), Value::U32(r)) => Value::U32(l & r), + (Value::U64(l), Value::U64(r)) => Value::U64(l & r), + (l, r) => panic!("cannot bitwise and {l:?} and {r:?}"), + }, + Expr::Arith(Arith::BitOr, lhs, rhs) => match (lhs.eval(), rhs.eval()) { + (Value::U8(l), Value::U8(r)) => Value::U8(l | r), + (Value::U16(l), Value::U16(r)) => Value::U16(l | r), + (Value::U32(l), Value::U32(r)) => Value::U32(l | r), + (Value::U64(l), Value::U64(r)) => Value::U64(l | r), + (l, r) => panic!("cannot bitwise or {l:?} and {r:?}"), + }, + Expr::Arith(Arith::Shl, lhs, rhs) => match (lhs.eval(), rhs.eval()) { + (Value::U8(l), Value::U8(r)) => Value::U8(l.checked_shl(r as u32).unwrap()), + (Value::U16(l), Value::U16(r)) => Value::U16(l.checked_shl(r as u32).unwrap()), + (Value::U32(l), Value::U32(r)) => Value::U32(l.checked_shl(r).unwrap()), + (Value::U64(l), Value::U64(r)) => { + Value::U64(l.checked_shl(u32::try_from(r).unwrap()).unwrap()) + } + (l, r) => panic!("cannot shift left {l:?} and {r:?}"), + }, + Expr::Arith(Arith::Shr, lhs, rhs) => match (lhs.eval(), rhs.eval()) { + (Value::U8(l), Value::U8(r)) => Value::U8(l.checked_shr(r as u32).unwrap()), + (Value::U16(l), Value::U16(r)) => Value::U16(l.checked_shr(r as u32).unwrap()), + (Value::U32(l), Value::U32(r)) => Value::U32(l.checked_shr(r).unwrap()), + (Value::U64(l), Value::U64(r)) => { + Value::U64(l.checked_shr(u32::try_from(r).unwrap()).unwrap()) + } + (l, r) => panic!("cannot shift right {l:?} and {r:?}"), + }, + Expr::Unary(Unary::BoolNot, expr) => match expr.eval() { + Value::Bool(x) => Value::Bool(!x), + x => panic!("cannot negate {x:?}"), + }, + } + } +} + +#[derive(Debug, Clone)] +pub enum Decoder { + Call(FormatId), + CallRec(FormatId, RecId), + + FailWith(Label), + EndOfInput, + Byte(ByteSet), + Compute(Box), + + Variant(Label, Box), + Branch(MatchTree, Vec), + + While(MatchTree, Box), // Repeat decoder while input matches + + Seq(Vec), + Tuple(Vec), + Maybe(Box, Box), +} + +pub(crate) mod error; +use error::DecodeError; + +pub type DecodeResult = Result; + +impl Decoder { + pub(crate) const FAIL: Self = Decoder::FailWith(Label::Borrowed("FAIL_CONST")); + + pub fn parse<'input>( + &self, + program: &Program, + input: ReadCtxt<'input>, + ) -> DecodeResult<(Value, ReadCtxt<'input>)> { + match self { + Decoder::FailWith(msg) => Err(DecodeError::fail(msg.clone(), input)), + Decoder::EndOfInput => match input.read_byte() { + None => Ok((Value::Tuple(vec![]), input)), + Some((b, _)) => Err(DecodeError::Trailing { + byte: b, + offset: input.offset, + }), + }, + Decoder::Byte(bs) => { + let (b, input) = input.read_byte().ok_or(DecodeError::Overbyte { + offset: input.offset, + })?; + if bs.contains(b) { + Ok((Value::U8(b), input)) + } else { + Err(DecodeError::Unexpected { + found: b, + expected: *bs, + offset: input.offset, + }) + } + } + Decoder::Call(ix) => program.decoders[*ix].0.parse(program, input), + Decoder::CallRec(level, _) => program.decoders[*level].0.parse(program, input), + Decoder::Compute(expr) => { + let v = expr.eval(); + Ok((v, input)) + } + Decoder::Variant(lab, da) => { + let (v, input) = da.parse(program, input)?; + Ok((Value::Variant(lab.clone(), Box::new(v)), input)) + } + Decoder::Branch(tree, branches) => { + let index = tree.matches(input).ok_or(DecodeError::NoValidBranch { + offset: input.offset, + })?; + let d = &branches[index]; + // let (v, input) = d.parse(program, input)?; + // Ok(Value::Branch(index, Box::new(v)), input)) + d.parse(program, input) + } + Decoder::Seq(decs) => { + let mut input = input; + let mut v = Vec::with_capacity(decs.len()); + for d in decs { + let (va, next_input) = d.parse(program, input)?; + input = next_input; + v.push(va); + } + Ok((Value::Seq(v), input)) + } + Decoder::Tuple(decs) => { + let mut input = input; + let mut v = Vec::with_capacity(decs.len()); + for d in decs { + let (va, next_input) = d.parse(program, input)?; + input = next_input; + v.push(va); + } + Ok((Value::Tuple(v), input)) + } + Decoder::While(tree, a) => { + let mut input = input; + let mut v = Vec::new(); + while tree.matches(input).ok_or(DecodeError::NoValidBranch { + offset: input.offset, + })? == 0 + { + let (va, next_input) = a.parse(program, input)?; + input = next_input; + v.push(va); + } + Ok((Value::Seq(v), input)) + } + Decoder::Maybe(expr, a) => { + let is_present = expr.eval().unwrap_bool(); + if is_present { + let (v, input) = a.parse(program, input)?; + Ok((Value::Option(Some(Box::new(v))), input)) + } else { + Ok((Value::Option(None), input)) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::ser::to_string_pretty; + + #[test] + fn not_actually_recursive() -> AResult<()> { + let dead_end = Format::Byte(ByteSet::from_bits([1, 0, 0, 0])); + let text = Format::Tuple(vec![ + Format::Repeat(Box::new(Format::Byte(ByteSet::from(0x01..=0x7f)))), + Format::RecVar(0), + ]); + let mut module = FormatModule::new(); + let frefs = module.declare_rec_formats(vec![ + (Label::Borrowed("text.null"), dead_end), + (Label::Borrowed("text.cstring"), text), + ]); + let f = frefs[1].call(); + let program = Compiler::compile_program(&module, &f, RecurseCtx::NonRec)?; + let input = ReadCtxt::new(b"hello world\x00"); + let (value, _) = program.run(input)?; + eprintln!("{value:?}"); + Ok(()) + } + + #[test] + fn auto_recursive() -> AResult<()> { + let peano = Format::Union(vec![ + Format::Variant( + Label::Borrowed("peanoZ"), + Box::new(Format::Byte(ByteSet::from([b'Z']))), + ), + Format::Variant( + Label::Borrowed("peanoS"), + Box::new(Format::Tuple(vec![ + Format::Byte(ByteSet::from([b'S'])), + Format::RecVar(0), + ])), + ), + ]); + let mut module = FormatModule::new(); + let frefs = module.declare_rec_formats(vec![(Label::Borrowed("test.peano"), peano)]); + let f = Format::Tuple(vec![frefs[0].call(), Format::EndOfInput]); + let program = Compiler::compile_program(&module, &f, RecurseCtx::NonRec)?; + let input = ReadCtxt::new(b"SSSSZ"); + let (value, _) = program.run(input)?; + eprintln!("{}", to_string_pretty(&value).unwrap()); + Ok(()) + } +} diff --git a/experiments/doodle-rec/src/decoder/error.rs b/experiments/doodle-rec/src/decoder/error.rs new file mode 100644 index 00000000..817920ad --- /dev/null +++ b/experiments/doodle-rec/src/decoder/error.rs @@ -0,0 +1,82 @@ +use crate::Label; +use doodle::{prelude::ByteSet, read::ReadCtxt}; + +#[derive(Debug)] +pub enum DecodeError { + Fail { + message: Label, + offset: usize, + buffer: Vec, + }, + Trailing { + byte: u8, + offset: usize, + }, + Overbyte { + offset: usize, + }, + Unexpected { + found: u8, + expected: ByteSet, + offset: usize, + }, + NoValidBranch { + offset: usize, + }, +} + +impl std::fmt::Display for DecodeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Fail { + message, + offset, + buffer, + } => { + write!( + f, + "failed with message \"{message}\" at Offset={offset}, Buffer={buffer:#?}" + ) + } + Self::Trailing { byte, offset } => { + write!( + f, + "byte `{byte:02x}` found when end-of-input expected (offset = {offset})" + ) + } + Self::Overbyte { offset } => { + write!( + f, + "attempted read of byte would overrun buffer (offset = {offset})" + ) + } + Self::Unexpected { + found, + expected, + offset, + } => { + write!( + f, + "byte `{found:02x}` found when {expected:?} expected (offset = {offset})" + ) + } + Self::NoValidBranch { offset } => { + write!(f, "no valid branch at offset {offset}") + } + } + } +} + +impl std::error::Error for DecodeError {} + +impl DecodeError { + pub fn fail(message: Label, input: ReadCtxt<'_>) -> Self { + let offset = input.offset; + let buffer = input.input.to_owned(); + Self::Fail { + message, + offset, + buffer, + } + } +} diff --git a/experiments/doodle-rec/src/helper.rs b/experiments/doodle-rec/src/helper.rs new file mode 100644 index 00000000..8934a4d0 --- /dev/null +++ b/experiments/doodle-rec/src/helper.rs @@ -0,0 +1,42 @@ +use crate::{Format, Label}; +use doodle::prelude::ByteSet; + +pub fn tuple(formats: impl IntoIterator) -> Format { + Format::Tuple(formats.into_iter().collect()) +} + +pub fn is_byte(x: u8) -> Format { + Format::Byte(ByteSet::from([x])) +} + +pub fn alts>(branches: impl IntoIterator) -> Format { + Format::Union( + branches + .into_iter() + .map(|(name, f)| Format::Variant(name.into(), Box::new(f))) + .collect(), + ) +} + +pub fn byte_seq(bytes: impl IntoIterator) -> Format { + Format::Seq(bytes.into_iter().map(is_byte).collect()) +} + +pub fn repeat(format: Format) -> Format { + Format::Repeat(Box::new(format)) +} + +pub fn var(ix: usize) -> Format { + Format::RecVar(ix) +} + +pub fn fmt_variant>(name: Name, format: Format) -> Format { + Format::Variant(name.into(), Box::new(format)) +} + +pub fn optional(format: Format) -> Format { + Format::Union(vec![ + fmt_variant("no", Format::EMPTY), + fmt_variant("yes", format), + ]) +} diff --git a/experiments/doodle-rec/src/lib.rs b/experiments/doodle-rec/src/lib.rs new file mode 100644 index 00000000..5fdc7226 --- /dev/null +++ b/experiments/doodle-rec/src/lib.rs @@ -0,0 +1,781 @@ +pub mod decoder; +pub(crate) mod matchtree; +pub use matchtree::determinations; +pub mod helper; +pub(crate) use matchtree::{MatchTree, Next}; + +use anyhow::{Result as AResult, anyhow}; +use doodle::{bounds::Bounds, byte_set::ByteSet}; +use std::{ + borrow::Cow, + cell::OnceCell, + cmp::Ordering, + collections::{BTreeMap, HashSet}, + ops::{Add as _, RangeInclusive}, + rc::Rc, +}; + +pub type Label = Cow<'static, str>; + +/// Global index into the total set of formats within a Module +pub type FormatId = usize; + +/// Local index into a Batch of formats (e.g. 0 would be 'self' in a singleton-batch) +pub type RecId = usize; + +#[derive(Debug, Clone, Copy, Default)] +pub enum RecurseCtx<'a> { + #[default] + NonRec, + Recurse { + entry_id: RecId, + span: Span, + batch: &'a [FormatDecl], + }, +} + +impl<'a> RecurseCtx<'a> { + pub const fn is_recursive(&self) -> bool { + matches!(self, RecurseCtx::Recurse { .. }) + } + + pub const fn as_span(&self) -> Option> { + match self { + RecurseCtx::NonRec => None, + RecurseCtx::Recurse { span, .. } => Some(*span), + } + } + + /// Returns `(new_ctx, is_auto)` + pub fn enter(&self, ix: RecId) -> (Self, Ordering) { + match self { + RecurseCtx::NonRec => panic!("cannot recurse into non-recursive context"), + RecurseCtx::Recurse { + batch, + span, + entry_id, + } => { + assert!(ix < batch.len(), "batch index out of range"); + let ret = RecurseCtx::Recurse { + entry_id: ix, + span: *span, + batch, + }; + (ret, ix.cmp(entry_id)) + } + } + } + + pub fn convert_rec_var(&self, ix: RecId) -> Option { + self.as_span().map(|span| span.index(ix)) + } + + /// Returns the global format-level of the closest entry-point + pub fn get_level(&self) -> Option { + match self { + RecurseCtx::NonRec => None, + RecurseCtx::Recurse { span, entry_id, .. } => Some(span.index(*entry_id)), + } + } + + pub fn get_format(&self) -> Option<&'a Format> { + match self { + RecurseCtx::NonRec => None, + RecurseCtx::Recurse { + batch, entry_id, .. + } => Some(&batch[*entry_id].format), + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct FormatRef(FormatId); + +impl FormatRef { + pub const fn get_level(self) -> usize { + self.0 + } + + pub fn call(self) -> Format { + Format::ItemVar(self.0) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct Span { + pub start: Idx, + pub end: Idx, +} + +impl Span { + pub const fn new(start: Idx, end: Idx) -> Self { + Self { start, end } + } +} + +impl Span { + pub fn index(self, ix: usize) -> usize { + assert!(self.start + ix <= self.end); + self.start + ix + } +} + +impl From> for Span { + fn from(value: RangeInclusive) -> Self { + Self { + start: *value.start(), + end: *value.end(), + } + } +} + +#[derive(Debug, Clone)] +pub struct FormatDecl { + format: Format, + pub fmt_id: FormatId, + f_type: Rc>, + batch: Option>, +} + +impl FormatDecl { + pub fn solve_type(&self, module: &FormatModule) -> AResult<&FormatType> { + let mut visited = HashSet::new(); + self.solve_type_with(module, &mut visited) + } + + pub(crate) fn solve_type_with( + &self, + module: &FormatModule, + visited: &mut HashSet, + ) -> AResult<&FormatType> { + match self.f_type.get() { + None => { + visited.insert(self.fmt_id); + let f_type = self.format.infer_type(visited, module, self.batch)?; + let Ok(_) = self.f_type.set(f_type) else { + unreachable!("synchronous TOCTOU!?") + }; + Ok(self.f_type.get().unwrap()) + } + Some(f_type) => Ok(f_type), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BaseType { + Bool, + U8, + U16, + U32, + U64, + Char, +} + +impl BaseType { + pub fn is_numeric(&self) -> bool { + matches!( + self, + BaseType::U8 | BaseType::U16 | BaseType::U32 | BaseType::U64 + ) + } +} + +#[derive(Debug, Clone)] +pub enum FormatType { + Any, + Void, + Base(BaseType), + Ref(FormatId), + Shape(TypeShape), +} + +impl FormatType { + pub const UNIT: FormatType = FormatType::Shape(TypeShape::Tuple(Vec::new())); + + pub fn is_numeric(&self) -> bool { + match self { + FormatType::Base(base) => base.is_numeric(), + _ => false, + } + } + + fn unify(&self, other: &FormatType) -> AResult { + match (self, other) { + (FormatType::Any, _) => Ok(other.clone()), + (_, FormatType::Any) => Ok(self.clone()), + (FormatType::Ref(id0), FormatType::Ref(id1)) => { + if id0 == id1 { + Ok(FormatType::Ref(*id0)) + } else { + unimplemented!("cross-ref unification not implemented"); + } + } + (FormatType::Void, _) | (_, FormatType::Void) => Ok(FormatType::Void), + (FormatType::Base(b1), FormatType::Base(b2)) if b1 == b2 => Ok(FormatType::Base(*b1)), + (FormatType::Shape(s1), FormatType::Shape(s2)) => { + let s = s1.unify(s2)?; + Ok(FormatType::Shape(s)) + } + _ => Err(anyhow!( + "cannot unify incompatible types: {self:?}, {other:?}" + )), + } + } +} + +#[derive(Debug, Clone)] +pub enum TypeShape { + Tuple(Vec), + Seq(Box), + Option(Box), + Union(BTreeMap), +} + +impl TypeShape { + fn unify(&self, other: &Self) -> AResult { + match (self, other) { + (TypeShape::Tuple(t1), TypeShape::Tuple(t2)) => { + if t1.len() != t2.len() { + return Err(anyhow!( + "cannot unify tuples of different arity: {t1:?}, {t2:?}" + )); + } + let mut unified = Vec::with_capacity(t1.len()); + for (t1, t2) in t1.iter().zip(t2.iter()) { + unified.push(t1.unify(t2)?); + } + Ok(TypeShape::Tuple(unified)) + } + (TypeShape::Seq(t1), TypeShape::Seq(t2)) => Ok(TypeShape::Seq(Box::new(t1.unify(t2)?))), + (TypeShape::Option(t1), TypeShape::Option(t2)) => { + Ok(TypeShape::Option(Box::new(t1.unify(t2)?))) + } + (TypeShape::Union(bs1), TypeShape::Union(bs2)) => { + let mut bs = BTreeMap::new(); + + let keys1 = bs1.keys().collect::>(); + let keys2 = bs2.keys().collect::>(); + + let all_keys = HashSet::union(&keys1, &keys2).cloned(); + + for key in all_keys.into_iter() { + match (bs1.get(key), bs2.get(key)) { + (Some(t1), Some(t2)) => { + let t = t1.unify(t2)?; + bs.insert(key.clone(), t); + } + (Some(t), None) | (None, Some(t)) => { + bs.insert(key.clone(), t.clone()); + } + (None, None) => unreachable!("key must appear in at least one operand"), + } + } + Ok(TypeShape::Union(bs)) + } + _ => Err(anyhow!("cannot unify shapes: {self:?}, {other:?}")), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Format { + // References to other formats + ItemVar(FormatId), + RecVar(RecId), + + // Basic Primitives + FailWith(Label), + EndOfInput, + Byte(ByteSet), + Compute(Box), + + // Union-Based + Variant(Label, Box), + Union(Vec), + + // Sequential + Repeat(Box), + Seq(Vec), + + // Higher-Order + Tuple(Vec), + Maybe(Box, Box), +} + +impl Format { + pub const EMPTY: Self = Format::Tuple(vec![]); + + fn infer_type<'ctx>( + &'ctx self, + visited: &mut HashSet, + module: &'ctx FormatModule, + batch: Option>, + ) -> AResult { + match self { + Format::ItemVar(level) => { + if visited.contains(level) { + Ok(FormatType::Ref(*level)) + } else { + let decl = &module.decls[*level]; + Ok(decl.solve_type_with(module, visited)?.clone()) + } + } + Format::RecVar(batch_ix) => match batch { + None => Err(anyhow!("Recursion without a batch")), + Some(range) => { + let level = range.start + batch_ix; + if level > range.end { + return Err(anyhow!("batch index out of range")); + } + if visited.contains(&level) { + Ok(FormatType::Ref(level)) + } else { + let decl = &module.decls[level]; + Ok(decl.solve_type_with(module, visited)?.clone()) + } + } + }, + Format::FailWith(_msg) => Ok(FormatType::Void), + Format::EndOfInput => Ok(FormatType::UNIT), + Format::Byte(bs) if bs.is_empty() => Ok(FormatType::Void), + Format::Byte(_) => Ok(FormatType::Base(BaseType::U8)), + Format::Compute(expr) => expr.as_ref().infer_type(), + Format::Variant(label, inner) => { + let inner_type = inner.infer_type(visited, module, batch)?; + Ok(FormatType::Shape(TypeShape::Union(BTreeMap::from([( + label.clone(), + inner_type, + )])))) + } + Format::Union(branches) => { + let mut t = FormatType::Any; + for f in branches { + t = t.unify(&f.infer_type(visited, module, batch)?)?; + } + Ok(t) + } + Format::Repeat(inner) => { + let t = inner.infer_type(visited, module, batch)?; + Ok(FormatType::Shape(TypeShape::Seq(Box::new(t)))) + } + Format::Seq(elts) => { + let mut elem_type = FormatType::Any; + for elt in elts { + elem_type = elem_type.unify(&elt.infer_type(visited, module, batch)?)?; + } + Ok(FormatType::Shape(TypeShape::Seq(Box::new(elem_type)))) + } + Format::Tuple(elts) => { + let mut types = Vec::with_capacity(elts.len()); + for elt in elts { + types.push(elt.infer_type(visited, module, batch)?); + } + Ok(FormatType::Shape(TypeShape::Tuple(types))) + } + Format::Maybe(expr, format) => match expr.infer_type()? { + FormatType::Base(BaseType::Bool) => { + let t = format.infer_type(visited, module, batch)?; + Ok(FormatType::Shape(TypeShape::Option(Box::new(t)))) + } + other => Err(anyhow!( + "maybe expression type was inferred to be non-bool: {other:?}" + )), + }, + } + } + + fn depends_on_next<'a>(&self, module: &'a FormatModule, ctx: RecurseCtx<'a>) -> bool { + match self { + Format::ItemVar(level) => { + let ctx = module.get_ctx(*level); + module.get_format(*level).depends_on_next(module, ctx) + } + Format::FailWith(..) => false, + Format::EndOfInput => false, + Format::Byte(..) => false, + Format::Compute(..) => false, + Format::RecVar(..) => { + // REVIEW - are there any recursive formats that *don't* depend on next? + // FIXME[epic=hardcoded] - this is a placeholder for future improvements to classification logic + // false + true + } + Format::Variant(_, f) => f.depends_on_next(module, ctx), + Format::Union(branches) => Format::union_depends_on_next(branches, module, ctx), + Format::Repeat(..) => true, + Format::Seq(formats) | Format::Tuple(formats) => { + formats.iter().any(|f| f.depends_on_next(module, ctx)) + } + Format::Maybe(..) => true, + } + } + + fn union_depends_on_next<'a>( + branches: &'a [Format], + module: &'a FormatModule, + ctx: RecurseCtx<'a>, + ) -> bool { + let mut fs = Vec::with_capacity(branches.len()); + for f in branches { + if f.depends_on_next(module, ctx) { + return true; + } + fs.push(f.clone()); + } + MatchTree::build(module, &fs, Rc::new(Next::Empty), ctx).is_none() + } + + fn is_nullable(&self, module: &FormatModule) -> bool { + self.match_bounds(module).min() == 0 + } + + fn match_bounds(&self, module: &FormatModule) -> Bounds { + match self { + Format::ItemVar(level) => module.get_format(*level).match_bounds(module), + Format::FailWith(..) | Format::EndOfInput | Format::Compute(..) => Bounds::exact(0), + Format::Byte(_) => Bounds::exact(1), + Format::Variant(_, f) => f.match_bounds(module), + Format::Union(branches) => branches + .iter() + .map(|f| f.match_bounds(module)) + .reduce(Bounds::union) + .unwrap(), + Format::Tuple(fields) | Format::Seq(fields) => fields + .iter() + .map(|f| f.match_bounds(module)) + .reduce(Bounds::add) + .unwrap_or(Bounds::exact(0)), + Format::Repeat(_) => Bounds::any(), + Format::Maybe(_, f) => Bounds::union(Bounds::exact(0), f.match_bounds(module)), + Format::RecVar(..) => { + // REVIEW - we cannot get better than this without a complex model, and certainly not without adding more parameters + Bounds::any() + } + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Expr { + // Primitive Values + U8(u8), + U16(u16), + U32(u32), + U64(u64), + Bool(bool), + + // Primitive Value Casts + AsChar(Box), + AsU8(Box), + AsU16(Box), + AsU32(Box), + AsU64(Box), + + // Higher-Order Exprs + Seq(Vec), + Tuple(Vec), + LiftOption(Option>), + Variant(Label, Box), + + // Operational + IntRel(IntRel, Box, Box), + Arith(Arith, Box, Box), + Unary(Unary, Box), +} + +impl Expr { + fn infer_type(&self) -> AResult { + match self { + Expr::U8(_) => Ok(FormatType::Base(BaseType::U8)), + Expr::U16(_) => Ok(FormatType::Base(BaseType::U16)), + Expr::U32(_) => Ok(FormatType::Base(BaseType::U32)), + Expr::U64(_) => Ok(FormatType::Base(BaseType::U64)), + Expr::Bool(_) => Ok(FormatType::Base(BaseType::Bool)), + Expr::AsChar(expr) => { + let expr_type = expr.infer_type()?; + if expr_type.is_numeric() { + Ok(FormatType::Base(BaseType::Char)) + } else { + Err(anyhow!("invalid char type conversion from {expr_type:?}")) + } + } + Expr::AsU8(expr) => { + let expr_type = expr.infer_type()?; + if expr_type.is_numeric() { + Ok(FormatType::Base(BaseType::U8)) + } else { + Err(anyhow!("invalid u8 type conversion from {expr_type:?}")) + } + } + Expr::AsU16(expr) => { + let expr_type = expr.infer_type()?; + if expr_type.is_numeric() { + Ok(FormatType::Base(BaseType::U16)) + } else { + Err(anyhow!("invalid u16 type conversion from {expr_type:?}")) + } + } + Expr::AsU32(expr) => { + let expr_type = expr.infer_type()?; + if expr_type.is_numeric() { + Ok(FormatType::Base(BaseType::U32)) + } else { + Err(anyhow!("invalid u32 type conversion from {expr_type:?}")) + } + } + Expr::AsU64(expr) => { + let expr_type = expr.infer_type()?; + if expr_type.is_numeric() { + Ok(FormatType::Base(BaseType::U64)) + } else { + Err(anyhow!("invalid u64 type conversion from {expr_type:?}")) + } + } + Expr::Seq(exprs) => { + let mut elem_type = FormatType::Any; + for expr in exprs { + elem_type = expr.infer_type()?.unify(&elem_type)?; + } + Ok(FormatType::Shape(TypeShape::Seq(Box::new(elem_type)))) + } + Expr::Tuple(exprs) => { + let mut elem_types = Vec::with_capacity(exprs.len()); + for expr in exprs { + elem_types.push(expr.infer_type()?); + } + Ok(FormatType::Shape(TypeShape::Tuple(elem_types))) + } + Expr::LiftOption(None) => Ok(FormatType::Shape(TypeShape::Option(Box::new( + FormatType::Any, + )))), + Expr::LiftOption(Some(expr)) => { + let expr_type = expr.infer_type()?; + Ok(FormatType::Shape(TypeShape::Option(Box::new(expr_type)))) + } + Expr::Variant(lab, expr) => { + let expr_type = expr.infer_type()?; + Ok(FormatType::Shape(TypeShape::Union(BTreeMap::from([( + lab.clone(), + expr_type, + )])))) + } + Expr::IntRel(_rel, lhs, rhs) => { + let lhs_type = lhs.infer_type()?; + let rhs_type = rhs.infer_type()?; + match (lhs_type, rhs_type) { + (FormatType::Base(b1), FormatType::Base(b2)) if b1 == b2 && b1.is_numeric() => { + Ok(FormatType::Base(BaseType::Bool)) + } + (lhs_type, rhs_type) => Err(anyhow!( + "invalid integer relation between {lhs_type:?} and {rhs_type:?}" + )), + } + } + Expr::Arith(_arith, lhs, rhs) => { + let lhs_type = lhs.infer_type()?; + let rhs_type = rhs.infer_type()?; + match (lhs_type, rhs_type) { + (FormatType::Base(b1), FormatType::Base(b2)) if b1 == b2 && b1.is_numeric() => { + Ok(FormatType::Base(b1)) + } + (lhs_type, rhs_type) => Err(anyhow!( + "invalid arithmetic operation between {lhs_type:?} and {rhs_type:?}" + )), + } + } + Expr::Unary(Unary::BoolNot, expr) => { + let expr_type = expr.infer_type()?; + if matches!(expr_type, FormatType::Base(BaseType::Bool)) { + Ok(FormatType::Base(BaseType::Bool)) + } else { + Err(anyhow!("invalid bool-not on {expr_type:?}")) + } + } + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum IntRel { + Eq, + Neq, + Gt, + Gte, + Lt, + Lte, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Arith { + Add, + Sub, + Mul, + Div, + Rem, + Shl, + Shr, + BitOr, + BitAnd, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Unary { + BoolNot, +} + +#[derive(Debug)] +pub struct FormatModule { + names: Vec