diff --git a/dialectic-compiler/src/cfg.rs b/dialectic-compiler/src/cfg.rs index 78d16cdf..76a6a78f 100644 --- a/dialectic-compiler/src/cfg.rs +++ b/dialectic-compiler/src/cfg.rs @@ -47,10 +47,13 @@ pub enum Ir { /// continuation, which is stored in the "next" pointer of the node. The scope resolution pass /// "lowers" this next pointer continuation into the arms of the `Choose`, and so after scope /// resolution all `Choose` nodes' next pointers should be `None`. - Choose(Vec>), - /// Like `Choose`, `Offer` nodes have a list of choices, and after scope resolution have no - /// continuation. - Offer(Vec>), + /// + /// The `Choose` may also have an optional "carrier" type. If none, this is defaulted to be a + /// `Choice` corresponding to the `N` number of branches of the `Choose`. + Choose(Vec>, Option), + /// Like `Choose`, `Offer` nodes have a list of choices and an optional carrier type, and after + /// scope resolution have no continuation. + Offer(Vec>, Option), /// `Split` nodes have a transmit-only half and a receive-only half. The nodes' semantics are /// similar to `Call`. Split { @@ -236,7 +239,7 @@ impl Cfg { follow(None, *tx_only); follow(None, *rx_only); } - Ir::Choose(choices) | Ir::Offer(choices) => { + Ir::Choose(choices, _) | Ir::Offer(choices, _) => { // Inline the current implicit continuation into the next pointer of the node, // if it is `Some` follow(implicit_cont, *next); @@ -285,7 +288,7 @@ impl Cfg { // `Offer` (which have had their explicit continuations erased and lowered into the // arms) then we need to assign the scoped implicit continuation, if there is one, // as the new explicit continuation - None if !matches!(expr, Ir::Choose(_) | Ir::Offer(_)) => *next = implicit_cont, + None if !matches!(expr, Ir::Choose(..) | Ir::Offer(..)) => *next = implicit_cont, // This will only be reached if there is no explicit continuation and the node is a // `Choose` or `Offer`, in which case we want to do nothing. _ => {} @@ -321,7 +324,7 @@ impl Cfg { stack.extend(node.next); } } - Ir::Choose(choices) | Ir::Offer(choices) => { + Ir::Choose(choices, _) | Ir::Offer(choices, _) => { stack.extend(choices.iter().filter_map(Option::as_ref)); assert!(node.next.is_none(), "at this point in the compiler, all continuations of \ @@ -479,23 +482,23 @@ impl Cfg { cont: Rc::new(cont), } } - Ir::Choose(choices) => { + Ir::Choose(choices, carrier_type) => { let targets = choices .iter() .map(|&choice| generate_inner(cfg, errors, loop_env, choice)) .collect(); debug_assert!(node.next.is_none(), "non-`Done` continuation for `Choose`"); - Target::Choose(targets) + Target::Choose(targets, carrier_type.clone()) } - Ir::Offer(choices) => { + Ir::Offer(choices, carrier_type) => { let targets = choices .iter() .map(|&choice| generate_inner(cfg, errors, loop_env, choice)) .collect(); debug_assert!(node.next.is_none(), "non-`Done` continuation for `Offer`"); - Target::Offer(targets) + Target::Offer(targets, carrier_type.clone()) } Ir::Loop(body) => { loop_env.push(node_index); diff --git a/dialectic-compiler/src/flow.rs b/dialectic-compiler/src/flow.rs index 453775c1..807b8757 100644 --- a/dialectic-compiler/src/flow.rs +++ b/dialectic-compiler/src/flow.rs @@ -287,7 +287,7 @@ fn preconditions(cfg: &Cfg, constraint: Constraint) -> Dnf { // passability of an `Offer` or `Choose` typically does not matter because as control // flow analysis must be run after scope resolution, they should not have an implicit // continuation. - Ir::Offer(choices) | Ir::Choose(choices) => Dnf(choices + Ir::Offer(choices, _) | Ir::Choose(choices, _) => Dnf(choices .iter() .filter_map(Option::as_ref) .map(|&c| vec![Constraint::Passable(c)]) @@ -325,7 +325,7 @@ fn preconditions(cfg: &Cfg, constraint: Constraint) -> Dnf { Dnf::only_if(conj) } // `Choose`/`Offer` are haltable only if any of their choices are haltable. - Ir::Choose(choices) | Ir::Offer(choices) => Dnf(choices + Ir::Choose(choices, _) | Ir::Offer(choices, _) => Dnf(choices .iter() // If any of the choices are `Done`, we want to emit an empty `Vec` instead, // denoting that this constraint is trivially satisfiable. @@ -371,7 +371,7 @@ fn preconditions(cfg: &Cfg, constraint: Constraint) -> Dnf { Dnf::only_if(conj) } // `Choose` and `Offer` break only if any arm breaks. - Ir::Choose(choices) | Ir::Offer(choices) => Dnf(choices + Ir::Choose(choices, _) | Ir::Offer(choices, _) => Dnf(choices .iter() .filter_map(Option::as_ref) .chain(node.next.as_ref()) diff --git a/dialectic-compiler/src/lib.rs b/dialectic-compiler/src/lib.rs index 4413006e..1247509a 100644 --- a/dialectic-compiler/src/lib.rs +++ b/dialectic-compiler/src/lib.rs @@ -393,7 +393,7 @@ mod tests { let continue1 = cfg.singleton(Ir::Continue(client)); cfg[recv].next = Some(continue1); let choose_opts = vec![Some(send), Some(recv)]; - let choose = cfg.singleton(Ir::Choose(choose_opts)); + let choose = cfg.singleton(Ir::Choose(choose_opts, None)); cfg[client_tally].expr = Ir::Loop(Some(choose)); @@ -401,12 +401,12 @@ mod tests { let send = cfg.send("Operation"); cfg[send].next = Some(client_tally); let choose_opts = vec![Some(break0), Some(send)]; - let choose = cfg.singleton(Ir::Choose(choose_opts)); + let choose = cfg.singleton(Ir::Choose(choose_opts, None)); cfg[client].expr = Ir::Loop(Some(choose)); let s = format!("{}", cfg.generate_target(Some(client)).unwrap()); - assert_eq!(s, "Loop>, Recv>)>>>)>>"); + assert_eq!(s, "Loop>, Recv>), Choice<2>>>>), Choice<2>>>"); } #[test] @@ -421,14 +421,14 @@ mod tests { let continue0 = cfg.singleton(Ir::Continue(client)); cfg[call].next = Some(continue0); let choose_opts = vec![Some(break0), Some(send)]; - let choose = cfg.singleton(Ir::Choose(choose_opts)); + let choose = cfg.singleton(Ir::Choose(choose_opts, None)); cfg[client].expr = Ir::Loop(Some(choose)); let s = format!("{}", cfg.generate_target(Some(client)).unwrap()); assert_eq!( s, - "Loop>>)>>" + "Loop>>), Choice<2>>>" ); } } diff --git a/dialectic-compiler/src/parse.rs b/dialectic-compiler/src/parse.rs index 51020320..06169d76 100644 --- a/dialectic-compiler/src/parse.rs +++ b/dialectic-compiler/src/parse.rs @@ -228,6 +228,12 @@ impl Parse for Spanned { let kw_span = input.parse::()?.span(); let choose_span = kw_span.join(input.span()).unwrap_or(kw_span); + let carrier_type = if input.peek(token::Brace) { + None + } else { + Some(input.parse()?) + }; + let content; braced!(content in input); let mut choice_arms = Vec::new(); @@ -251,7 +257,7 @@ impl Parse for Spanned { } Ok(Spanned { - inner: Syntax::Choose(arm_asts), + inner: Syntax::Choose(arm_asts, carrier_type), span: choose_span, }) } else if lookahead.peek(kw::offer) { @@ -259,6 +265,12 @@ impl Parse for Spanned { let kw_span = input.parse::()?.span(); let offer_span = kw_span.join(input.span()).unwrap_or(kw_span); + let carrier_type = if input.peek(token::Brace) { + None + } else { + Some(input.parse()?) + }; + let content; braced!(content in input); let mut choice_arms = Vec::new(); @@ -282,7 +294,7 @@ impl Parse for Spanned { } Ok(Spanned { - inner: Syntax::Offer(arm_asts), + inner: Syntax::Offer(arm_asts, carrier_type), span: offer_span, }) } else if lookahead.peek(kw::split) { diff --git a/dialectic-compiler/src/syntax.rs b/dialectic-compiler/src/syntax.rs index b00a5f64..435d9797 100644 --- a/dialectic-compiler/src/syntax.rs +++ b/dialectic-compiler/src/syntax.rs @@ -42,9 +42,9 @@ pub enum Syntax { /// Syntax: `call T` or `call { ... }`. Call(Box>), /// Syntax: `choose { 0 => ..., ... }`. - Choose(Vec>), + Choose(Vec>, Option), /// Syntax: `offer { 0 => ..., ... }`. - Offer(Vec>), + Offer(Vec>, Option), /// Syntax: `split { -> ..., <- ... }`. Split { /// The transmit-only half. @@ -172,19 +172,19 @@ fn to_cfg<'a>( let rx_only = to_cfg(rx_only, cfg, env).0; Ir::Split { tx_only, rx_only } } - Choose(choices) => { + Choose(choices, carrier_type) => { let choice_nodes = choices .iter() .map(|choice| to_cfg(choice, cfg, env).0) .collect(); - Ir::Choose(choice_nodes) + Ir::Choose(choice_nodes, carrier_type.clone()) } - Offer(choices) => { + Offer(choices, carrier_type) => { let choice_nodes = choices .iter() .map(|choice| to_cfg(choice, cfg, env).0) .collect(); - Ir::Offer(choice_nodes) + Ir::Offer(choice_nodes, carrier_type.clone()) } Continue(label) => { return convert_jump_to_cfg(label, CompileError::ContinueOutsideLoop, Ir::Continue) @@ -306,13 +306,13 @@ impl Spanned { let rx = rx_only.to_token_stream_with(add_optional); quote_spanned! {sp=> split { -> #tx, <- #rx, } } } - Choose(choices) => { + Choose(choices, carrier_type) => { let arms = choice_arms_to_tokens(&mut add_optional, choices); - quote_spanned! {sp=> choose { #arms } } + quote_spanned! {sp=> choose #carrier_type { #arms } } } - Offer(choices) => { + Offer(choices, carrier_type) => { let arms = choice_arms_to_tokens(&mut add_optional, choices); - quote_spanned! {sp=> offer { #arms } } + quote_spanned! {sp=> offer #carrier_type { #arms } } } Loop(None, body) => { let body_tokens = body.to_token_stream_with(add_optional); @@ -354,7 +354,7 @@ impl Spanned { // rather than required. let ends_with_block = matches!( &stmt.inner, - Block(_) | Split { .. } | Offer(_) | Choose(_) | Loop(_, _), + Block(_) | Split { .. } | Offer(..) | Choose(..) | Loop(_, _), ); if !(is_call_of_block || ends_with_block) || add_optional() { @@ -438,8 +438,14 @@ mod tests { "type" => Type(parse_quote!(())), other => unreachable!("{}", other), }, - "choose" => Choose(Arbitrary::arbitrary(g)), - "offer" => Offer(Arbitrary::arbitrary(g)), + "choose" => Choose( + Arbitrary::arbitrary(g), + Some(parse_quote!(())).filter(|_| Arbitrary::arbitrary(g)), + ), + "offer" => Offer( + Arbitrary::arbitrary(g), + Some(parse_quote!(())).filter(|_| Arbitrary::arbitrary(g)), + ), "split" => Split { tx_only: Arbitrary::arbitrary(g), rx_only: Arbitrary::arbitrary(g), @@ -475,8 +481,8 @@ mod tests { }) }) .collect(), - Choose(choices) => choices.shrink().map(Choose).collect(), - Offer(choices) => choices.shrink().map(Offer).collect(), + Choose(choices, _) => choices.shrink().map(|cs| Choose(cs, None)).collect(), + Offer(choices, _) => choices.shrink().map(|cs| Offer(cs, None)).collect(), Loop(label, body) => body .shrink() .map(|body_shrunk| Loop(label.clone(), body_shrunk)) diff --git a/dialectic-compiler/src/target.rs b/dialectic-compiler/src/target.rs index 77dab3f5..bd160f6b 100644 --- a/dialectic-compiler/src/target.rs +++ b/dialectic-compiler/src/target.rs @@ -4,7 +4,7 @@ use { proc_macro2::TokenStream, quote::{quote_spanned, ToTokens}, std::{fmt, rc::Rc}, - syn::{Path, Type}, + syn::{parse_quote, Path, Type}, }; use crate::Spanned; @@ -24,10 +24,10 @@ pub enum Target { Recv(Type, Rc>), /// Session type: `Send`. Send(Type, Rc>), - /// Session type: `Choose<(P, ...)>`. - Choose(Vec>), - /// Session type: `Offer<(P, ...)>`. - Offer(Vec>), + /// Session type: `Choose<(P, ...), Carrier>`. + Choose(Vec>, Option), + /// Session type: `Offer<(P, ...), Carrier>`. + Offer(Vec>, Option), /// Session type: `Loop<...>`. Loop(Rc>), /// Session type: `Continue`. @@ -64,33 +64,55 @@ impl fmt::Display for Target { } => write!(f, "Split<{}, {}, {}>", s, p, q)?, Call(s, p) => write!(f, "Call<{}, {}>", s, p)?, Then(s, p) => write!(f, "<{} as Then<{}>>::Combined", s, p)?, - Choose(cs) => { + Choose(cs, carrier_type) => { let count = cs.len(); + write!(f, "Choose<(")?; + for (i, c) in cs.iter().enumerate() { write!(f, "{}", c)?; if i + 1 < count { write!(f, ", ")?; } } + if count == 1 { write!(f, ",")?; } - write!(f, ")>")?; + + write!(f, "), ")?; + + match carrier_type { + Some(carrier) => write!(f, "CustomChoice<{}>", carrier.to_token_stream())?, + None => write!(f, "Choice<{}>", count)?, + } + + write!(f, ">")?; } - Offer(cs) => { + Offer(cs, carrier_type) => { let count = cs.len(); + write!(f, "Offer<(")?; + for (i, c) in cs.iter().enumerate() { write!(f, "{}", c)?; if i + 1 < count { write!(f, ", ")?; } } + if count == 1 { write!(f, ",")?; } - write!(f, ")>")?; + + write!(f, "), ")?; + + match carrier_type { + Some(carrier) => write!(f, "CustomChoice<{}>", carrier.to_token_stream())?, + None => write!(f, "Choice<{}>", count)?, + } + + write!(f, ">")?; } Continue(n) => { write!(f, "Continue<{}>", n)?; @@ -154,17 +176,33 @@ impl Spanned { quote_spanned!(span=> <#s as #dialectic_crate::types::Then<#p>>::Combined) .to_tokens(tokens); } - Choose(cs) => { + Choose(cs, carrier_type) => { + let carrier: syn::Type = match carrier_type { + Some(ty) => parse_quote!(#dialectic_crate::backend::CustomChoice<#ty>), + None => { + let n = cs.len(); + parse_quote!(#dialectic_crate::backend::Choice<#n>) + } + }; let cs = cs .iter() .map(|c| c.to_token_stream_with_crate_name(dialectic_crate)); - quote_spanned!(span=> #dialectic_crate::types::Choose<(#(#cs,)*)>).to_tokens(tokens) + quote_spanned!(span=> #dialectic_crate::types::Choose<(#(#cs,)*), #carrier>) + .to_tokens(tokens) } - Offer(cs) => { + Offer(cs, carrier_type) => { + let carrier: syn::Type = match carrier_type { + Some(ty) => parse_quote!(#dialectic_crate::backend::CustomChoice<#ty>), + None => { + let n = cs.len(); + parse_quote!(#dialectic_crate::backend::Choice<#n>) + } + }; let cs = cs .iter() .map(|c| c.to_token_stream_with_crate_name(dialectic_crate)); - quote_spanned!(span=> #dialectic_crate::types::Offer<(#(#cs,)*)>).to_tokens(tokens) + quote_spanned!(span=> #dialectic_crate::types::Offer<(#(#cs,)*), #carrier>) + .to_tokens(tokens) } Continue(n) => { quote_spanned!(span=> #dialectic_crate::types::Continue<#n>).to_tokens(tokens) diff --git a/dialectic-compiler/tests/tally_client.rs b/dialectic-compiler/tests/tally_client.rs index 4d2c9d71..f70c5133 100644 --- a/dialectic-compiler/tests/tally_client.rs +++ b/dialectic-compiler/tests/tally_client.rs @@ -9,14 +9,17 @@ fn tally_client_expr_call_ast() { let client_ast: Spanned = Syntax::Loop( None, Box::new( - Syntax::Choose(vec![ - Syntax::Break(None).into(), - Syntax::Block(vec![ - Syntax::send("Operation").into(), - Syntax::call(Syntax::type_("ClientTally")).into(), - ]) - .into(), - ]) + Syntax::Choose( + vec![ + Syntax::Break(None).into(), + Syntax::Block(vec![ + Syntax::send("Operation").into(), + Syntax::call(Syntax::type_("ClientTally")).into(), + ]) + .into(), + ], + None, + ) .into(), ), ) @@ -25,7 +28,7 @@ fn tally_client_expr_call_ast() { let s = format!("{}", syntax::compile(&client_ast).unwrap()); assert_eq!( s, - "Loop>>)>>" + "Loop>>), Choice<2>>>" ); } @@ -45,7 +48,7 @@ fn tally_client_expr_call_parse_string() { let s = format!("{}", syntax::compile(&ast).unwrap()); assert_eq!( s, - "Loop>>)>>" + "Loop>>), Choice<2>>>" ); } @@ -65,7 +68,7 @@ fn tally_client_invocation_call_parse_string() { let s = format!("{}", syntax::compile(&ast).unwrap()); assert_eq!( s, - "Loop>>)>>" + "Loop>>), Choice<2>>>" ); } @@ -85,7 +88,7 @@ fn tally_client_invocation_direct_subst_parse_string() { let s = format!("{}", syntax::compile(&ast).unwrap()); assert_eq!( s, - "Loop>>::Combined>)>>" + "Loop>>::Combined>), Choice<2>>>" ); } @@ -128,10 +131,10 @@ fn tally_client_direct_subst_nested_loop_break() { ::dialectic::types::Choose<( ::dialectic::types::Send>, ::dialectic::types::Recv>, - )> + ), ::dialectic::backend::Choice<2usize>> > >, - )> + ), ::dialectic::backend::Choice<2usize>> >", ) .unwrap(); diff --git a/dialectic-macro/Cargo.toml b/dialectic-macro/Cargo.toml index b13f8adb..5019c3c2 100644 --- a/dialectic-macro/Cargo.toml +++ b/dialectic-macro/Cargo.toml @@ -15,6 +15,7 @@ proc-macro = true [dependencies] dialectic-compiler = { version = "0.1", path = "../dialectic-compiler" } +vesta-syntax = { version = "0.1" } syn = { version = "1.0", features = ["full", "parsing"] } proc-macro2 = "1.0" quote = "1.0" @@ -24,6 +25,7 @@ dialectic = { version = "0.4", path = "../dialectic" } dialectic-tokio-mpsc = { version = "0.1", path = "../dialectic-tokio-mpsc" } static_assertions = "1.1" tokio = "1.2" +vesta = "0.1" [package.metadata.docs.rs] all-features = true diff --git a/dialectic-macro/src/lib.rs b/dialectic-macro/src/lib.rs index 725fd801..d7d83e2a 100644 --- a/dialectic-macro/src/lib.rs +++ b/dialectic-macro/src/lib.rs @@ -4,9 +4,10 @@ extern crate proc_macro; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned, ToTokens, TokenStreamExt}; use syn::{ - braced, parse::Parse, parse::ParseStream, parse_macro_input, punctuated::Punctuated, - spanned::Spanned, Arm, Ident, LitInt, Pat, Token, + braced, parse::Parse, parse::ParseStream, parse_macro_input, parse_quote, + punctuated::Punctuated, spanned::Spanned, token::Brace, Ident, LitInt, Token, }; +use vesta_syntax::{CaseArm, CaseInput, CaseOutput}; /** The `Session!` macro compiles a small domain-specific language for describing session types into @@ -85,12 +86,12 @@ and the [`offer!`] macro. # type_eq!( Session! { offer { 0 => {}, 1 => {} } }, - Offer<(Done, Done)> + Offer<(Done, Done), Choice<2>> ); type_eq!( Session! { choose { 0 => {}, 1 => {} } }, - Choose<(Done, Done)> + Choose<(Done, Done), Choice<2>> ); ``` @@ -171,7 +172,7 @@ type R = Session! { } }; -type_eq!(R, Loop, Recv>>>, Done)>>); +type_eq!(R, Loop, Recv>>>, Done), Choice<2>>>); ``` ## The `split` keyword @@ -220,7 +221,7 @@ type Protocol = Session! { type_eq!( Protocol, - Loop>>>>, Done)>>, + Loop>>>>, Done), Choice<2>>>, ); ``` @@ -295,7 +296,7 @@ pub fn Session(input: proc_macro::TokenStream) -> proc_macro::TokenStream { /// }); /// /// // Choose to send an integer -/// c1.choose::<0>().await?.send(42).await?; +/// c1.choose::<0>(()).await?.send(42).await?; /// /// // Wait for the offering thread to finish /// t1.await??; @@ -317,9 +318,11 @@ pub fn offer(input: proc_macro::TokenStream) -> proc_macro::TokenStream { struct OfferInvocation { /// The identifier of the channel to be offered upon. chan: syn::Ident, + /// The brace token for the whole of the offer. + brace_token: Brace, /// The syntactic branches of the invocation: this could be invalid, because [`Arm`] contains /// many variants for its patterns which are not legitimate in our context. - branches: Vec, + cases: Vec, } /// The output information generated by the `offer!` macro, as required to quote it back into a @@ -327,8 +330,7 @@ struct OfferInvocation { struct OfferOutput { /// The identifier of the channel to be offered upon. chan: syn::Ident, - /// The branches, in order, to be emitted for the `match` statement generated by the macro. - branches: Vec, + case_output: CaseOutput, } impl Parse for OfferInvocation { @@ -336,49 +338,32 @@ impl Parse for OfferInvocation { let _ = input.parse::()?; let chan = input.parse::()?; let content; - let _ = braced!(content in input); - let mut branches = Vec::new(); + let brace_token = braced!(content in input); + let mut arms = Vec::new(); while !content.is_empty() { - branches.push(content.call(Arm::parse)?); + arms.push(content.call(CaseArm::parse)?); } - Ok(OfferInvocation { chan, branches }) + Ok(OfferInvocation { + chan, + brace_token, + cases: arms, + }) } } impl ToTokens for OfferOutput { fn to_tokens(&self, tokens: &mut TokenStream) { - let OfferOutput { chan, branches } = self; + let OfferOutput { chan, case_output } = self; // Find the path necessary to refer to types in the dialectic crate. let dialectic_crate = dialectic_compiler::dialectic_path(); - // Create an iterator of token streams, one for each branch of the generated `match` - let arms = branches.iter().enumerate().map(|(choice, body)| { - let byte = choice as u8; - quote! { - #byte => { - let #chan = #dialectic_crate::Branches::case::<#choice>(#chan); - let #chan = match #chan { - Ok(chan) => chan, - Err(_) => unreachable!("malformed generated code from `offer!` macro: mismatch between variant and choice: this is a bug!"), - }; - #body - } - } - }); - // Merge all the branches into one `match` statement that branches precisely once, on the // discriminant of the `Branches` structure tokens.append_all(quote! { { - match #dialectic_crate::Chan::offer( #chan ).await { - Ok(#chan) => Ok(match { - let choice: u8 = std::convert::Into::into(dialectic::Branches::choice(& #chan)); - choice - } { - #(#arms),* - _ => unreachable!("malformed code generated from `offer!` macro: impossible variant encountered: this is a bug!"), - }), + match #dialectic_crate::Chan::offer(#chan).await { + Ok(#chan) => Ok(#case_output), Err(error) => Err(error), } } @@ -389,64 +374,25 @@ impl ToTokens for OfferOutput { impl OfferInvocation { /// Compile an [`OfferInvocation`] to an [`OfferOutput`], or return an error if it was not a /// valid macro invocation. - fn compile(self) -> Result { - // Validate the structure, collecting all errors - let mut errors: Vec = Vec::new(); - for (choice, arm) in self.branches.iter().enumerate() { - if choice > 256 { - let message = format!("at most 256 arms (labeled `0` through `255` in ascending order) are permitted in the `offer!` macro; this arm is number {}", choice); - errors.push(syn::Error::new(arm.span(), message)); - } - match &arm.pat { - Pat::Lit(pat_lit) => { - let lit_int = match &*pat_lit.expr { - syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Int(lit_int), - .. - }) => lit_int, - _ => { - let message = "expected a usize literal"; - errors.push(syn::Error::new(pat_lit.expr.span(), message)); - continue; - } - }; - - match lit_int.base10_parse::() { - Ok(n) if n == choice => {} - Ok(_) => { - let message = format!("expected the `usize` literal `{}` for this arm of the `offer!` macro (note: arms must be in ascending order)", choice); - errors.push(syn::Error::new(pat_lit.span(), message)); - } - Err(e) => { - let message = format!("could not parse literal: `{}`", e); - errors.push(syn::Error::new(pat_lit.span(), message)); - } - } - } - _ => { - let message = format!("expected the `usize` literal `{}` for this arm (note: arms must be in ascending order)", choice); - errors.push(syn::Error::new(arm.pat.span(), message)) - } - } + fn compile(mut self) -> Result { + let chan = &self.chan; + for case in self.cases.iter_mut() { + let pat = &case.arm.pat; + let pair = parse_quote!((#chan, #pat)); + case.arm.pat = pair; } - // Either collect the valid expressions, one for each branch, or return the errors - match errors.pop() { - None => Ok(OfferOutput { - chan: self.chan, - branches: self - .branches - .into_iter() - .map(|Arm { body, .. }| *body) - .collect(), - }), - Some(mut error) => { - for e in errors.drain(..) { - error.combine(e) - } - Err(error) - } - } + let case_input = CaseInput { + scrutinee: parse_quote!(#chan), + brace_token: self.brace_token, + arms: self.cases, + }; + + let case_output = case_input.compile()?; + Ok(OfferOutput { + chan: self.chan, + case_output, + }) } } @@ -488,58 +434,84 @@ impl Parse for Mutability { } } +enum TransmitterBound { + Send(Mutability, syn::Type), + Choice(Token![match]), +} + struct TransmitterSpec { name: syn::Type, - types: Punctuated<(Mutability, syn::Type), Token![,]>, + bounds: Punctuated, +} + +enum ReceiverBound { + Recv(syn::Type), + Choice(Token![match]), } struct ReceiverSpec { name: syn::Type, - types: Punctuated, + bounds: Punctuated, } impl Parse for TransmitterSpec { fn parse(input: ParseStream) -> Result { - fn parse_convention_type_pair( - input: ParseStream, - ) -> Result<(Mutability, syn::Type), syn::Error> { - Ok((input.parse()?, input.parse()?)) + fn parse_bound(input: ParseStream) -> Result { + let mutability = input.parse()?; + let match_token = input.parse::>()?; + + match match_token { + Some(token) => Ok(TransmitterBound::Choice(token)), + None => Ok(TransmitterBound::Send(mutability, input.parse()?)), + } } let name: syn::Type = input.parse()?; if input.is_empty() { Ok(TransmitterSpec { name, - types: Punctuated::new(), + bounds: Punctuated::new(), }) } else { let _for: Token![for] = input.parse()?; - let types = if input.is_empty() { + let bounds = if input.is_empty() { Punctuated::new() } else { - input.parse_terminated(parse_convention_type_pair)? + input.parse_terminated(parse_bound)? }; - Ok(TransmitterSpec { name, types }) + Ok(TransmitterSpec { name, bounds }) } } } impl Parse for ReceiverSpec { fn parse(input: ParseStream) -> Result { + fn parse_bound(input: ParseStream) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(Token![match]) || input.is_empty() { + Ok(ReceiverBound::Choice(input.parse()?)) + } else { + Ok(ReceiverBound::Recv(input.parse()?)) + } + } + let name: syn::Type = input.parse()?; if input.is_empty() { Ok(ReceiverSpec { name, - types: Punctuated::new(), + bounds: Punctuated::new(), }) } else { let _for: Token![for] = input.parse()?; let types = if input.is_empty() { Punctuated::new() } else { - input.parse_terminated(syn::Type::parse)? + input.parse_terminated(parse_bound)? }; - Ok(ReceiverSpec { name, types }) + Ok(ReceiverSpec { + name, + bounds: types, + }) } } } @@ -624,7 +596,7 @@ fn where_predicates_mut( /// # fn e() {} /// #[Transmitter(Tx for bool, i64, Vec)] /// # fn f() {} -/// #[Transmitter(Tx for bool, ref i64, ref mut Vec)] +/// #[Transmitter(Tx for match, bool, Option, ref i64, ref mut Vec)] /// # fn g() {} /// ``` /// @@ -637,7 +609,7 @@ fn where_predicates_mut( /// or `mut`), and types `T1`, `T2`, `...`, the invocation: /// /// ```ignore -/// #[Transmitter(Tx for C1? T1, C2? T2, ...)] +/// #[Transmitter(Tx for (match,)? C1? T1, C2? T2, ...)] /// fn f() {} /// ``` /// @@ -645,15 +617,19 @@ fn where_predicates_mut( /// /// ``` /// use dialectic::prelude::*; +/// # use vesta::Match; /// /// # type C1 = Ref; /// # type C2 = Ref; /// # struct T1; +/// # #[derive(Match)] /// # struct T2; /// # /// fn f() /// where /// Tx: Transmitter + Send + 'static, +/// // If `match` is specified, a `TransmitChoice` bound is emitted: +/// Tx: TransmitChoice, /// // For each of the types `T1`, `T2`, ... /// // If the convention is unspecified, `C` is left unspecified; /// // otherwise, we translate into `call_by` conventions using @@ -669,7 +645,10 @@ pub fn Transmitter( params: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { - let TransmitterSpec { name, types } = parse_macro_input!(params as TransmitterSpec); + let TransmitterSpec { + name, + bounds: types, + } = parse_macro_input!(params as TransmitterSpec); let dialectic_path = dialectic_compiler::dialectic_path(); let mut item = parse_macro_input!(input as syn::Item); if let Some(predicates) = where_predicates_mut(&mut item) { @@ -678,10 +657,15 @@ pub fn Transmitter( + #dialectic_path::backend::Transmitter + 'static }); - for (mutability, ty) in types { - predicates.push(syn::parse_quote! { - #name: #dialectic_path::backend::Transmit<#ty, #mutability> - }); + for bound in types { + match bound { + TransmitterBound::Send(mutability, ty) => predicates.push(syn::parse_quote! { + #name: #dialectic_path::backend::Transmit<#ty, #mutability> + }), + TransmitterBound::Choice(_) => predicates.push(syn::parse_quote! { + #name: #dialectic_path::backend::TransmitChoice + }), + } } item.into_token_stream().into() } else { @@ -734,7 +718,7 @@ pub fn Transmitter( /// # fn a() {} /// #[Receiver(Rx for bool)] /// # fn b() {} -/// #[Receiver(Rx for bool, i64, Vec)] +/// #[Receiver(Rx for match, bool, i64, Vec)] /// # fn c() {} /// ``` /// @@ -746,7 +730,7 @@ pub fn Transmitter( /// For a transmitter type `Rx`, and types `T1`, `T2`, `...`, the invocation: /// /// ```ignore -/// #[Receiver(Rx for T1, T2, ...)] +/// #[Receiver(Rx for (match,)? T1, T2, ...)] /// fn f() {} /// ``` /// @@ -761,6 +745,8 @@ pub fn Transmitter( /// fn f() /// where /// Rx: Receiver + Send + 'static, +/// // If `match` is present in the list, a `ReceiveChoice` bound is generated: +/// Rx: ReceiveChoice, /// // For each of the types `T1`, `T2`, ... /// Rx: Receive, /// Rx: Receive, @@ -774,7 +760,10 @@ pub fn Receiver( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { let dialectic_path = dialectic_compiler::dialectic_path(); - let ReceiverSpec { name, types } = parse_macro_input!(params as ReceiverSpec); + let ReceiverSpec { + name, + bounds: types, + } = parse_macro_input!(params as ReceiverSpec); let mut item = parse_macro_input!(input as syn::Item); if let Some(predicates) = where_predicates_mut(&mut item) { predicates.push(syn::parse_quote! { @@ -782,10 +771,15 @@ pub fn Receiver( + #dialectic_path::backend::Receiver + 'static }); - for ty in types { - predicates.push(syn::parse_quote! { - #name: #dialectic_path::backend::Receive<#ty> - }); + for bound in types { + match bound { + ReceiverBound::Recv(ty) => predicates.push(syn::parse_quote! { + #name: #dialectic_path::backend::Receive<#ty> + }), + ReceiverBound::Choice(_) => predicates.push(syn::parse_quote! { + #name: #dialectic_path::backend::ReceiveChoice + }), + } } item.into_token_stream().into() } else { diff --git a/dialectic-macro/tests/loop_choose_empty.rs b/dialectic-macro/tests/loop_choose_empty.rs index 7b1f6011..492fb36c 100644 --- a/dialectic-macro/tests/loop_choose_empty.rs +++ b/dialectic-macro/tests/loop_choose_empty.rs @@ -10,4 +10,4 @@ type Bug = Session! { } }; -assert_type_eq_all!(Bug, Loop,)>>,); +assert_type_eq_all!(Bug, Loop,), Choice<1>>>,); diff --git a/dialectic-macro/tests/lots_of_jumps.rs b/dialectic-macro/tests/lots_of_jumps.rs index e983dfb0..93142880 100644 --- a/dialectic-macro/tests/lots_of_jumps.rs +++ b/dialectic-macro/tests/lots_of_jumps.rs @@ -29,13 +29,16 @@ assert_type_eq_all!( Loop< Recv< bool, - Offer<( - Done, - Continue<1>, - Recv>, - Continue<0>, - Send>> - )>, + Offer< + ( + Done, + Continue<1>, + Recv>, + Continue<0>, + Send>> + ), + Choice<5>, + >, >, >, >, diff --git a/dialectic-macro/tests/optional_terminators.rs b/dialectic-macro/tests/optional_terminators.rs index 09857e8b..2936297b 100644 --- a/dialectic-macro/tests/optional_terminators.rs +++ b/dialectic-macro/tests/optional_terminators.rs @@ -22,11 +22,14 @@ type BigChoose = Session! { assert_type_eq_all!( BigChoose, - Choose<( - Send<(), Done>, - Send<(), Done>, - Call, - Call, - Split, Recv, Done> - )> + Choose< + ( + Send<(), Done>, + Send<(), Done>, + Call, + Call, + Split, Recv, Done> + ), + Choice<5>, + > ); diff --git a/dialectic-null/src/lib.rs b/dialectic-null/src/lib.rs index 1150048c..83b2336e 100644 --- a/dialectic-null/src/lib.rs +++ b/dialectic-null/src/lib.rs @@ -73,7 +73,9 @@ impl std::error::Error for Error {} impl backend::Transmitter for Sender { type Error = Error; +} +impl backend::TransmitChoice for Sender { fn send_choice<'async_lifetime, const LENGTH: usize>( &'async_lifetime mut self, _choice: Choice, @@ -121,27 +123,8 @@ impl backend::Transmit<(), Mut> for Sender { } } -impl backend::Transmit> for Sender { - fn send<'a, 'async_lifetime>( - &'async_lifetime mut self, - _message: as By>::Type, - ) -> Pin> + Send + 'async_lifetime>> - where - 'a: 'async_lifetime, - { - Box::pin(async { Ok(()) }) - } -} - impl backend::Receiver for Receiver { type Error = Error; - - fn recv_choice<'async_lifetime, const LENGTH: usize>( - &'async_lifetime mut self, - ) -> Pin, Self::Error>> + Send + 'async_lifetime>> - { - Box::pin(async { 0.try_into().map_err(|_| Error { _private: () }) }) - } } impl backend::Receive<()> for Receiver { @@ -152,8 +135,8 @@ impl backend::Receive<()> for Receiver { } } -impl backend::Receive> for Receiver { - fn recv<'async_lifetime>( +impl backend::ReceiveChoice for Receiver { + fn recv_choice<'async_lifetime, const LENGTH: usize>( &'async_lifetime mut self, ) -> Pin, Self::Error>> + Send + 'async_lifetime>> { diff --git a/dialectic-reconnect/Cargo.toml b/dialectic-reconnect/Cargo.toml index 48e6b6be..27248625 100644 --- a/dialectic-reconnect/Cargo.toml +++ b/dialectic-reconnect/Cargo.toml @@ -28,6 +28,7 @@ dialectic-tokio-serde = { version = "0.1", path = "../dialectic-tokio-serde" } dialectic-tokio-serde-json = { version = "0.1", path = "../dialectic-tokio-serde-json" } anyhow = "1" futures = { version = "0.3", features = ["std"], default-features = false } +vesta = "0.1" [package.metadata.docs.rs] all-features = true diff --git a/dialectic-reconnect/examples/ping.rs b/dialectic-reconnect/examples/ping.rs index d756da8f..2562f05c 100644 --- a/dialectic-reconnect/examples/ping.rs +++ b/dialectic-reconnect/examples/ping.rs @@ -69,12 +69,12 @@ async fn main() -> Result<(), Error> { // How to perform the initial handshake, receiving a session key from the server async fn init(chan: Chan) -> Result { - Ok(chan.choose::<0>().await?.recv().await?.0) + Ok(chan.choose::<0>(()).await?.recv().await?.0) } // How to perform a retry handshake, submitting a session key to the server async fn retry(key: usize, chan: Chan) -> Result<(), Error> { - let chan = chan.choose::<1>().await?; + let chan = chan.choose::<1>(()).await?; chan.send(key).await?.close(); Ok(()) } diff --git a/dialectic-reconnect/src/lib.rs b/dialectic-reconnect/src/lib.rs index 447dbf7d..3b14a852 100644 --- a/dialectic-reconnect/src/lib.rs +++ b/dialectic-reconnect/src/lib.rs @@ -84,12 +84,12 @@ //! //! // How to perform the initial handshake, receiving a session key from the server //! async fn init(chan: Chan) -> Result { -//! Ok(chan.choose::<0>().await?.recv().await?.0) +//! Ok(chan.choose::<0>(()).await?.recv().await?.0) //! } //! //! // How to perform a retry handshake, submitting a session key to the server //! async fn retry(key: usize, chan: Chan) -> Result<(), Error> { -//! let chan = chan.choose::<1>().await?; +//! let chan = chan.choose::<1>(()).await?; //! chan.send(key).await?.close(); //! Ok(()) //! } @@ -146,12 +146,12 @@ //! # //! # // How to perform the initial handshake, receiving a session key from the server //! # async fn init(chan: Chan) -> Result { -//! # Ok(chan.choose::<0>().await?.recv().await?.0) +//! # Ok(chan.choose::<0>(()).await?.recv().await?.0) //! # } //! # //! # // How to perform a retry handshake, submitting a session key to the server //! # async fn retry(key: usize, chan: Chan) -> Result<(), Error> { -//! # let chan = chan.choose::<1>().await?; +//! # let chan = chan.choose::<1>(()).await?; //! # chan.send(key).await?.close(); //! # Ok(()) //! # } diff --git a/dialectic-reconnect/src/resume.rs b/dialectic-reconnect/src/resume.rs index f485e4b9..a4d418e9 100644 --- a/dialectic-reconnect/src/resume.rs +++ b/dialectic-reconnect/src/resume.rs @@ -441,7 +441,15 @@ where Tx::Error: Send + Sync, { type Error = ResumeError; +} +#[Transmitter(Tx for match)] +#[Receiver(Rx)] +impl TransmitChoice for Sender +where + Key: Sync + Send + Eq + Hash, + Tx::Error: Send + Sync, +{ fn send_choice<'async_lifetime, const LENGTH: usize>( &'async_lifetime mut self, choice: Choice, @@ -501,13 +509,6 @@ where Rx::Error: Send + Sync, { type Error = ResumeError; - - fn recv_choice<'async_lifetime, const LENGTH: usize>( - &'async_lifetime mut self, - ) -> Pin, Self::Error>> + Send + 'async_lifetime>> - { - Box::pin(async move { retry_loop!(self.recv_choice::()) }) - } } #[Transmitter(Tx)] @@ -524,3 +525,18 @@ where Box::pin(async move { retry_loop!(self.recv()) }) } } + +#[Transmitter(Tx)] +#[Receiver(Rx for match)] +impl ReceiveChoice for Receiver +where + Key: Send + Sync + Eq + Hash, + Rx::Error: Send + Sync, +{ + fn recv_choice<'async_lifetime, const LENGTH: usize>( + &'async_lifetime mut self, + ) -> Pin, Self::Error>> + Send + 'async_lifetime>> + { + Box::pin(async move { retry_loop!(self.recv_choice::()) }) + } +} diff --git a/dialectic-reconnect/src/retry.rs b/dialectic-reconnect/src/retry.rs index 4203c35c..d81ae494 100644 --- a/dialectic-reconnect/src/retry.rs +++ b/dialectic-reconnect/src/retry.rs @@ -394,13 +394,6 @@ where HandshakeErr: Send, { type Error = RetryError; - - fn send_choice<'async_lifetime, const LENGTH: usize>( - &'async_lifetime mut self, - choice: Choice, - ) -> Pin> + Send + 'async_lifetime>> { - Box::pin(async move { retry_loop!(self.tx.send_choice(choice)) }) - } } #[Transmitter(Tx for ref T)] @@ -456,6 +449,26 @@ where } } +#[Transmitter(Tx for match)] +#[Receiver(Rx)] +impl TransmitChoice + for Sender +where + Tx: Sync, + Address: Clone + Sync + Send, + Key: Clone + Sync + Send, + Tx::Error: Sync + Send, + ConnectErr: Send, + HandshakeErr: Send, +{ + fn send_choice<'async_lifetime, const LENGTH: usize>( + &'async_lifetime mut self, + choice: Choice, + ) -> Pin> + Send + 'async_lifetime>> { + Box::pin(async move { retry_loop!(self.tx.send_choice(choice)) }) + } +} + #[Transmitter(Tx)] #[Receiver(Rx)] impl backend::Receiver @@ -469,13 +482,6 @@ where HandshakeErr: Send, { type Error = RetryError; - - fn recv_choice<'async_lifetime, const LENGTH: usize>( - &'async_lifetime mut self, - ) -> Pin, Self::Error>> + Send + 'async_lifetime>> - { - Box::pin(async move { retry_loop!(self.rx.recv_choice()) }) - } } #[Transmitter(Tx)] @@ -497,3 +503,23 @@ where Box::pin(async move { retry_loop!(self.rx.recv()) }) } } + +#[Transmitter(Tx)] +#[Receiver(Rx for match)] +impl ReceiveChoice + for Receiver +where + Rx: Sync, + Address: Clone + Sync + Send, + Key: Clone + Sync + Send, + Rx::Error: Sync + Send, + ConnectErr: Send, + HandshakeErr: Send, +{ + fn recv_choice<'async_lifetime, const LENGTH: usize>( + &'async_lifetime mut self, + ) -> Pin, Self::Error>> + Send + 'async_lifetime>> + { + Box::pin(async move { retry_loop!(self.rx.recv_choice()) }) + } +} diff --git a/dialectic-tokio-mpsc/src/lib.rs b/dialectic-tokio-mpsc/src/lib.rs index 050bf9fd..4f4b126b 100644 --- a/dialectic-tokio-mpsc/src/lib.rs +++ b/dialectic-tokio-mpsc/src/lib.rs @@ -143,7 +143,9 @@ pub enum RecvError { impl backend::Transmitter for Sender { type Error = SendError>; +} +impl backend::TransmitChoice for Sender { fn send_choice<'async_lifetime, const LENGTH: usize>( &'async_lifetime mut self, choice: Choice, @@ -198,7 +200,9 @@ where impl backend::Receiver for Receiver { type Error = RecvError; +} +impl backend::ReceiveChoice for Receiver { fn recv_choice<'async_lifetime, const LENGTH: usize>( &'async_lifetime mut self, ) -> Pin, Self::Error>> + Send + 'async_lifetime>> @@ -225,7 +229,9 @@ impl backend::Receive for Receiver { impl backend::Transmitter for UnboundedSender { type Error = SendError>; +} +impl backend::TransmitChoice for UnboundedSender { fn send_choice<'async_lifetime, const LENGTH: usize>( &'async_lifetime mut self, choice: Choice, @@ -280,7 +286,9 @@ where impl backend::Receiver for UnboundedReceiver { type Error = RecvError; +} +impl backend::ReceiveChoice for UnboundedReceiver { fn recv_choice<'async_lifetime, const LENGTH: usize>( &'async_lifetime mut self, ) -> Pin, Self::Error>> + Send + 'async_lifetime>> diff --git a/dialectic-tokio-serde/src/lib.rs b/dialectic-tokio-serde/src/lib.rs index 25d754fc..9ab3bd20 100644 --- a/dialectic-tokio-serde/src/lib.rs +++ b/dialectic-tokio-serde/src/lib.rs @@ -31,7 +31,10 @@ use std::{future::Future, pin::Pin}; use dialectic::{ - backend::{self, By, Choice, Mut, Receive, Ref, Transmit, Transmittable, Val}, + backend::{ + self, By, Choice, Mut, Receive, ReceiveChoice, Ref, Transmit, TransmitChoice, + Transmittable, Val, + }, Chan, }; use futures::sink::SinkExt; @@ -163,22 +166,21 @@ where W: AsyncWrite + Unpin + Send, { type Error = SendError; +} +impl TransmitChoice for Sender +where + F: Serializer + Unpin + Send, + F::Output: Send, + F::Error: Send, + E: Encoder + Send, + W: AsyncWrite + Unpin + Send, +{ fn send_choice<'async_lifetime, const LENGTH: usize>( &'async_lifetime mut self, choice: Choice, ) -> Pin> + Send + 'async_lifetime>> { - Box::pin(async move { - let serialized = self - .serializer - .serialize(&choice) - .map_err(SendError::Serialize)?; - self.framed_write - .send(serialized) - .await - .map_err(SendError::Encode)?; - Ok(()) - }) + >>::send(self, choice) } } @@ -296,7 +298,14 @@ where R: AsyncRead + Unpin + Send, { type Error = RecvError; +} +impl ReceiveChoice for Receiver +where + F: Deserializer + Unpin + Send, + D: Decoder + Send, + R: AsyncRead + Unpin + Send, +{ fn recv_choice<'async_lifetime, const LENGTH: usize>( &'async_lifetime mut self, ) -> Pin, Self::Error>> + Send + 'async_lifetime>> diff --git a/dialectic/Cargo.toml b/dialectic/Cargo.toml index 28ec49f7..9289d44c 100644 --- a/dialectic/Cargo.toml +++ b/dialectic/Cargo.toml @@ -13,7 +13,8 @@ readme = "../README.md" [dependencies] thiserror = "1" -call-by = { version = "^0.2.2" } +call-by = { version = "^0.2.3" } +vesta = { version = "0.1" } tokio = { version = "1", optional = true } tokio-util = { version = "0.6", features = ["codec"], optional = true } serde = { version = "1", features = ["derive"], optional = true } diff --git a/dialectic/benches/micro.rs b/dialectic/benches/micro.rs index 09683be9..d61db092 100644 --- a/dialectic/benches/micro.rs +++ b/dialectic/benches/micro.rs @@ -41,7 +41,7 @@ where chan.recv().await.unwrap().1 } -#[Transmitter(Tx)] +#[Transmitter(Tx for match)] async fn choose( chan: Chan {} } } }, Tx, Rx>, ) -> Chan {} } } }, Tx, Rx> @@ -49,10 +49,10 @@ where Rx: Send, Tx::Error: Debug, { - chan.choose::<0>().await.unwrap() + chan.choose::<0>(()).await.unwrap() } -#[Receiver(Rx)] +#[Receiver(Rx for match)] async fn offer( chan: Chan {} } } }, Tx, Rx>, ) -> Chan {} } } }, Tx, Rx> @@ -143,8 +143,8 @@ fn bench_chan_loop_group( }); } -#[Transmitter(Tx for ())] -#[Receiver(Rx for ())] +#[Transmitter(Tx for match, ())] +#[Receiver(Rx for match, ())] fn bench_all_on( c: &mut Criterion, rt_name: &str, diff --git a/dialectic/build.rs b/dialectic/build.rs index 665a8996..c2c30e90 100644 --- a/dialectic/build.rs +++ b/dialectic/build.rs @@ -26,6 +26,7 @@ fn main() -> Result<(), Box> { "use crate::types::{{Send, Recv, Choose, Offer, Call, Split, Loop, Continue, Done}};" )?; writeln!(f, "use crate::Session;")?; + writeln!(f, "use crate::backend::Choice;")?; writeln!(f, "use static_assertions::assert_impl_all;")?; writeln!(f)?; @@ -90,7 +91,7 @@ impl Display for Session { if count == 1 { write!(f, ",")?; } - write!(f, ")>")?; + write!(f, "), Choice<{}>>", count)?; } Offer(cs) => { let count = cs.len(); @@ -104,7 +105,7 @@ impl Display for Session { if count == 1 { write!(f, ",")?; } - write!(f, ")>")?; + write!(f, "), Choice<{}>>", count)?; } Continue(n) => { write!(f, "Continue<{}>", n)?; diff --git a/dialectic/examples/stack.rs b/dialectic/examples/stack.rs index 6b2dac00..0e9c0878 100644 --- a/dialectic/examples/stack.rs +++ b/dialectic/examples/stack.rs @@ -26,7 +26,7 @@ type Client = Session! { }; /// The implementation of the client. -#[Transmitter(Tx for ref str)] +#[Transmitter(Tx for match, ref str)] #[Receiver(Rx for String)] async fn client( mut input: BufReader, @@ -64,7 +64,7 @@ async fn client_prompt( /// reference instead of by value. This function can't be written in `async fn` style because it is /// recursive, and current restrictions in Rust mean that recursive functions returning futures must /// explicitly return a boxed `dyn Future` object. -#[Transmitter(Tx for ref str)] +#[Transmitter(Tx for match, ref str)] #[Receiver(Rx for String)] fn client_rec<'a, Tx, Rx>( size: usize, @@ -83,10 +83,10 @@ where let string = client_prompt(input, output, size).await?; if string.is_empty() { // Break this nested loop (about to go to pop/quit) - break chan.choose::<0>().await?.close(); + break chan.choose::<0>(()).await?.close(); } else { // Push the string to the stack - let chan = chan.choose::<1>().await?.send_ref(&string).await?; + let chan = chan.choose::<1>(()).await?.send_ref(&string).await?; // Recursively do `Client` let chan = chan .call(|chan| client_rec(size + 1, input, output, chan)) @@ -114,7 +114,7 @@ type Server = ::Dual; /// `async fn` style because it is recursive, and current restrictions in Rust mean that recursive /// functions returning futures must explicitly return a boxed `dyn Future` object. #[Transmitter(Tx for ref str)] -#[Receiver(Rx for String)] +#[Receiver(Rx for match, String)] fn server( mut chan: Chan, ) -> Pin>> + Send>> diff --git a/dialectic/examples/tally.rs b/dialectic/examples/tally.rs index c65da7d4..2af7349b 100644 --- a/dialectic/examples/tally.rs +++ b/dialectic/examples/tally.rs @@ -31,7 +31,7 @@ pub type Client = Session! { }; /// The implementation of the client. -#[Transmitter(Tx for Operation, i64)] +#[Transmitter(Tx for match, Operation, i64)] #[Receiver(Rx for i64)] async fn client( mut input: BufReader, @@ -47,7 +47,7 @@ where chan = if let Ok(operation) = prompt("Operation (+ or *): ", &mut input, &mut output, str::parse).await { - let chan = chan.choose::<0>().await?.send(operation).await?; + let chan = chan.choose::<0>(()).await?.send(operation).await?; output .write_all("Enter numbers (press ENTER to tally):\n".as_bytes()) .await?; @@ -57,13 +57,13 @@ where .await?; let chan = chan.unwrap(); if done { - break chan.choose::<1>().await?; + break chan.choose::<1>(()).await?; } else { chan } } else { // End of input, so quit - break chan.choose::<1>().await?; + break chan.choose::<1>(()).await?; } } .close(); @@ -85,7 +85,7 @@ pub type ClientTally = Session! { }; /// The implementation of the client's tally subroutine. -#[Transmitter(Tx for Operation, i64)] +#[Transmitter(Tx for match, Operation, i64)] #[Receiver(Rx for i64)] async fn client_tally( operation: &Operation, @@ -115,10 +115,10 @@ where .await; match user_input { // User wants to add another number to the tally - Ok(Some(n)) => chan = chan.choose::<0>().await?.send(n).await?, + Ok(Some(n)) => chan = chan.choose::<0>(()).await?.send(n).await?, // User wants to finish this tally Ok(None) | Err(_) => { - let (tally, chan) = chan.choose::<1>().await?.recv().await?; + let (tally, chan) = chan.choose::<1>(()).await?.recv().await?; output .write_all(format!("= {}\n", tally).as_bytes()) .await?; @@ -136,7 +136,7 @@ type Server = ::Dual; /// The implementation of the server for each client connection. #[Transmitter(Tx for i64)] -#[Receiver(Rx for Operation, i64)] +#[Receiver(Rx for match, Operation, i64)] async fn server(mut chan: Chan) -> Result<(), Box> where Tx::Error: Error + Send, @@ -161,7 +161,7 @@ type ServerTally = ::Dual; /// The implementation of the server's tally subroutine. #[Transmitter(Tx for i64)] -#[Receiver(Rx for Operation, i64)] +#[Receiver(Rx for match, Operation, i64)] async fn server_tally( op: Operation, mut chan: Chan, diff --git a/dialectic/src/backend.rs b/dialectic/src/backend.rs index 504b3f60..ab5a5aba 100644 --- a/dialectic/src/backend.rs +++ b/dialectic/src/backend.rs @@ -4,7 +4,13 @@ //! receiving channel `Rx`. In order to use a `Chan` to run a session, these underlying channels //! must implement the traits [`Transmitter`] and [`Receiver`], as well as [`Transmit`](Transmit) //! and [`Receive`](Receive) for at least the types `T` used in those capacities in any given -//! session. +//! session. If you want to use [`Choice`] to carry your branching messages, you will also need +//! [`TransmitChoice`] and [`ReceiveChoice`]. These two pairs of traits are what need to be +//! implemented for a transport backend to function with Dialectic; if a custom backend for an +//! existing protocol is being implemented, [`TransmitChoice`] and [`ReceiveChoice`] do not +//! necessarily need to be implemented and can be left out as necessary/desired. Branching with +//! `offer` and `choose` can still be achieved with any type that implements the Vesta `Match` +//! trait, as long as that type can be transmitted or received by the backend in question. //! //! Functions which are generic over their backend will in turn need to specify the bounds //! [`Transmit`](Transmit) and [`Receive`](Receive) for all `T`s they send and receive, @@ -14,15 +20,21 @@ #[doc(no_inline)] pub use call_by::{By, Convention, Mut, Ref, Val}; -use std::{future::Future, pin::Pin}; +use std::{ + convert::{TryFrom, TryInto}, + future::Future, + pin::Pin, +}; +pub use vesta::{Case, Match}; mod choice; pub use choice::*; /// A backend transport used for transmitting (i.e. the `Tx` parameter of [`Chan`](crate::Chan)) /// must implement [`Transmitter`], which specifies what type of errors it might return, as well as -/// giving a method to send [`Choice`]s across the channel. This is a super-trait of [`Transmit`], -/// which is what's actually needed to receive particular values over a [`Chan`](crate::Chan). +/// giving a method to send [`Choice`]s across the channel. This is a super-trait of [`Transmit`] +/// and [`TransmitChoice`], which are what's actually needed to receive particular values over a +/// [`Chan`](crate::Chan). /// /// If you're writing a function and need a lot of different [`Transmit`](Transmit) bounds, the /// [`Transmitter`](macro@crate::Transmitter) attribute macro can help you specify them more @@ -30,12 +42,6 @@ pub use choice::*; pub trait Transmitter { /// The type of possible errors when sending. type Error; - - /// Send any `Choice` using the [`Convention`] specified by the trait implementation. - fn send_choice<'async_lifetime, const LENGTH: usize>( - &'async_lifetime mut self, - choice: Choice, - ) -> Pin> + Send + 'async_lifetime>>; } /// A marker trait indicating that some type `T` is transmittable as the associated type @@ -112,22 +118,120 @@ where 'a: 'async_lifetime; } +/// A trait describing how we can transmit a [`Choice`](Choice) for any `N: usize`. If your +/// backend should be able to transmit any [`Choice`](Choice), this is the trait to implement; +/// you'll get a blanket impl from it for `TransmitCase>`. Implementing this trait will +/// not allow your backend to [`Transmit`] a `Choice`, but there is generally no reason to do so. +pub trait TransmitChoice: Transmitter { + /// Send a [`Choice`](Choice) over the backend, for some `N`. + fn send_choice<'async_lifetime, const LENGTH: usize>( + &'async_lifetime mut self, + choice: Choice, + ) -> Pin> + Send + 'async_lifetime>>; +} + +/// If a transport is [`TransmitCase`](TransmitCase), we can use it to +/// [`send_case`](TransmitCase::send_case) a message of type `T` by [`Val`], [`Ref`], or [`Mut`], +/// depending on the calling convention specified by `C`. [`TransmitCase`] is highly similar to +/// [`Transmit`], with a major difference: [`TransmitCase`] may be used to send only *part* of an +/// `enum` datatype, as part of a [`Chan::choose`](crate::Chan::choose) call. This is because the +/// discriminant is actually the constant `N: usize` parameter to the [`TransmitCase::send_case`] +/// method. This matching/construction/deconstruction is done through the +/// [`vesta`](https://docs.rs/vesta) crate and its `Match` and `Case` traits; implementation of +/// these traits does not need to be done by hand as Vesta provides a derive macro for them. +/// Similarly, [`TransmitCase`] never needs to be implemented by hand; blanket implementations are +/// provided to allow branching transmissions of `Choice` for all `N` (when a backend supports +/// `TransmitChoice`) and `T: Match` (when a backend supports [`Transmit`](Transmit).) +pub trait TransmitCase: + Transmitter + sealed::TransmitCase +where + T: Transmittable + Match, +{ + /// Send a "case" of a [`Match`]-able type. + fn send_case<'a, 'async_lifetime>( + &'async_lifetime mut self, + message: <>::Case as By<'a, C>>::Type, + ) -> Pin> + Send + 'async_lifetime>> + where + T: Case, + >::Case: By<'a, C>, + 'a: 'async_lifetime; +} + +/// This is a wrapper type which disambiguates, at the type level, "custom" choice types, and makes +/// sure Rust sees them as a different type from `Choice`. `CustomChoice` is an implementation +/// detail which is generated by the `Session!` macro when outputting [`Choose`](crate::Choose) and +/// [`Offer`](crate::Offer) types. If you see it in an error message, it means you wrote a `choose` +/// or `offer` statement with a custom carrier type, which is held as the type parameter to +/// `CustomChoice`. +#[derive(Debug)] +pub struct CustomChoice(pub T); + +unsafe impl Match for CustomChoice { + type Range = T::Range; + + fn tag(&self) -> Option { + self.0.tag() + } +} + +impl, const N: usize> Case for CustomChoice { + type Case = T::Case; + + unsafe fn case(this: Self) -> Self::Case { + T::case(this.0) + } + + fn uncase(case: Self::Case) -> Self { + CustomChoice(T::uncase(case)) + } +} + +impl + TransmitCase, Val, N> for Tx +{ + fn send_case<'a, 'async_lifetime>( + &'async_lifetime mut self, + _message: < as Case>::Case as By<'a, Val>>::Type, + ) -> Pin> + Send + 'async_lifetime>> + where + Choice: Case, + as Case>::Case: By<'a, Val>, + 'a: 'async_lifetime, + { + // FIXME(sleffy): no unwrap + self.send_choice::( + TryFrom::::try_from(N.try_into().unwrap()).expect("N < LENGTH by trait properties"), + ) + } +} + +impl, T: Match + Transmittable, const N: usize> + TransmitCase, Val, N> for Tx +{ + fn send_case<'a, 'async_lifetime>( + &'async_lifetime mut self, + message: < as Case>::Case as By<'a, Val>>::Type, + ) -> Pin> + Send + 'async_lifetime>> + where + CustomChoice: Case, + as Case>::Case: By<'a, Val>, + 'a: 'async_lifetime, + { + self.send( as Case>::uncase(call_by::to_val(message)).0) + } +} + /// A backend transport used for receiving (i.e. the `Rx` parameter of [`Chan`](crate::Chan)) must -/// implement [`Receiver`], which specifies what type of errors it might return, as well as giving a -/// method to send [`Choice`]s across the channel. This is a super-trait of [`Receive`], which is -/// what's actually needed to receive particular values over a [`Chan`](crate::Chan). +/// implement [`Receiver`], which specifies what type of errors it might return. This is a +/// super-trait of [`Receive`] and [`ReceiveChoice`], which are the traits that are actually needed +/// to receive particular values over a [`Chan`](crate::Chan). /// /// If you're writing a function and need a lot of different [`Receive`](Receive) bounds, the /// [`Receiver`](macro@crate::Receiver) attribute macro can help you specify them more succinctly. pub trait Receiver { /// The type of possible errors when receiving. type Error; - - /// Receive any `Choice`. It is impossible to construct a `Choice<0>`, so if `N = 0`, a - /// [`Receiver::Error`] must be returned. - fn recv_choice<'async_lifetime, const LENGTH: usize>( - &'async_lifetime mut self, - ) -> Pin, Self::Error>> + Send + 'async_lifetime>>; } /// If a transport is [`Receive`](Receive), we can use it to [`recv`](Receive::recv) a message of @@ -152,3 +256,78 @@ pub trait Receive: Receiver { &'async_lifetime mut self, ) -> Pin> + Send + 'async_lifetime>>; } + +/// A trait describing how we can receive a [`Choice`](Choice) for any `N: usize`. If your +/// backend should be able to receive any [`Choice`](Choice), this is the trait to implement; +/// you'll get a blanket impl from it for `ReceiveCase>`. Implementing this trait will +/// not allow your backend to [`Receive`] a `Choice`, but there is generally no reason to do so. +pub trait ReceiveChoice: Receiver { + /// Receive any `Choice`. It is impossible to construct a `Choice<0>`, so if `N = 0`, a + /// [`Receiver::Error`] must be returned. + fn recv_choice<'async_lifetime, const LENGTH: usize>( + &'async_lifetime mut self, + ) -> Pin, Self::Error>> + Send + 'async_lifetime>>; +} + +/// A trait describing how a receiver can receive some `T` which can be matched on and has one or +/// more associated "cases". [`ReceiveCase`] is highly similar to +/// [`Receive`], with a major difference: [`ReceiveCase`] may be used to send only *part* of an +/// `enum` datatype, as part of a [`Chan::choose`](crate::Chan::choose) call. This is because the +/// discriminant is actually the constant `N: usize` parameter to the [`ReceiveCase::recv_case`] +/// method. This matching/construction/deconstruction is done through the +/// [`vesta`](https://docs.rs/vesta) crate and its `Match` and `Case` traits; implementation of +/// these traits does not need to be done by hand as Vesta provides a derive macro for them. +/// Similarly, [`ReceiveCase`] never needs to be implemented by hand; blanket implementations are +/// provided to allow branching transmissions of `Choice` for all `N` (when a backend supports +/// `ReceiveChoice`) and `T: Match` (when a backend supports [`Receive`](Receive).) +pub trait ReceiveCase: Receiver +where + T: Match, +{ + /// Receive a case of some type which implements [`Match`]. This may require type annotations + /// for disambiguation. + fn recv_case<'async_lifetime>( + &'async_lifetime mut self, + ) -> Pin> + Send + 'async_lifetime>>; +} + +impl ReceiveCase> for Rx +where + Rx: ReceiveChoice, +{ + fn recv_case<'async_lifetime>( + &'async_lifetime mut self, + ) -> Pin, Self::Error>> + Send + 'async_lifetime>> + { + self.recv_choice::() + } +} + +impl ReceiveCase> for Rx +where + Rx: Receive + Send, +{ + fn recv_case<'async_lifetime>( + &'async_lifetime mut self, + ) -> Pin, Self::Error>> + Send + 'async_lifetime>> + { + Box::pin(async move { + let t = self.recv().await?; + Ok(CustomChoice(t)) + }) + } +} + +mod sealed { + use super::*; + + pub trait TransmitCase {} + + impl, T: Transmittable> TransmitCase, Val> for Tx {} + impl TransmitCase, Val> for Tx {} + + pub trait ReceiveCase {} + + impl, T> ReceiveCase> for Rx {} + impl ReceiveCase> for Rx {} +} diff --git a/dialectic/src/backend/choice.rs b/dialectic/src/backend/choice.rs index aa687e1e..351863e1 100644 --- a/dialectic/src/backend/choice.rs +++ b/dialectic/src/backend/choice.rs @@ -1,6 +1,18 @@ use crate::unary::*; use std::convert::{TryFrom, TryInto}; use thiserror::Error; +use vesta::{Case, Exhaustive, Match}; + +/// A trait mapping a `Number` to a `Choice`, so that wrapped/calculated const generic +/// parameters can be transformed to `Choice`s without needing an extra parameter. +pub trait ToChoice: Constant { + /// The resulting `Choice`. + type AsChoice; +} + +impl ToChoice for Number { + type AsChoice = Choice; +} /// A `Choice` represents a selection between several protocols offered by [`offer!`](crate::offer). /// @@ -118,6 +130,29 @@ impl From> for u8 { } } +unsafe impl Match for Choice { + type Range = Exhaustive; + + fn tag(&self) -> Option { + Some(self.choice as usize) + } +} + +impl Case for Choice +where + Number: ToUnary, + Number: ToUnary, + UnaryOf: LessThan>, +{ + type Case = (); + + unsafe fn case(_: Self) -> Self::Case {} + + fn uncase(_: ()) -> Self { + Choice { choice: M as u8 } + } +} + // If the serde feature is enabled, do custom serialization for `Choice` that fails when receiving // an out-of-bounds choice. #[cfg(feature = "serde")] diff --git a/dialectic/src/chan.rs b/dialectic/src/chan.rs index ea7f78cd..5d6fb55d 100644 --- a/dialectic/src/chan.rs +++ b/dialectic/src/chan.rs @@ -4,15 +4,15 @@ use futures::Future; use pin_project::pin_project; use std::{ any::TypeId, - convert::{TryFrom, TryInto}, marker::{self, PhantomData}, mem, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll}, }; +use vesta::{Case, Exhaustive, Match}; -use crate::tuple::{HasLength, List, Tuple}; +use crate::tuple::{HasLength, Tuple}; use crate::Unavailable; use crate::{backend::*, IncompleteHalf, SessionIncomplete}; use crate::{prelude::*, types::*, unary::*}; @@ -291,18 +291,17 @@ where } } -impl Chan +impl Chan where - S: Session>, + S: Session>, Choices: Tuple, Choices::AsList: HasLength, - ::Length: ToConstant>, Tx: Transmitter + marker::Send + 'static, Rx: marker::Send + 'static, { /// Actively choose to enter the `N`th protocol offered via [`offer!`](crate::offer) by the - /// other end of the connection, alerting the other party to this choice by sending the number - /// `N` over the channel. + /// other end of the connection, alerting the other party to this choice by sending + /// corresponding `N`th case of the "carrier type" ([`Choice`] by default) over the channel. /// /// The choice `N` is specified as a `const` generic `usize`. /// @@ -338,7 +337,7 @@ where /// }); /// /// // Choose to send an integer - /// c1.choose::<0>().await?.send(42).await?; + /// c1.choose::<0>(()).await?.send(42).await?; /// /// // Wait for the offering thread to finish /// t1.await??; @@ -358,7 +357,7 @@ where /// let (c1, c2) = OnlyTwoChoices::channel(|| mpsc::channel(1)); /// /// // Try to choose something out of range (this doesn't typecheck) - /// c1.choose::<2>().await?; + /// c1.choose::<2>(()).await?; /// /// # // Wait for the offering thread to finish /// # t1.await??; @@ -367,38 +366,77 @@ where /// ``` pub async fn choose( mut self, + choice: >::Case, ) -> Result< Chan< as ToUnary>::AsUnary>>::Selected, Tx, Rx>, Tx::Error, > where + Carrier: Case, Number: ToUnary, Choices::AsList: Select< as ToUnary>::AsUnary>, as ToUnary>::AsUnary>>::Selected: Session, + Tx: TransmitCase, { - let choice: Choice = u8::try_from(N) - .expect("choices must fit into a byte") - .try_into() - .expect("type system prevents out of range choice in `choose`"); - self.tx.as_mut().unwrap().send_choice(choice).await?; + self.tx.as_mut().unwrap().send_case(choice).await?; Ok(self.unchecked_cast()) } + + // /// Identical to [`Chan::choose`], but allows you to send the carrier's case value by reference. + // /// Useful for custom carrier types. + // pub async fn choose_ref( + // mut self, + // choice: &>::Case, + // ) -> Result< + // Chan< as ToUnary>::AsUnary>>::Selected, Tx, Rx>, + // Tx::Error, + // > + // where + // Carrier: Case, + // Number: ToUnary, + // Choices::AsList: Select< as ToUnary>::AsUnary>, + // as ToUnary>::AsUnary>>::Selected: Session, + // Tx: TransmitCase, + // { + // self.tx.as_mut().unwrap().send_case::(choice).await?; + // Ok(self.unchecked_cast()) + // } + + // /// Identical to [`Chan::choose`], but allows you to send the carrier's case value by mutable + // /// reference. Useful for custom carrier types. + // pub async fn choose_mut( + // mut self, + // choice: &mut >::Case, + // ) -> Result< + // Chan< as ToUnary>::AsUnary>>::Selected, Tx, Rx>, + // Tx::Error, + // > + // where + // Carrier: Case, + // Number: ToUnary, + // Choices::AsList: Select< as ToUnary>::AsUnary>, + // as ToUnary>::AsUnary>>::Selected: Session, + // Tx: TransmitCase, + // { + // self.tx.as_mut().unwrap().send_case::(choice).await?; + // Ok(self.unchecked_cast()) + // } } -impl Chan +impl Chan where - S: Session>, + S: Session>, Choices: Tuple + 'static, Choices::AsList: HasLength + EachScoped + EachHasDual, - ::Length: ToConstant>, Z: LessThan<::Length>, Tx: marker::Send + 'static, - Rx: Receiver + marker::Send + 'static, + Rx: Receiver + ReceiveCase + marker::Send + 'static, + Carrier: Match, { /// Offer the choice of one or more protocols to the other party, and wait for them to indicate /// which protocol they'd like to proceed with. Returns a [`Branches`] structure representing - /// all the possible channel types which could be returned, which must be eliminated using - /// [`case`](Branches::case). + /// all the possible channel types which could be returned, which must be eliminated through the + /// [`Match`] and [`Case`] traits. /// ///💡 **Where possible, prefer the [`offer!`](crate::offer) macro**. This has the benefit of /// ensuring at compile time that no case is left unhandled; it's also more succinct. @@ -413,6 +451,7 @@ where /// ``` /// use dialectic::prelude::*; /// use dialectic_tokio_mpsc as mpsc; + /// use vesta::{Match, CaseExt, case}; /// /// # #[tokio::main] /// # async fn main() -> Result<(), Box> { @@ -427,18 +466,20 @@ where /// /// // Spawn a thread to offer a choice /// let t1 = tokio::spawn(async move { - /// match c2.offer().await?.case::<0>() { - /// Ok(c2) => { c2.recv().await?; }, - /// Err(rest) => match rest.case::<0>() { - /// Ok(c2) => { c2.send("Hello!".to_string()).await?; }, - /// Err(rest) => rest.empty_case(), + /// let branches = c2.offer().await?; + /// case!(branches { + /// 0(c2, ()) => { + /// c2.recv().await?; /// } - /// } + /// 1(c2, ()) => { + /// c2.send("Hello!".to_string()).await?; + /// } + /// }); /// Ok::<_, mpsc::Error>(()) /// }); /// /// // Choose to send an integer - /// c1.choose::<0>().await?.send(42).await?; + /// c1.choose::<0>(()).await?.send(42).await?; /// /// // Wait for the offering thread to finish /// t1.await??; @@ -474,16 +515,16 @@ where /// # }); /// # /// # // Choose to send an integer - /// # c1.choose::<0>().await?.send(42).await?; + /// # c1.choose::<0>(()).await?.send(42).await?; /// # /// # // Wait for the offering thread to finish /// # t1.await??; /// # Ok(()) /// # } /// ``` - pub async fn offer(self) -> Result, Rx::Error> { + pub async fn offer(self) -> Result, Rx::Error> { let (tx, mut rx, drop_tx, drop_rx) = self.unwrap_contents(); - let variant = rx.as_mut().unwrap().recv_choice::().await?.into(); + let variant = rx.as_mut().unwrap().recv_case().await?.into(); Ok(Branches { variant, tx, @@ -886,14 +927,14 @@ where #[derive(Derivative)] #[derivative(Debug)] #[must_use] -pub struct Branches +pub struct Branches where Tx: marker::Send + 'static, Rx: marker::Send + 'static, Choices: Tuple + 'static, Choices::AsList: EachScoped + EachHasDual + HasLength, { - variant: u8, + variant: Option, tx: Option, rx: Option, drop_tx: Arc>>>, @@ -901,7 +942,7 @@ where protocols: PhantomData Choices>, } -impl Drop for Branches +impl Drop for Branches where Tx: marker::Send + 'static, Rx: marker::Send + 'static, @@ -918,74 +959,72 @@ where } } -impl Branches +unsafe impl Match + for Branches where Choices: Tuple + 'static, Choices::AsList: EachScoped + EachHasDual + HasLength, ::Length: ToConstant>, Tx: marker::Send + 'static, Rx: marker::Send + 'static, + Carrier: Match>, { - /// Check if the selected protocol in this [`Branches`] was the `N`th protocol in its type. If - /// so, return the corresponding channel; otherwise, return all the other possibilities. - pub fn case( - mut self, - ) -> Result< + type Range = Exhaustive; + + fn tag(&self) -> Option { + self.variant.as_ref().unwrap().tag() + } +} + +impl Case + for Branches +where + Number: ToUnary, + Choices: Tuple + 'static, + Choices::AsList: EachScoped + EachHasDual + HasLength + Select< as ToUnary>::AsUnary>, + as ToUnary>::AsUnary>>::Selected: Session, + ::Length: ToConstant>, + Tx: marker::Send + 'static, + Rx: marker::Send + 'static, + Carrier: Match> + Case, +{ + type Case = ( Chan< as ToUnary>::AsUnary>>::Selected, Tx, Rx>, - Branches<< as ToUnary>::AsUnary>>::Remainder as List>::AsTuple, Tx, Rx>, - > - where - Number: ToUnary, - Choices::AsList: Select< as ToUnary>::AsUnary>, - as ToUnary>::AsUnary>>::Selected: Session, - as ToUnary>::AsUnary>>::Remainder: EachScoped + EachHasDual + HasLength + List, - { - let variant = self.variant; - let tx = self.tx.take(); - let rx = self.rx.take(); - let drop_tx = self.drop_tx.clone(); - let drop_rx = self.drop_rx.clone(); - let branch: u8 = N - .try_into() - .expect("branch discriminant exceeded u8::MAX in `case`"); - if variant == branch { - Ok(Chan { - tx, - rx, - drop_tx, - drop_rx, - session: PhantomData, - }) - } else { - Err(Branches { - // Subtract 1 from variant if we've eliminated a branch with a lower discriminant - variant: if variant > branch { - variant - 1 - } else { - variant - }, - tx, - rx, - drop_tx, - drop_rx, - protocols: PhantomData, - }) - } + Carrier::Case, + ); + + unsafe fn case(mut this: Self) -> Self::Case { + // FIXME(sleffy): unwrap necessary? + let carrier_case = Case::::case(this.variant.take().unwrap()); + let tx = this.tx.take(); + let rx = this.rx.take(); + let drop_tx = this.drop_tx.clone(); + let drop_rx = this.drop_rx.clone(); + let chan = Chan { + tx, + rx, + drop_tx, + drop_rx, + session: PhantomData, + }; + + (chan, carrier_case) } - /// Determine the [`Choice`] which was made by the other party, indicating which of these - /// [`Branches`] should be taken. - /// - /// Ordinarily, you should prefer the [`offer!`](crate::offer) macro in situations where you - /// need to know this value. - pub fn choice(&self) -> Choice { - self.variant - .try_into() - .expect("internal variant for `Branches` exceeds number of choices") + fn uncase((chan, carrier_case): Self::Case) -> Self { + let (tx, rx, drop_tx, drop_rx) = chan.unwrap_contents(); + Branches { + variant: Some(Carrier::uncase(carrier_case)), + tx, + rx, + drop_tx, + drop_rx, + protocols: PhantomData, + } } } -impl<'a, Tx, Rx> Branches<(), Tx, Rx> +impl<'a, Tx, Rx, Carrier> Branches<(), Carrier, Tx, Rx> where Tx: marker::Send + 'static, Rx: marker::Send + 'static, diff --git a/dialectic/src/lib.rs b/dialectic/src/lib.rs index 93311c34..114d8bf4 100644 --- a/dialectic/src/lib.rs +++ b/dialectic/src/lib.rs @@ -173,7 +173,10 @@ pub(crate) mod dialectic { /// all the bits and pieces you need to start writing programs with Dialectic. pub mod prelude { #[doc(no_inline)] - pub use crate::backend::{Choice, Receive, Receiver, Transmit, Transmitter}; + pub use crate::backend::{ + Choice, Receive, ReceiveCase, ReceiveChoice, Receiver, Transmit, TransmitCase, + TransmitChoice, Transmitter, + }; #[doc(no_inline)] pub use crate::session::Session; #[doc(no_inline)] @@ -182,4 +185,6 @@ pub mod prelude { pub use call_by::{Mut, Ref, Val}; #[doc(no_inline)] pub use dialectic_macro::{offer, Receiver, Session, Transmitter}; + #[doc(no_inline)] + pub use vesta::Match; } diff --git a/dialectic/src/tutorial.rs b/dialectic/src/tutorial.rs index 60a3fc5b..bc2cd5c3 100644 --- a/dialectic/src/tutorial.rs +++ b/dialectic/src/tutorial.rs @@ -338,7 +338,7 @@ let t1 = tokio::spawn(async move { // Make a choice let t2 = tokio::spawn(async move { - let c2 = c2.choose::<1>().await?; // select to `Send` + let c2 = c2.choose::<1>(()).await?; // select to `Send` c2.send("Hi there!".to_string()).await?; // enact the selected choice Ok::<_, mpsc::Error>(()) }); @@ -475,11 +475,11 @@ tokio::spawn(async move { // Send some numbers to be summed for n in 0..=10 { - c1 = c1.choose::<0>().await?.send(n).await?; + c1 = c1.choose::<0>(()).await?.send(n).await?; } // Get the sum -let (sum, c1) = c1.choose::<1>().await?.recv().await?; +let (sum, c1) = c1.choose::<1>(()).await?.recv().await?; c1.close(); assert_eq!(sum, 55); # Ok(()) @@ -600,12 +600,12 @@ async fn query_all( let mut answers = Vec::with_capacity(questions.len()); for question in questions.into_iter() { let (answer, c) = - chan.choose::<1>().await? + chan.choose::<1>(()).await? .call(|c| query(question, c)).await?; // Call `query` as a subroutine chan = c.unwrap(); answers.push(answer); } - chan.choose::<0>().await?.close(); + chan.choose::<0>(()).await?.close(); Ok(answers) } ``` @@ -808,6 +808,80 @@ macros, see their documentation. Additionally, the code in the is written to be backend-agnostic, and uses these attributes. They may prove an additional resource if you get stuck. +# Custom choice "carrier" types + +Under the hood, Dialectic implements sending/receiving choices from the `offer` and `choose` +constructs using a special type called `Choice`. Usually, this entails sending/receiving a +single byte; however, if you're implementing an existing protocol, you may want to branch on +something else and sending/receiving a single byte may not make any sense at all. This is where the +Vesta crate and its `Match` derive trait and macro come in; the [`offer!`] macro is implemented +under the hood using Vesta's `case!` macro, and accepts the same syntax for pattern matching. *This +is most useful when you're writing your own backend and need **message-for-message parity with +an existing protocol.*** + +As a *very* contrived example, we could rewrite the `QuerySum` example from the branching example +using `Option` to carry our decisions: + +``` +use dialectic::prelude::*; +use dialectic_tokio_mpsc as mpsc; + +type QuerySum = Session! { + loop { + choose Option { + 0 => { // None + recv i64; + break; + } + 1 => continue, // Some(i64), i64 is implicitly sent (and received on the other side) + } + } +}; + +# #[tokio::main] +# async fn main() -> Result<(), Box> { +# +let (mut c1, mut c2) = QuerySum::channel(|| mpsc::channel(1)); + +// Sum all the numbers sent over the channel +tokio::spawn(async move { + let mut sum = 0i64; + let c2 = loop { + c2 = offer!(in c2 { + 0 => break c2, + 1(n) => { + sum += n; + c2 + }, + })?; + }; + c2.send(sum).await?.close(); + Ok::<_, mpsc::Error>(()) +}); + +// Send some numbers to be summed +for n in 0..=10 { + c1 = c1.choose::<1>(n).await?; +} + +// Get the sum +let (sum, c1) = c1.choose::<0>(()).await?.recv().await?; +c1.close(); +assert_eq!(sum, 55); +# Ok(()) +# } +``` + +In general, if you are not looking for message-for-message parity with an existing protocol, there +is no reason to use custom carrier types. They give no other extra functionality over the default +`Choice`, which is more ergonomic in terms of writing bounds and also shows all of the sent data +as part of the session type when written out; none of it is implicit, like the `i64` value which is +implicitly sometimes sent by the `choose` in the above example. To show just how identical they +are, you can actually write out `Session! { choose Choice<1> { 0 => {}, } }` and it will work fine. +*But*, if you are writing a protocol implementation which is generic over its backend types, you +would have to write `Choice<1>` as an explicit bound on the transmitter type, rather than using the +universally quantified `match` keyword in the `Transmitter` macro. + # Wrapping up We've now finished our tour of everything you need to get started programming with Dialectic! ✨ diff --git a/dialectic/src/types.rs b/dialectic/src/types.rs index 7518f085..2aa59bb7 100644 --- a/dialectic/src/types.rs +++ b/dialectic/src/types.rs @@ -50,9 +50,10 @@ pub use split::*; /// ``` /// # use static_assertions::assert_type_eq_all; /// use dialectic::types::*; +/// use dialectic::backend::Choice; /// -/// type Client = Loop, Done>, Recv, Done>, Recv>)>>; -/// type Server = Loop, Call, Done>, Done>, Send>)>>; +/// type Client = Loop, Done>, Recv, Done>, Recv>), Choice<3>>>; +/// type Server = Loop, Call, Done>, Done>, Send>), Choice<3>>>; /// /// assert_type_eq_all!(Client, ::DualSession); /// ``` @@ -315,15 +316,21 @@ mod tests { fn complex_session_zero_size() { type P = Loop< Loop< - Choose<( - Send>, - Recv>, - Offer<( - Send>, - Continue<1>, - Split>, Recv>, Done>, - )>, - )>, + Choose< + ( + Send>, + Recv>, + Offer< + ( + Send>, + Continue<1>, + Split>, Recv>, Done>, + ), + Choice<3>, + >, + ), + Choice<3>, + >, >, >; assert_eq!(std::mem::size_of::

(), 0); diff --git a/dialectic/src/types/choose.rs b/dialectic/src/types/choose.rs index 0536ac70..ac33f3c3 100644 --- a/dialectic/src/types/choose.rs +++ b/dialectic/src/types/choose.rs @@ -11,57 +11,64 @@ use crate::tuple::{List, Tuple}; /// At most 128 choices can be presented to a `Choose` type; to choose from more options, nest /// `Choose`s within each other. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Choose(PhantomData Choices>); +pub struct Choose(PhantomData (Choices, Carrier)>); -impl Default for Choose { +impl Default for Choose { fn default() -> Self { Choose(PhantomData) } } -impl IsSession for Choose {} +impl IsSession for Choose {} -impl HasDual for Choose +impl HasDual for Choose where - Choices: Tuple, + Choices: Any + Tuple, Choices::AsList: EachHasDual, ::Duals: List + EachHasDual, + Carrier: 'static, { - type DualSession = Offer<<::Duals as List>::AsTuple>; + type DualSession = Offer<<::Duals as List>::AsTuple, Carrier>; } -impl Actionable for Choose { +impl Actionable for Choose { type NextAction = Self; } -impl Scoped for Choose where +impl Scoped for Choose where Choices::AsList: EachScoped { } -impl Subst for Choose +impl Subst for Choose where Choices: Tuple + 'static, Choices::AsList: EachSubst, >::Substituted: List, + Carrier: 'static, { - type Substituted = Choose<<>::Substituted as List>::AsTuple>; + type Substituted = + Choose<<>::Substituted as List>::AsTuple, Carrier>; } -impl Then for Choose +impl Then for Choose where Choices: Tuple + 'static, Choices::AsList: EachThen, >::Combined: List, + Carrier: 'static, { - type Combined = Choose<<>::Combined as List>::AsTuple>; + type Combined = + Choose<<>::Combined as List>::AsTuple, Carrier>; } -impl Lift for Choose +impl Lift for Choose where Choices: Tuple + 'static, Choices::AsList: EachLift, >::Lifted: List, + Carrier: 'static, { - type Lifted = Choose<<>::Lifted as List>::AsTuple>; + type Lifted = + Choose<<>::Lifted as List>::AsTuple, Carrier>; } diff --git a/dialectic/src/types/offer.rs b/dialectic/src/types/offer.rs index 5f75000b..8670bc5b 100644 --- a/dialectic/src/types/offer.rs +++ b/dialectic/src/types/offer.rs @@ -10,57 +10,64 @@ use crate::tuple::{List, Tuple}; /// At most 128 choices can be offered in a single `Offer` type; to supply more options, nest /// `Offer`s within each other. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Offer(PhantomData Choices>); +pub struct Offer(PhantomData (Choices, Carrier)>); -impl Default for Offer { +impl Default for Offer { fn default() -> Self { Offer(PhantomData) } } -impl IsSession for Offer {} +impl IsSession for Offer {} -impl HasDual for Offer +impl HasDual for Offer where - Choices: Tuple, + Choices: Any + Tuple, Choices::AsList: EachHasDual, ::Duals: List + EachHasDual, + Carrier: 'static, { - type DualSession = Choose<<::Duals as List>::AsTuple>; + type DualSession = Choose<<::Duals as List>::AsTuple, Carrier>; } -impl Actionable for Offer { +impl Actionable for Offer { type NextAction = Self; } -impl Scoped for Offer where +impl Scoped for Offer where Choices::AsList: EachScoped { } -impl Subst for Offer +impl Subst for Offer where Choices: Tuple + 'static, Choices::AsList: EachSubst, >::Substituted: List, + Carrier: 'static, { - type Substituted = Offer<<>::Substituted as List>::AsTuple>; + type Substituted = + Offer<<>::Substituted as List>::AsTuple, Carrier>; } -impl Then for Offer +impl Then for Offer where Choices: Tuple + 'static, Choices::AsList: EachThen, >::Combined: List, + Carrier: 'static, { - type Combined = Offer<<>::Combined as List>::AsTuple>; + type Combined = + Offer<<>::Combined as List>::AsTuple, Carrier>; } -impl Lift for Offer +impl Lift for Offer where Choices: Tuple + 'static, Choices::AsList: EachLift, >::Lifted: List, + Carrier: 'static, { - type Lifted = Offer<<>::Lifted as List>::AsTuple>; + type Lifted = + Offer<<>::Lifted as List>::AsTuple, Carrier>; }