Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 5 additions & 63 deletions crates/tx-macros/src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,18 @@ impl<'a> SerdeGenerator<'a> {
let serde = &self.serde;
let serde_str = serde.to_string();

let (legacy_variant, legacy_arm, legacy_deserialize) = self.generate_legacy_handling();
let (legacy_variant, legacy_arm) = self.generate_legacy_handling();
let untagged_variants = self.generate_untagged_variants(&legacy_variant);
let untagged_conversions = self.generate_untagged_conversions(&legacy_arm);
let deserialize_impl = self.generate_untagged_deserialize(&legacy_deserialize);

quote! {
#[derive(#serde::Serialize)]
#[derive(#serde::Serialize, #serde::Deserialize)]
#[serde(untagged, bound = #serde_bounds_str, crate = #serde_str)]
pub(crate) enum UntaggedTxTypes #generics {
Tagged(TaggedTxTypes #generics),
#untagged_variants
}

#deserialize_impl
#untagged_conversions
}
}
Expand All @@ -195,72 +193,16 @@ impl<'a> SerdeGenerator<'a> {
}

/// Generate legacy transaction handling for serde.
fn generate_legacy_handling(&self) -> (TokenStream, TokenStream, TokenStream) {
fn generate_legacy_handling(&self) -> (TokenStream, TokenStream) {
if let Some(legacy) = self.variants.legacy_variant() {
let ty = &legacy.ty;
let name = &legacy.name;
let alloy_consensus = self.alloy_consensus;

let variant = quote! { UntaggedLegacy(#ty) };
let arm = quote! { UntaggedTxTypes::UntaggedLegacy(tx) => Self::#name(tx), };
let deserialize = quote! {
if let Ok(val) = #alloy_consensus::transaction::untagged_legacy_serde::deserialize(deserializer).map(Self::UntaggedLegacy) {
return Ok(val);
}
};

(variant, arm, deserialize)
(variant, arm)
} else {
(quote! {}, quote! {}, quote! {})
}
}

/// Generate custom deserialize implementation for untagged types.
fn generate_untagged_deserialize(&self, legacy_deserialize: &TokenStream) -> TokenStream {
let generics = self.generics;
let unwrapped_generics = &generics.params;
let serde = &self.serde;
let serde_bounds = self.generate_serde_bounds();

let flattened_names = self.variants.flattened.iter().map(|v| &v.name);

quote! {
// Manually modified derived serde(untagged) to preserve the error of the TaggedTxEnvelope
// attempt. Note: This uses private serde API
impl<'de, #unwrapped_generics> #serde::Deserialize<'de> for UntaggedTxTypes #generics where #serde_bounds {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: #serde::Deserializer<'de>,
{
let content = #serde::de::DeserializeSeed::deserialize(
#serde::__private225::de::ContentVisitor::new(),
deserializer
)?;
let deserializer =
#serde::__private225::de::ContentRefDeserializer::<D::Error>::new(&content);

let tagged_res =
TaggedTxTypes::<#unwrapped_generics>::deserialize(deserializer).map(Self::Tagged);

if tagged_res.is_ok() {
// return tagged if successful
return tagged_res;
}

// proceed with flattened variants
#(
if let Ok(val) = #serde::Deserialize::deserialize(deserializer).map(Self::#flattened_names) {
return Ok(val);
}
)*

#legacy_deserialize

// return the original error, which is more useful than the untagged error
// > "data did not match any variant of untagged enum MaybeTaggedTxEnvelope"
tagged_res
}
}
(quote! {}, quote! {})
}
}

Expand Down
Loading