diff --git a/Cargo.lock b/Cargo.lock index b6575dd..009c39b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,13 +108,23 @@ checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5" [[package]] name = "bincode" -version = "1.3.3" +version = "2.0.0-rc.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +checksum = "f11ea1a0346b94ef188834a65c068a03aec181c94896d481d7a0a40d85b0ce95" dependencies = [ + "bincode_derive", "serde", ] +[[package]] +name = "bincode_derive" +version = "2.0.0-rc.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e30759b3b99a1b802a7a3aa21c85c3ded5c28e1c83170d82d70f08bbf7f3e4c" +dependencies = [ + "virtue", +] + [[package]] name = "bindgen" version = "0.65.1" @@ -1068,6 +1078,12 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "virtue" +version = "0.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dcc60c0624df774c82a0ef104151231d37da4962957d691c011c852b2473314" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index 5c6fd2e..ee03c21 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ exclude = ["**/tests/**", "**/examples/**", "**/benchmarks/**", "docs/**", ".hoo [dependencies] actix-rt = "2.8.0" base64 = "0.20.0" -bincode = "1.3.3" +bincode = { version = "2.0.0-rc.3", features = ["serde"] } bytes = "1.4.0" colored = { version = "2.1.0", optional = true } hex = "0.4.3" diff --git a/src/crypto.rs b/src/crypto.rs index 38b51d0..25e2bd6 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -1,9 +1,50 @@ pub use ring; -use std::convert::TryInto; use tracing::warn; +macro_rules! fixed_bytes_wrapper { + ($vis:vis struct $name:ident, $n:expr, $doc:literal) => { + #[doc = $doc] + #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Serialize, Deserialize)] + $vis struct $name(crate::utils::serialize_utils::FixedByteArray<$n>); + + impl $name { + pub fn from_slice(slice: &[u8]) -> Option { + Some(Self(slice.try_into().ok()?)) + } + } + + impl AsRef<[u8]> for $name { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + std::fmt::Display::fmt(&self.0, f) + } + } + + impl std::str::FromStr for $name { + type Err = as std::str::FromStr>::Err; + + fn from_str(s: &str) -> Result { + as std::str::FromStr>::from_str(s).map(Self) + } + } + + #[cfg(test)] + impl crate::utils::PlaceholderSeed for $name { + fn placeholder_seed_parts<'a>(seed_parts: impl IntoIterator) -> Self { + Self(crate::utils::placeholder_bytes( + [ concat!(stringify!($name), ":").as_bytes() ].iter().copied().chain(seed_parts) + ).into()) + } + } + }; +} + pub mod sign_ed25519 { - use super::deserialize_slice; pub use ring::signature::Ed25519KeyPair as SecretKeyBase; use ring::signature::KeyPair; pub use ring::signature::Signature as SignatureBase; @@ -11,9 +52,7 @@ pub mod sign_ed25519 { pub use ring::signature::{ED25519, ED25519_PUBLIC_KEY_LEN}; use serde::{Deserialize, Serialize}; use std::convert::TryInto; - use tracing::warn; - - pub type PublicKeyBase = ::PublicKey; + use crate::crypto::generate_random; // Constants copied from the ring library const SCALAR_LEN: usize = 32; @@ -21,57 +60,61 @@ pub mod sign_ed25519 { const SIGNATURE_LEN: usize = ELEM_LEN + SCALAR_LEN; pub const ED25519_SIGNATURE_LEN: usize = SIGNATURE_LEN; - /// Signature data + pub const ED25519_SEED_LEN: usize = 32; + + fixed_bytes_wrapper!(pub struct Signature, ED25519_SIGNATURE_LEN, "Signature data"); + fixed_bytes_wrapper!(pub struct PublicKey, ED25519_PUBLIC_KEY_LEN, "Public key data"); + + /// PKCS8 encoded secret key pair /// We used sodiumoxide serialization before (treated it as slice with 64 bit length prefix). - #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Serialize, Deserialize)] - pub struct Signature( - #[serde(serialize_with = "<[_]>::serialize")] - #[serde(deserialize_with = "deserialize_slice")] - [u8; ED25519_SIGNATURE_LEN], - ); + /// Slice and vector are serialized the same. + #[derive(Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Serialize, Deserialize)] + pub struct SecretKey(#[serde(with = "crate::utils::serialize_utils::vec_codec")] Vec); - impl Signature { + impl SecretKey { + /// Constructs a `SecretKey` from the given PKCS8 document. + /// + /// ### Arguments + /// + /// * `slice` - a slice containing the encoded PKCS8 document pub fn from_slice(slice: &[u8]) -> Option { - Some(Self(slice.try_into().ok()?)) + match SecretKeyBase::from_pkcs8(slice) { + Ok(_) => Some(Self(slice.to_vec())), + Err(_) => None, + } } - } - impl AsRef<[u8]> for Signature { - fn as_ref(&self) -> &[u8] { - self.0.as_ref() + /// Gets the public key corresponding to this secret key. + pub fn get_public_key(&self) -> PublicKey { + let keypair = SecretKeyBase::from_pkcs8(&self.0) + .expect("SecretKey contains invalid PKCS8 document?!?"); + + PublicKey::from_slice(keypair.public_key().as_ref()) + .expect("Keypair public key length is invalid?!?") } } - /// Public key data - /// We used sodiumoxide serialization before (treated it as slice with 64 bit length prefix). - #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Serialize, Deserialize)] - pub struct PublicKey( - #[serde(serialize_with = "<[_]>::serialize")] - #[serde(deserialize_with = "deserialize_slice")] - [u8; ED25519_PUBLIC_KEY_LEN], - ); - - impl PublicKey { - pub fn from_slice(slice: &[u8]) -> Option { - Some(Self(slice.try_into().ok()?)) + impl From for SecretKey { + fn from(value: ring::pkcs8::Document) -> Self { + Self(value.as_ref().to_vec()) } } - impl AsRef<[u8]> for PublicKey { - fn as_ref(&self) -> &[u8] { - self.0.as_ref() + #[cfg(test)] + impl crate::utils::PlaceholderSeed for SecretKey { + fn placeholder_seed_parts<'a>(seed_parts: impl IntoIterator) -> Self { + gen_keypair_from_seed(&crate::utils::placeholder_bytes( + [ "SecretKey:".as_bytes() ].iter().copied().chain(seed_parts) + )).1 } } - /// PKCS8 encoded secret key pair - /// We used sodiumoxide serialization before (treated it as slice with 64 bit length prefix). - /// Slice and vector are serialized the same. - #[derive(Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Serialize, Deserialize)] - pub struct SecretKey(Vec); - - impl SecretKey { - pub fn from_slice(slice: &[u8]) -> Option { - Some(Self(slice.to_vec())) + #[cfg(test)] + impl crate::utils::PlaceholderSeed for (PublicKey, SecretKey) { + fn placeholder_seed_parts<'a>(seed_parts: impl IntoIterator) -> Self { + gen_keypair_from_seed(&crate::utils::placeholder_bytes( + [ "(PublicKey, SecretKey):".as_bytes() ].iter().copied().chain(seed_parts) + )) } } @@ -87,89 +130,47 @@ pub mod sign_ed25519 { } pub fn sign_detached(msg: &[u8], sk: &SecretKey) -> Signature { - let secret = match SecretKeyBase::from_pkcs8(sk.as_ref()) { - Ok(secret) => secret, - Err(_) => { - warn!("Invalid secret key"); - return Signature([0; ED25519_SIGNATURE_LEN]); - } - }; + let keypair = SecretKeyBase::from_pkcs8(sk.as_ref()) + .expect("Invalid PKCS8 secret key?!?"); - let signature = match secret.sign(msg).as_ref().try_into() { - Ok(signature) => signature, - Err(_) => { - warn!("Invalid signature"); - return Signature([0; ED25519_SIGNATURE_LEN]); - } - }; + let signature = keypair.sign(msg).as_ref().try_into() + .expect("Invalid signature?!?"); Signature(signature) } - pub fn verify_append(sm: &[u8], pk: &PublicKey) -> bool { - if sm.len() > ED25519_SIGNATURE_LEN { - let start = sm.len() - ED25519_SIGNATURE_LEN; - let sig = Signature(match sm[start..].try_into() { - Ok(sig) => sig, - Err(_) => { - warn!("Invalid signature"); - return false; - } - }); - let msg = &sm[..start]; - verify_detached(&sig, msg, pk) - } else { - false - } - } - - pub fn sign_append(msg: &[u8], sk: &SecretKey) -> Vec { - let sig = sign_detached(msg, sk); - let mut sm = msg.to_vec(); - sm.extend_from_slice(sig.as_ref()); - sm + /// Generates a completely random Ed25519 keypair. + pub fn gen_keypair() -> (PublicKey, SecretKey) { + let seed = generate_random(); + gen_keypair_from_seed(&seed) } - pub fn gen_keypair() -> (PublicKey, SecretKey) { - let rand = ring::rand::SystemRandom::new(); - let pkcs8 = match SecretKeyBase::generate_pkcs8(&rand) { - Ok(pkcs8) => pkcs8, - Err(_) => { - warn!("Failed to generate secret key base for pkcs8"); - return (PublicKey([0; ED25519_PUBLIC_KEY_LEN]), SecretKey(vec![])); - } + /// Generates an Ed25519 keypair based on the given seed. + /// + /// ### Arguments + /// + /// * `seed` - the seed to generate the keypair from + fn gen_keypair_from_seed(seed: &[u8; ED25519_SEED_LEN]) -> (PublicKey, SecretKey) { + let rand = ring::test::rand::FixedSliceSequenceRandom { + bytes: &[ seed ], + current: core::cell::UnsafeCell::new(0), }; - let secret = match SecretKeyBase::from_pkcs8(pkcs8.as_ref()) { - Ok(secret) => secret, - Err(_) => { - warn!("Invalid secret key base"); - return (PublicKey([0; ED25519_PUBLIC_KEY_LEN]), SecretKey(vec![])); - } - }; + let pkcs8 = SecretKeyBase::generate_pkcs8(&rand) + .expect("Failed to generate secret key base for pkcs8"); - let pub_key_gen = match secret.public_key().as_ref().try_into() { - Ok(pub_key_gen) => pub_key_gen, - Err(_) => { - warn!("Invalid public key generation"); - return (PublicKey([0; ED25519_PUBLIC_KEY_LEN]), SecretKey(vec![])); - } - }; - let public = PublicKey(pub_key_gen); - let secret = match SecretKey::from_slice(pkcs8.as_ref()) { - Some(secret) => secret, - None => { - warn!("Invalid secret key"); - return (PublicKey([0; ED25519_PUBLIC_KEY_LEN]), SecretKey(vec![])); - } - }; + let keypair = SecretKeyBase::from_pkcs8(pkcs8.as_ref()) + .expect("Generated PKCS8 document is invalid?!?"); - (public, secret) + let public_key = PublicKey(keypair.public_key().as_ref().try_into() + .expect("Generated keypair contains an invalid public key?!?")); + let secret_key = pkcs8.into(); + (public_key, secret_key) } } pub mod secretbox_chacha20_poly1305 { // Use key and nonce separately like rust-tls does - use super::{deserialize_slice, generate_random}; + use super::generate_random; pub use ring::aead::LessSafeKey as KeyBase; pub use ring::aead::Nonce as NonceBase; pub use ring::aead::NONCE_LEN; @@ -179,45 +180,8 @@ pub mod secretbox_chacha20_poly1305 { pub const KEY_LEN: usize = 256 / 8; - /// key data - #[derive(Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Serialize, Deserialize)] - pub struct Key( - #[serde(serialize_with = "<[_]>::serialize")] - #[serde(deserialize_with = "deserialize_slice")] - [u8; KEY_LEN], - ); - - impl Key { - pub fn from_slice(slice: &[u8]) -> Option { - Some(Self(slice.try_into().ok()?)) - } - } - - impl AsRef<[u8]> for Key { - fn as_ref(&self) -> &[u8] { - self.0.as_ref() - } - } - - /// Nonce data - #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Serialize, Deserialize)] - pub struct Nonce( - #[serde(serialize_with = "<[_]>::serialize")] - #[serde(deserialize_with = "deserialize_slice")] - [u8; NONCE_LEN], - ); - - impl Nonce { - pub fn from_slice(slice: &[u8]) -> Option { - Some(Self(slice.try_into().ok()?)) - } - } - - impl AsRef<[u8]> for Nonce { - fn as_ref(&self) -> &[u8] { - self.0.as_ref() - } - } + fixed_bytes_wrapper!(pub struct Key, KEY_LEN, "Key data"); + fixed_bytes_wrapper!(pub struct Nonce, NONCE_LEN, "Nonce data"); pub fn seal(mut plain_text: Vec, nonce: &Nonce, key: &Key) -> Option> { let key = get_keybase(key)?; @@ -249,20 +213,20 @@ pub mod secretbox_chacha20_poly1305 { } fn get_noncebase(nonce: &Nonce) -> NonceBase { - NonceBase::assume_unique_for_key(nonce.0) + NonceBase::assume_unique_for_key(*nonce.0) } pub fn gen_key() -> Key { - Key(generate_random()) + Key(generate_random().into()) } pub fn gen_nonce() -> Nonce { - Nonce(generate_random()) + Nonce(generate_random().into()) } } pub mod pbkdf2 { - use super::{deserialize_slice, generate_random}; + use super::generate_random; use ring::pbkdf2::{derive, PBKDF2_HMAC_SHA256}; use serde::{Deserialize, Serialize}; use std::convert::TryInto; @@ -272,24 +236,7 @@ pub mod pbkdf2 { pub const SALT_LEN: usize = 256 / 8; pub const OPSLIMIT_INTERACTIVE: u32 = 100_000; - #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Serialize, Deserialize)] - pub struct Salt( - #[serde(serialize_with = "<[_]>::serialize")] - #[serde(deserialize_with = "deserialize_slice")] - [u8; SALT_LEN], - ); - - impl Salt { - pub fn from_slice(slice: &[u8]) -> Option { - Some(Self(slice.try_into().ok()?)) - } - } - - impl AsRef<[u8]> for Salt { - fn as_ref(&self) -> &[u8] { - self.0.as_ref() - } - } + fixed_bytes_wrapper!(pub struct Salt, SALT_LEN, "Salt data"); pub fn derive_key(key: &mut [u8], passwd: &[u8], salt: &Salt, iterations: u32) { let iterations = match NonZeroU32::new(iterations) { @@ -303,36 +250,78 @@ pub mod pbkdf2 { } pub fn gen_salt() -> Salt { - Salt(generate_random()) + Salt(generate_random().into()) } } pub mod sha3_256 { + use std::convert::TryInto; + use std::fmt::{Display, Formatter}; + use std::ops::Deref; + use std::str::FromStr; + pub use sha3::digest::Output; pub use sha3::Digest; pub use sha3::Sha3_256; - pub fn digest(data: &[u8]) -> Output { - Sha3_256::digest(data) + pub const HASH_LEN : usize = 256 / 8; + + /// A SHA3-256 hash. + #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] + pub struct Hash( + [u8; HASH_LEN], + ); + + impl Hash { + pub fn from_slice(slice: &[u8]) -> Option { + Some(Self(slice.try_into().ok()?)) + } } - pub fn digest_all<'a>(data: impl Iterator) -> Output { + impl Display for Hash { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + f.write_str(&hex::encode(&self.0)) + } + } + + impl FromStr for Hash { + type Err = hex::FromHexError; + + fn from_str(s: &str) -> Result { + let mut buf = [0u8; HASH_LEN]; + match hex::decode_to_slice(s, &mut buf) { + Ok(_) => Ok(Self(buf)), + Err(e) => Err(e), + } + } + } + + impl AsRef<[u8]> for Hash { + fn as_ref(&self) -> &[u8] { + &self.0 + } + } + + impl Deref for Hash { + type Target = [u8; HASH_LEN]; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + pub fn digest(data: &[u8]) -> Hash { + Hash(Sha3_256::digest(data).try_into().unwrap()) + } + + pub fn digest_all<'a>(data: impl Iterator) -> Hash { let mut hasher = Sha3_256::new(); data.for_each(|v| hasher.update(v)); - hasher.finalize() + Hash(hasher.finalize().try_into().unwrap()) } } -fn deserialize_slice<'de, D: serde::Deserializer<'de>, const N: usize>( - deserializer: D, -) -> Result<[u8; N], D::Error> { - let value: &[u8] = serde::Deserialize::deserialize(deserializer)?; - value - .try_into() - .map_err(|_| serde::de::Error::custom("Invalid array in deserialization".to_string())) -} - -pub fn generate_random() -> [u8; N] { +fn generate_random() -> [u8; N] { let mut value: [u8; N] = [0; N]; use ring::rand::SecureRandom; @@ -344,3 +333,82 @@ pub fn generate_random() -> [u8; N] { value } + +#[cfg(test)] +mod test { + use std::fmt::{Debug, Display}; + use std::str::FromStr; + use serde::Serialize; + use serde::de::DeserializeOwned; + use crate::utils::PlaceholderSeed; + use super::*; + + fn test_placeholders_different_seed() { + let [v0, v1] = W::placeholder_array_seed::<2>([]); + assert_eq!(v0, v0); + assert_eq!(v1, v1); + assert_ne!(v0, v1); + } + + fn test_fixed_bytes_wrapper + PlaceholderSeed + Serialize + DeserializeOwned>( + expected_placeholder_hex: &str, + ) { + let placeholder = W::placeholder_seed([]); + assert_eq!(placeholder.to_string(), expected_placeholder_hex); + assert_eq!(W::from_str(expected_placeholder_hex).unwrap(), placeholder); + + let expected_json = format!("\"{}\"", expected_placeholder_hex); + assert_eq!(serde_json::to_string(&placeholder).unwrap(), expected_json); + assert_eq!(serde_json::from_str::(&expected_json).unwrap(), placeholder); + + test_placeholders_different_seed::(); + } + + #[test] + fn test_ed25519_signature() { + test_fixed_bytes_wrapper::( + "9c4e2259fc9b47b4c4cf672c7436dc16ace2970955a002b69a495ca96d9dfaf026dbee622284a1cf306a1189af8a462d2ea498d10f14b637c848168b0ba698a7", + ); + } + + #[test] + fn test_ed25519_public_key() { + test_fixed_bytes_wrapper::( + "1d67f7de4c59192568f8e0381fcd1eb9ce044568e4670038e0b42e421540b4f6", + ); + } + + #[test] + fn test_ed25519_secret_key() { + assert_eq!(hex::encode(sign_ed25519::SecretKey::placeholder_seed([])), + "3053020101300506032b6570042204203651dccde39be8697d8e0690acd90e3b8ce7f596c5f205fbd0b3b3e3a68629e1a12303210092bc778f74110b3fcbcf8a4df71ed9a33c62faa8d01417d381745ef700ef6b73"); + + test_placeholders_different_seed::(); + } + + #[test] + fn test_ed25519_keypair() { + test_placeholders_different_seed::<(sign_ed25519::PublicKey, sign_ed25519::SecretKey)>(); + } + + #[test] + fn test_chacha20_key() { + test_fixed_bytes_wrapper::( + "d2678ac6abff79fa16d2a8f762a3c33b227a519ac1830aee33a19605b7f9cd35", + ); + } + + #[test] + fn test_chacha20_nonce() { + test_fixed_bytes_wrapper::( + "b0980a0a073d6b828fb48ad5", + ); + } + + #[test] + fn test_pbkdf2_salt() { + test_fixed_bytes_wrapper::( + "81c4a8cde605d6b51857eb6ebaead0de98cf254d4855725db7aec45a98699e9c", + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 99760c1..3404532 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,6 @@ +#[macro_use] +mod macros; + pub mod constants; pub mod crypto; pub mod primitives; diff --git a/src/macros.rs b/src/macros.rs new file mode 100644 index 0000000..99f75b3 --- /dev/null +++ b/src/macros.rs @@ -0,0 +1,212 @@ +/// Generate an enum to be used as an error type. +/// +/// This will automatically generate the enum, but also implement `Display` with given format +/// strings and optionally indicate the source error for errors which wrap another error. +/// +/// Example usage: +/// ```ignore +/// make_error_type!(pub enum MyError { +/// Unknown; "Unknown error", +/// IncorrectLength(length: usize); "Incorrect length {length}, expected 123", +/// InvalidHexData(source: hex::FromHexError); "Cannot convert hex string: {source}"; source, +/// }); +/// ``` +macro_rules! make_error_type { + (@fmt_source) => { None }; + (@fmt_source $sourcen:expr) => { Some($sourcen) }; + + ( + $( #[$attr:meta] )* + $vis:vis enum $name:ident { + $( $( + #[$tattr:meta] )* + $tname:ident + $(( $( $( #[$t_tuple_arg_attr:meta] )* $t_tuple_arg_name:ident : $t_tuple_arg_ty:ty),+ $(,)? ))? + $({ $( $( #[$t_struct_arg_attr:meta] )* $t_struct_arg_name:ident : $t_struct_arg_ty:ty),+ $(,)? })? + ; $tmsg:literal $(( $($tmsgarg:expr),* ))? + $( ; $sourcen:expr )? + ),+ $(,)? + } + ) => { + $( #[$attr] )* + #[derive(::std::fmt::Debug)] + $vis enum $name { + $( + $( #[$tattr] )* + $tname + $(( $( $( #[$t_tuple_arg_attr] )* $t_tuple_arg_ty),+ ))? + $({ $( $( #[$t_struct_arg_attr] )* $t_struct_arg_name : $t_struct_arg_ty),+ })? + ),+ + } + + impl std::error::Error for $name { + #[allow(unused_variables)] + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + $( + Self::$tname + $(( $($t_tuple_arg_name),+ ))? + $({ $($t_struct_arg_name),+ })? + => + make_error_type!(@fmt_source $($sourcen)?) + ),+ + } + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, _f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + $( + Self::$tname + $(( $($t_tuple_arg_name),+ ))? + $({ $($t_struct_arg_name),+ })? + => + write!(_f, $tmsg $($(, $tmsgarg)*)? ) + ),+ + } + } + } + }; +} + +/// Generate an enum where each variant represents a unique named value. +/// +/// This will automatically generate the enum, but also implement the following traits: +/// * `ToName` and `FromName` with the variant names +/// * `Display` with the variant names +/// +/// Additionally, a constant array containing every enum variant will be generated with the +/// given name and visibility. +/// +/// Example usage: +/// ```ignore +/// make_trivial_enum!(pub enum MyError { +/// Hello, +/// World, +/// } +/// all_variants=pub(crate) ALL_VARIANTS); +/// ``` +macro_rules! make_trivial_enum { + ( + @impl_toname_fromname_display + $ename:ident { + $( $vname:ident ),+ + } + $all_variants_vis:vis $all_variants:ident + ) => { + impl $ename { + #[allow(unused)] + $all_variants_vis const $all_variants : &'static [Self] = + &[ $( Self::$vname ),+ ]; + } + + impl crate::utils::ToName for $ename { + fn to_name(&self) -> &'static str { + match self { + $( Self::$vname => stringify!($vname) ),+ + } + } + } + + impl crate::utils::FromName for $ename { + const ALL_NAMES: &'static [&'static str] = &[ $( stringify!($vname) ),+ ]; + + fn from_name(name: &str) -> Result { + match name { + $( stringify!($vname) => Ok(Self::$vname), )+ + _ => Err(name), + } + } + } + + impl std::fmt::Display for $ename { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str(crate::utils::ToName::to_name(self)) + } + } + }; + + ( + $( #[$eattr:meta] )* + $evis:vis enum $ename:ident { + $( + $( #[$vattr:meta] )* + $vname:ident, + )+ + } + all_variants=$all_variants_vis:vis $all_variants:ident + ) => { + $( #[$eattr] )* + $evis enum $ename { + $( $( #[$vattr] )* $vname ),+ + } + + make_trivial_enum!( + @impl_toname_fromname_display $ename { + $( $vname ),+ + } + $all_variants_vis $all_variants); + }; +} + +/// Generate an enum where each variant simply wraps a numeric ordinal number. +/// +/// This will automatically generate the enum, but also implement the following traits: +/// * `ToOrdinal` and `FromOrdinal` with the given ordinal numbers +/// * `ToName` and `FromName` with the variant names +/// * `Display` with the variant names +/// +/// Additionally, a constant array containing every enum variant will be generated with the +/// given name and visibility. +/// +/// Example usage: +/// ```ignore +/// make_ordinal_enum!(pub enum MyError { +/// Hello = 2, +/// World = 3, +/// } +/// all_variants=pub(crate) ALL_VARIANTS); +/// ``` +macro_rules! make_ordinal_enum { + ( + $( #[$eattr:meta] )* + $evis:vis enum $ename:ident { + $( + $( #[$vattr:meta] )* + $vname:ident = $vord:literal, + )+ + } + all_variants=$all_variants_vis:vis $all_variants:ident + ) => { + $( #[$eattr] )* + $evis enum $ename { + $( $( #[$vattr] )* $vname = $vord, )+ + } + + make_trivial_enum!( + @impl_toname_fromname_display $ename { + $( $vname ),+ + } + $all_variants_vis $all_variants); + + impl crate::utils::ToOrdinal for $ename { + fn to_ordinal(&self) -> u32 { + match self { + $( Self::$vname => $vord ),+ + } + } + } + + impl crate::utils::FromOrdinal for $ename { + const ALL_ORDINALS: &'static [u32] = &[ $( $vord ),+ ]; + + fn from_ordinal(ordinal: u32) -> Result { + match ordinal { + $( $vord => Ok(Self::$vname), )+ + _ => Err(ordinal), + } + } + } + }; +} diff --git a/src/primitives/block.rs b/src/primitives/block.rs index 72e07e9..18cb611 100644 --- a/src/primitives/block.rs +++ b/src/primitives/block.rs @@ -4,7 +4,6 @@ use crate::crypto::sha3_256::{self, Sha3_256}; use crate::crypto::sign_ed25519::PublicKey; use crate::primitives::asset::Asset; use crate::primitives::transaction::{Transaction, TxIn, TxOut}; -use bincode::{deserialize, serialize}; use bytes::Bytes; use serde::{Deserialize, Serialize}; use std::convert::TryInto; @@ -82,25 +81,25 @@ impl Block { /// Sets the internal number of bits based on length pub fn set_bits(&mut self) { - let bytes = Bytes::from(match serialize(&self) { + let bytes = match bincode::serde::encode_to_vec(&self, bincode::config::legacy()) { Ok(bytes) => bytes, Err(e) => { warn!("Failed to serialize block: {:?}", e); return; } - }); + }; self.header.bits = bytes.len(); } /// Checks whether a block has hit its maximum size pub fn is_full(&self) -> bool { - let bytes = Bytes::from(match serialize(&self) { + let bytes = match bincode::serde::encode_to_vec(&self, bincode::config::legacy()) { Ok(bytes) => bytes, Err(e) => { warn!("Failed to serialize block: {:?}", e); return false; } - }); + }; bytes.len() >= MAX_BLOCK_SIZE } @@ -143,7 +142,8 @@ pub fn gen_random_hash() -> String { /// /// * `transactions` - Transactions to construct a merkle tree for pub fn build_hex_txs_hash(transactions: &[String]) -> String { - let txs = match serialize(transactions) { + // TODO: This is bad, it won't produce the same result on 32-bit systems + let txs = match bincode::serde::encode_to_vec(transactions, bincode::config::legacy()) { Ok(bytes) => bytes, Err(e) => { warn!("Failed to serialize transactions: {:?}", e); diff --git a/src/primitives/transaction.rs b/src/primitives/transaction.rs index 08435fd..ab8ac2b 100644 --- a/src/primitives/transaction.rs +++ b/src/primitives/transaction.rs @@ -1,4 +1,6 @@ #![allow(unused)] + +use std::convert::TryInto; use crate::constants::*; use crate::crypto::sign_ed25519::{PublicKey, Signature}; use crate::primitives::{ @@ -8,10 +10,11 @@ use crate::primitives::{ use crate::script::lang::Script; use crate::script::{OpCodes, StackEntry}; use crate::utils::is_valid_amount; -use bincode::serialize; -use bytes::Bytes; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::fmt; +use std::str::FromStr; +use bincode::{Decode, Encode}; +use crate::crypto::sha3_256; #[derive(Debug, Clone, Serialize, Deserialize)] pub enum GenesisTxHashSpec { @@ -38,11 +41,125 @@ pub struct TxConstructor { pub address_version: Option, } +const TX_HASH_LENGTH_BYTES : usize = TX_HASH_LENGTH / 2; + +/// Compact transaction hash representation. +/// +/// For legacy reasons, this wraps 31 hexadecimal digits worth of data, equivalent to 15.5 bytes. +/// Because of this, the 4 least significant bits of the last byte are unused. While awkward, this +/// actually means we have a convenient location to squeeze in a version indicator if we decide to +/// extend the transaction hash size in the future. +#[derive(Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd, Encode, Decode)] +pub struct TxHash([u8; TX_HASH_LENGTH_BYTES]); + +make_error_type!(pub enum TxHashError { + BadByteCount(size: usize); "Transaction hash needs {TX_HASH_LENGTH_BYTES} bytes, got {size}", + BadZeroBits; "Transaction hash must end with four zero bits", + + InvalidStringLength(input: String); "Transaction hash \"{input}\" has incorrect length", + InvalidPrefix(input: String); "Transaction hash \"{input}\" has incorrect prefix", + InvalidHexData(input: String, cause: hex::FromHexError); "Transaction hash \"{input}\" is invalid: {cause}"; cause, +}); + +impl TxHash { + /// Constructs a new `TransactionHash` from the given bytes. + /// + /// Fails if the given slice does not contain a valid encoded `TransactionHash`. + pub fn from_slice(slice: &[u8]) -> Result { + let bytes : [u8; TX_HASH_LENGTH_BYTES] = slice.try_into() + .map_err(|_| TxHashError::BadByteCount(slice.len()))?; + + // The four least significant bits of the last byte must be zero, as a transaction + // hash consists of an odd number of hexadecimal digits. + if (bytes[TX_HASH_LENGTH_BYTES - 1] & 0xF) != 0 { + return Err(TxHashError::BadZeroBits); + } + + Ok(Self(bytes)) + } + + /// Constructs a `TransactionHash` based on the given SHA3-256 hash. + pub fn from_hash(hash: sha3_256::Hash) -> Self { + let mut chunk = (*hash).first_chunk::().unwrap().clone(); + chunk[TX_HASH_LENGTH_BYTES - 1] &= 0xF0; + Self::from_slice(&chunk).unwrap() + } +} + +#[cfg(test)] +impl crate::utils::PlaceholderSeed for TxHash { + fn placeholder_seed_parts<'a>(seed_parts: impl IntoIterator) -> Self { + let mut bytes = crate::utils::placeholder_bytes::( + [ "TxHash:".as_bytes() ].iter().copied().chain(seed_parts) + ); + bytes[TX_HASH_LENGTH_BYTES - 1] &= 0xF0; + Self::from_slice(&bytes).unwrap() + } +} + +impl fmt::Display for TxHash { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Encode the binary data as hex, and add the prefix character. + // The buffer is one character larger than necessary because of the trailing four zero bits. + let mut chars = [0u8; {TX_HASH_LENGTH + 1}]; + chars[0] = TX_PREPEND; + hex::encode_to_slice(self.0, &mut chars[1..]).unwrap(); + f.write_str(std::str::from_utf8(&chars[0..TX_HASH_LENGTH]).unwrap()) + } +} + +impl FromStr for TxHash { + type Err = TxHashError; + + fn from_str(input: &str) -> Result { + if input.len() != TX_HASH_LENGTH { + return Err(TxHashError::InvalidStringLength(input.to_string())); + } else if input.as_bytes()[0] != TX_PREPEND { + return Err(TxHashError::InvalidPrefix(input.to_string())); + } + + // Strip the leading TX_PREPEND character, and then pad the string by adding an + // additional trailing '0' character so that the hex string is parseable. + let mut chars = [0u8; TX_HASH_LENGTH]; + *chars.first_chunk_mut::<{TX_HASH_LENGTH - 1}>().unwrap() = input.as_bytes()[1..].try_into().unwrap(); + chars[TX_HASH_LENGTH - 1] = '0' as u8; + + // Parse the hex string + let mut bytes = [0u8; TX_HASH_LENGTH_BYTES]; + hex::decode_to_slice(&chars, &mut bytes) + .map_err(|e| TxHashError::InvalidHexData(input.to_string(), e))?; + Self::from_slice(&bytes) + } +} + +impl AsRef<[u8]> for TxHash { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + +impl Serialize for TxHash { + fn serialize(&self, serializer: S) -> Result { + assert!(serializer.is_human_readable(), "serializer must be human-readable!"); + + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for TxHash { + fn deserialize>(deserializer: D) -> Result { + assert!(deserializer.is_human_readable(), "deserializer must be human-readable!"); + + let text : String = serde::Deserialize::deserialize(deserializer)?; + text.parse().map_err(::custom) + } +} + /// An outpoint - a combination of a transaction hash and an index n into its vout #[derive(Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd, Serialize, Deserialize)] pub struct OutPoint { pub t_hash: String, - pub n: i32, + pub n: u32, } impl fmt::Display for OutPoint { @@ -53,14 +170,26 @@ impl fmt::Display for OutPoint { impl OutPoint { /// Creates a new outpoint instance - pub fn new(t_hash: String, n: i32) -> OutPoint { + // TODO: jrabil: remove this + pub fn new(t_hash: String, n: u32) -> OutPoint { OutPoint { t_hash, n } } + + /// Creates a new outpoint instance + pub fn new_hash(t_hash: TxHash, n: u32) -> OutPoint { + OutPoint { t_hash: t_hash.to_string(), n } + } } -impl Default for OutPoint { - fn default() -> Self { - Self::new(String::new(), 0) +#[cfg(test)] +impl crate::utils::PlaceholderSeed for OutPoint { + fn placeholder_seed_parts<'a>(seed_parts: impl IntoIterator) -> Self { + Self { + t_hash: TxHash::placeholder_seed_parts( + ["OutPoint:".as_bytes()].iter().copied().chain(seed_parts) + ).to_string(), + n: 0, + } } } @@ -205,15 +334,6 @@ impl Transaction { } } - /// Get the total transaction size in bytes - pub fn get_total_size(&self) -> usize { - let bytes = match serialize(self) { - Ok(bytes) => bytes, - Err(_) => vec![], - }; - bytes.len() - } - /// Gets the create asset assigned to this transaction, if it exists fn get_create_asset(&self) -> Option<&Asset> { let is_create = self.inputs.len() == 1 @@ -251,3 +371,46 @@ impl Transaction { false } } + +/*---- TESTS ----*/ + +#[cfg(test)] +mod tests { + use crate::utils::PlaceholderSeed; + use super::*; + + #[test] + fn test_tx_hash_string() { + let hash = TxHash::placeholder_indexed(0); + let string = hash.to_string(); + assert_eq!(string, "g1a30d8257870b5d077fc55d1faa63aa"); + assert_eq!(TxHash::from_str(&string).unwrap(), hash); + } + + #[test] + fn test_tx_hash_slice() { + let hash = TxHash::placeholder_indexed(0); + let bytes = hash.as_ref().to_vec(); + assert_eq!(hex::encode(&bytes), "1a30d8257870b5d077fc55d1faa63aa0"); + assert_eq!(TxHash::from_slice(&bytes).unwrap(), hash); + } + + #[test] + fn test_tx_hash_bincode() { + let config = bincode::config::standard(); + let hash = TxHash::placeholder_indexed(0); + + let serialized = bincode::encode_to_vec(&hash, config.clone()).unwrap(); + assert_eq!(&serialized, hash.as_ref()); + let deserialized: TxHash = bincode::decode_from_slice(&serialized, config.clone()).unwrap().0; + assert_eq!(deserialized, hash); + } + + #[test] + fn test_tx_hash_serdejson() { + let hash = TxHash::placeholder_indexed(0); + let json = serde_json::to_string(&hash).unwrap(); + assert_eq!(json, "\"g1a30d8257870b5d077fc55d1faa63aa\""); + assert_eq!(serde_json::from_str::(&json).unwrap(), hash); + } +} diff --git a/src/script/interface_ops.rs b/src/script/interface_ops.rs index 7b9875b..ba2a63c 100644 --- a/src/script/interface_ops.rs +++ b/src/script/interface_ops.rs @@ -11,8 +11,6 @@ use crate::utils::error_utils::*; use crate::utils::transaction_utils::{ construct_address, construct_address_temp, construct_address_v0, }; -use bincode::de; -use bincode::serialize; use bytes::Bytes; use hex::encode; use std::collections::BTreeMap; diff --git a/src/script/lang.rs b/src/script/lang.rs index 259465e..2498e5c 100644 --- a/src/script/lang.rs +++ b/src/script/lang.rs @@ -8,8 +8,6 @@ use crate::script::interface_ops::*; use crate::script::{OpCodes, StackEntry}; use crate::utils::error_utils::*; use crate::utils::transaction_utils::{construct_address, construct_address_for}; -use bincode::serialize; -use bytes::Bytes; use hex::encode; use serde::{Deserialize, Serialize}; use tracing::{error, warn}; diff --git a/src/utils/mod.rs b/src/utils/mod.rs index adb6987..3864f3b 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -8,6 +8,7 @@ use crate::primitives::asset::TokenAmount; pub mod druid_utils; pub mod error_utils; pub mod script_utils; +pub mod serialize_utils; pub mod test_utils; pub mod transaction_utils; @@ -51,3 +52,143 @@ pub fn add_btreemap( }); m1 } + +/// A trait which indicates that it is possible to acquire a "placeholder" value +/// of a type, which can be used for test purposes. +#[cfg(test)] +pub trait Placeholder : Sized { + /// Gets a placeholder value of this type which can be used for test purposes. + fn placeholder() -> Self; + + /// Gets an array of placeholder values of this type which can be used for test purposes. + fn placeholder_array() -> [Self; N] { + core::array::from_fn(|_| Self::placeholder()) + } +} + +/// A trait which indicates that it is possible to acquire a "placeholder" value +/// of a type, which can be used for test purposes. These placeholder values are consistent +/// across program runs. +#[cfg(test)] +pub trait PlaceholderSeed: Sized + PartialEq { + /// Gets a dummy valid of this type which can be used for test purposes. + /// + /// This allows acquiring multiple distinct placeholder values which are still consistent + /// between runs. + /// + /// ### Arguments + /// + /// * `seed_parts` - the parts of the seed for the placeholder value to obtain. Two placeholder + /// values generated from the same seed are guaranteed to be equal (even + /// across multiple test runs, so long as the value format doesn't change). + fn placeholder_seed_parts<'a>(seed_parts: impl IntoIterator) -> Self; + + /// Gets a dummy valid of this type which can be used for test purposes. + /// + /// This allows acquiring multiple distinct placeholder values which are still consistent + /// between runs. + /// + /// ### Arguments + /// + /// * `seed` - the seed for the placeholder value to obtain. Two placeholder + /// values generated from the same seed are guaranteed to be equal (even + /// across multiple test runs, so long as the value format doesn't change). + fn placeholder_seed(seed: impl AsRef<[u8]>) -> Self { + Self::placeholder_seed_parts([ seed.as_ref() ]) + } + + /// Gets a dummy valid of this type which can be used for test purposes. + /// + /// This allows acquiring multiple distinct placeholder values which are still consistent + /// between runs. + /// + /// ### Arguments + /// + /// * `index` - the index of the placeholder value to obtain. Two placeholder values generated + /// from the same index are guaranteed to be equal (even across multiple test runs, + /// so long as the value format doesn't change). + fn placeholder_indexed(index: u64) -> Self { + Self::placeholder_seed_parts([ index.to_le_bytes().as_slice() ]) + } + + /// Gets an array of placeholder values of this type which can be used for test purposes. + fn placeholder_array_seed(seed: impl AsRef<[u8]>) -> [Self; N] { + core::array::from_fn(|n| Self::placeholder_seed_parts( + [ seed.as_ref(), &(n as u64).to_le_bytes() ] + )) + } + + /// Gets an array of placeholder values of this type which can be used for test purposes. + fn placeholder_array_indexed(base_index: u64) -> [Self; N] { + Self::placeholder_array_seed(base_index.to_le_bytes()) + } +} + +#[cfg(test)] +impl Placeholder for T { + fn placeholder() -> Self { + ::placeholder_seed_parts([]) + } +} + +/// Generates the given number of pseudorandom bytes based on the given seed. +/// +/// This is intended to be used in tests, where random but reproducible placeholder values are often +/// required. +/// +/// ### Arguments +/// +/// * `seed_parts` - the parts of the seed, which will be concatenated to form the RNG seed +#[cfg(test)] +pub fn placeholder_bytes<'a, const N: usize>( + seed_parts: impl IntoIterator +) -> [u8; N] { + // Use Shake-256 to generate an arbitrarily large number of random bytes based on the given seed. + let mut shake256 = sha3::Shake256::default(); + for slice in seed_parts { + sha3::digest::Update::update(&mut shake256, slice); + } + let mut reader = sha3::digest::ExtendableOutput::finalize_xof(shake256); + + let mut res = [0u8; N]; + sha3::digest::XofReader::read(&mut reader, &mut res); + res +} + +/// A trait which indicates that a type can be represented by an ordinal number. +pub trait ToOrdinal { + /// Gets the ordinal number from a value. + fn to_ordinal(&self) -> u32; +} + +/// A trait which indicates that a type can be instantiated from an ordinal number. +pub trait FromOrdinal : Sized { + /// A slice containing every valid ordinal number. + const ALL_ORDINALS : &'static [u32]; + + /// Gets the value corresponding to the given ordinal number. + /// + /// ### Arguments + /// + /// * `ordinal` - The ordinal number + fn from_ordinal(ordinal: u32) -> Result; +} + +/// A trait which indicates that a type can be represented by a string name. +pub trait ToName { + /// Gets a value's string name. + fn to_name(&self) -> &'static str; +} + +/// A trait which indicates that a type can be instantiated from a string name. +pub trait FromName : Sized { + /// A slice containing every valid name. + const ALL_NAMES : &'static [&'static str]; + + /// Gets the value corresponding to the given name. + /// + /// ### Arguments + /// + /// * `name` - The name + fn from_name(name: &str) -> Result; +} \ No newline at end of file diff --git a/src/utils/script_utils.rs b/src/utils/script_utils.rs index a2cd085..fca79da 100644 --- a/src/utils/script_utils.rs +++ b/src/utils/script_utils.rs @@ -15,8 +15,6 @@ use crate::utils::transaction_utils::{ construct_address, construct_tx_hash, construct_tx_in_out_signable_hash, construct_tx_in_signable_asset_hash, construct_tx_in_signable_hash, }; -use bincode::serialize; -use bytes::Bytes; use hex::encode; use ring::error; use std::collections::{BTreeMap, BTreeSet}; diff --git a/src/utils/serialize_utils.rs b/src/utils/serialize_utils.rs new file mode 100644 index 0000000..9053c02 --- /dev/null +++ b/src/utils/serialize_utils.rs @@ -0,0 +1,605 @@ +use std::any::TypeId; +use std::convert::{TryFrom, TryInto}; +use std::fmt; +use std::io::Write; +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; +use std::str::FromStr; +use bincode::{BorrowDecode, Decode, Encode}; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::de::{SeqAccess, Visitor}; +use serde::ser::{SerializeTuple}; + +/// Implements `Write` by simply counting the number of bytes written to it. +#[derive(Copy, Clone, Debug)] +pub struct ByteCountingWriter { + pub count: usize, +} + +impl ByteCountingWriter { + /// Creates a new `ByteCountingWriter` with a `count` of 0. + pub fn new() -> Self { + Self { + count: 0, + } + } +} + +impl Write for ByteCountingWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.count += buf.len(); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +/// Simple wrapper around a fixed-length byte array. +/// +/// This can be formatted to and parsed from a hexadecimal string using `Display` and `FromStr`. +/// When serialized as JSON, it is also represented as a hexadecimal string. +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Encode, Decode)] +pub struct FixedByteArray( + #[serde(with = "fixed_array_codec")] + [u8; N], +); + +impl FixedByteArray { + pub fn new(arr: [u8; N]) -> Self { + Self(arr) + } +} + +impl AsRef<[u8]> for FixedByteArray { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl AsMut<[u8]> for FixedByteArray { + fn as_mut(&mut self) -> &mut [u8] { + &mut self.0 + } +} + +impl Deref for FixedByteArray { + type Target = [u8; N]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for FixedByteArray { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From<[u8; N]> for FixedByteArray { + fn from(value: [u8; N]) -> Self { + Self(value) + } +} + +impl From<&[u8; N]> for FixedByteArray { + fn from(value: &[u8; N]) -> Self { + Self(*value) + } +} + +impl TryFrom<&[u8]> for FixedByteArray { + type Error = std::array::TryFromSliceError; + + fn try_from(value: &[u8]) -> Result { + value.try_into().map(Self) + } +} + +impl TryFrom<&Vec> for FixedByteArray { + type Error = std::array::TryFromSliceError; + + fn try_from(value: &Vec) -> Result { + value.as_slice().try_into().map(Self) + } +} + +impl fmt::LowerHex for FixedByteArray { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // This is hacky because we can't make an array of type [u8; {N * 2}] due to + // generic parameters not being allowed in constant expressions on stable rust + assert_eq!(std::mem::size_of::<[u16; N]>(), std::mem::size_of::<[u8; N]>() * 2); + let mut buf = [0u16; N]; + let slice = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut u8, N * 2) }; + hex::encode_to_slice(&self.0, slice).unwrap(); + f.write_str(std::str::from_utf8(slice).unwrap()) + } +} + +impl fmt::Display for FixedByteArray { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::LowerHex::fmt(self, f) + } +} + +impl fmt::Debug for FixedByteArray { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "FixedByteArray<{N}>({self:x})") + } +} + +impl FromStr for FixedByteArray { + type Err = hex::FromHexError; + + fn from_str(s: &str) -> Result { + let mut buf = [0u8; N]; + hex::decode_to_slice(s, &mut buf)?; + Ok(Self(buf)) + } +} + +/// A serde codec for fixed-size arrays. +pub mod fixed_array_codec { + use super::*; + + pub fn serialize( + values: &[T; N], + serializer: S, + ) -> Result { + if TypeId::of::() == TypeId::of::() && serializer.is_human_readable() { + // We're serializing a byte array for a human-readable format, make it a hex string + vec_codec::serialize(values, serializer) + } else { + // Serialize the array as a tuple, to avoid adding a length prefix + let mut tuple = serializer.serialize_tuple(N)?; + for e in values { + tuple.serialize_element(e)?; + } + tuple.end() + } + } + + pub fn deserialize<'de, T: Deserialize<'de> + 'static, D: Deserializer<'de>, const N: usize>( + deserializer: D, + ) -> Result<[T; N], D::Error> { + if TypeId::of::() == TypeId::of::() && deserializer.is_human_readable() { + // We're deserializing a byte array for a human-readable format, we'll accept two different + // representations: + // - A hexadecimal string + // - An array of byte literals (this format should never be produced by the serializer + // for human-readable formats, but it was in the past, so we'll still support reading + // it for backwards-compatibility). + vec_to_fixed_array(vec_codec::deserialize(deserializer)?) + } else { + // We're deserializing a binary format, read the array as a tuple + // (to avoid adding a length prefix) + + struct FixedArrayVisitor(PhantomData); + impl<'de, T: Deserialize<'de>, const N: usize> Visitor<'de> for FixedArrayVisitor { + type Value = [T; N]; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a sequence") + } + + fn visit_seq>(self, mut seq: A) -> Result { + let mut vec = Vec::with_capacity(N); + while let Some(val) = seq.next_element::()? { + vec.push(val) + } + vec_to_fixed_array(vec) + } + } + + deserializer.deserialize_tuple(N, FixedArrayVisitor(Default::default())) + } + } +} + +/// A serde codec for variable-length `Vec`s. +pub mod vec_codec { + use super::*; + + pub fn serialize( + values: &[T], + serializer: S, + ) -> Result { + if TypeId::of::() == TypeId::of::() && serializer.is_human_readable() { + // We're serializing a byte array for a human-readable format, make it a hex string + let bytes = unsafe { std::slice::from_raw_parts(values.as_ptr() as *const u8, values.len()) }; + serializer.serialize_str(&hex::encode(bytes)) + } else { + // Serialize the array as a length-prefixed sequence + values.serialize(serializer) + } + } + + pub fn deserialize<'de, T: Deserialize<'de> + 'static, D: Deserializer<'de>>( + deserializer: D, + ) -> Result, D::Error> { + if TypeId::of::() == TypeId::of::() && deserializer.is_human_readable() { + // We're deserializing a byte array for a human-readable format, we'll accept two different + // representations: + // - A hexadecimal string + // - An array of byte literals (this format should never be produced by the serializer + // for human-readable formats, but it was in the past, so we'll still support reading + // it for backwards-compatibility). + + struct HexStringOrBytesVisitor(); + impl<'de> Visitor<'de> for HexStringOrBytesVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("hex string or byte array") + } + + fn visit_str(self, value: &str) -> Result { + hex::decode(value).map_err(E::custom) + } + + fn visit_seq(self, mut seq: A) -> Result where A: SeqAccess<'de> { + let mut vec = Vec::new(); + while let Some(elt) = seq.next_element::()? { + vec.push(elt); + } + Ok(vec) + } + } + + Ok(deserializer.deserialize_any(HexStringOrBytesVisitor())?.into_iter() + // This is a hack to convert the Vec into a Vec, even though we already know + // that T = u8. This could be done in a much nicer way if trait specialization were + // a thing, but unfortunately it's still only available on nightly :( + .map(|b| unsafe { std::mem::transmute_copy::(&b) }) + .collect::>()) + } else { + // Read a length-prefixed sequence as a Vec + >::deserialize(deserializer) + } + } +} + +fn vec_to_fixed_array( + vec: Vec, +) -> Result<[T; N], E> { + <[T; N]>::try_from(vec) + .map_err(|vec| E::custom(format!("expected exactly {} elements, but read {}", N, vec.len()))) +} + +/// Encodes an object into a `Vec` using bincode 2's standard configuration. +/// +/// This allows using the turbofish operator to explicitly specify the encode type without also +/// having to specify the config type. +/// +/// ### Arguments +/// +/// * `value` - the value to encode +#[inline(always)] +pub fn bincode_encode_to_vec_standard( + value: &T, +) -> Result, bincode::error::EncodeError> { + bincode::encode_to_vec(value, bincode::config::standard()) +} + +/// Encodes an object into the given `Write` using bincode 2's standard configuration. +/// +/// This allows using the turbofish operator to explicitly specify the encode type without also +/// having to specify the config type. +/// +/// ### Arguments +/// +/// * `value` - the value to encode +/// * `write` - the `Write` to encode into +#[inline(always)] +pub fn bincode_encode_to_write_standard( + value: &T, + write: &mut impl Write, +) -> Result { + bincode::encode_into_std_write(value, write, bincode::config::standard()) +} + +/// Calculates the encoded size of the given object using bincode 2's standard configuration. +/// +/// ### Arguments +/// +/// * `value` - the value to encode +#[inline(always)] +pub fn bincode_encoded_size_standard( + value: &T, +) -> Result { + let mut writer = ByteCountingWriter::new(); + bincode::encode_into_std_write(value, &mut writer, bincode::config::standard())?; + Ok(writer.count) +} + +/// Decodes an object from a slice using bincode 2's standard configuration. +/// +/// This allows using the turbofish operator to explicitly specify the decode type without also +/// having to specify the config type. +/// +/// ### Arguments +/// +/// * `slice` - the slice to decode from +#[inline(always)] +pub fn bincode_decode_from_slice_standard( + slice: &[u8], +) -> Result<(T, usize), bincode::error::DecodeError> { + bincode::decode_from_slice(slice, bincode::config::standard()) +} + +/// Decodes an object from a slice using bincode 2's standard configuration. +/// +/// This allows using the turbofish operator to explicitly specify the decode type without also +/// having to specify the config type. +/// +/// ### Arguments +/// +/// * `slice` - the slice to decode from +pub fn bincode_decode_from_slice_standard_full( + slice: &[u8], +) -> Result { + let (result, read_bytes) = bincode_decode_from_slice_standard::(slice)?; + if read_bytes == slice.len() { + Ok(result) + } else { + Err(bincode::error::DecodeError::OtherString( + format!("{} bytes left over after decoding", slice.len() - read_bytes))) + } +} + +/// Decodes an object from a slice using bincode 2's standard configuration. +/// +/// This allows using the turbofish operator to explicitly specify the decode type without also +/// having to specify the config type. +/// +/// ### Arguments +/// +/// * `slice` - the slice to decode from +#[inline(always)] +pub fn bincode_borrow_decode_from_slice_standard<'a, T: BorrowDecode<'a>>( + slice: &'a [u8], +) -> Result<(T, usize), bincode::error::DecodeError> { + bincode::borrow_decode_from_slice(slice, bincode::config::standard()) +} + +/*---- TESTS ----*/ + +#[cfg(test)] +mod tests { + use std::fmt::{Debug, Display}; + + use serde::{Deserialize, Serialize}; + use serde::de::DeserializeOwned; + use super::*; + + fn repeat(orig: &str, n: usize) -> String { + let mut res = String::with_capacity(orig.len() * n); + for _ in 0..n { + res.push_str(orig) + } + res + } + + fn test_bin_codec( + config: fn() -> C, + obj: T, + expect: &str, + ) { + let bytes = bincode::serde::encode_to_vec(&obj, config()).unwrap(); + assert_eq!(hex::encode(&bytes), expect); + assert_eq!(bincode::serde::decode_from_slice::(&bytes, config()).unwrap().0, obj); + } + + fn test_json_codec( + obj: T, + expect: &str, + ) { + let json = serde_json::to_string(&obj).unwrap(); + assert_eq!(json, expect); + assert_eq!(serde_json::from_str::(&json).unwrap(), obj); + } + + fn test_json_deserialize( + obj: T, + json: &str, + ) { + assert_eq!(serde_json::from_str::(&json).unwrap(), obj); + } + + fn test_display_fromstr( + obj: T, + expect: &str, + ) { + let string = obj.to_string(); + assert_eq!(string, expect); + assert_eq!(::from_str(&string).ok().unwrap(), obj); + } + + macro_rules! test_fixed_array { + ($n:literal) => { + test_bin_codec(bincode::config::legacy, [VAL; $n], &repeat(HEX, $n)); + test_json_codec([VAL; $n], &serde_json::to_string(&[VAL; $n].to_vec()).unwrap()); + }; + } + + macro_rules! test_fixed_array_wrapper { + ($e:ty, $t:ident, $n:literal) => { + #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] + struct $t([$e; $n]); + test_bin_codec(bincode::config::legacy, $t([VAL; $n]), &repeat(HEX, $n)); + test_json_codec($t([VAL; $n]), &serde_json::to_string(&[VAL; $n].to_vec()).unwrap()); + }; + } + + #[test] + fn test_fixed_u32_arrays() { + const VAL : u32 = 0xDEADBEEF; + const HEX : &str = "efbeadde"; + + test_fixed_array!(0); + test_fixed_array!(1); + test_fixed_array!(32); + + test_fixed_array_wrapper!(u32, FixedArrayWrapper0, 0); + test_fixed_array_wrapper!(u32, FixedArrayWrapper1, 1); + test_fixed_array_wrapper!(u32, FixedArrayWrapper32, 32); + + macro_rules! test_fixed_array_wrapper_codec { + ($t:ident, $n:literal) => { + #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] + struct $t(#[serde(with = "fixed_array_codec")] [u32; $n]); + test_bin_codec(bincode::config::legacy, $t([VAL; $n]), &repeat(HEX, $n)); + test_json_codec($t([VAL; $n]), &serde_json::to_string(&[VAL; $n].to_vec()).unwrap()); + }; + } + + test_fixed_array_wrapper_codec!(CodecFixedArrayWrapper0, 0); + test_fixed_array_wrapper_codec!(CodecFixedArrayWrapper1, 1); + test_fixed_array_wrapper_codec!(CodecFixedArrayWrapper32, 32); + test_fixed_array_wrapper_codec!(CodecFixedArrayWrapper33, 33); + } + + #[test] + fn test_fixed_u8_arrays() { + const VAL : u8 = 123; + const HEX : &str = "7b"; + + test_fixed_array!(0); + test_fixed_array!(1); + test_fixed_array!(32); + + test_fixed_array_wrapper!(u8, FixedArrayWrapper0, 0); + test_fixed_array_wrapper!(u8, FixedArrayWrapper1, 1); + test_fixed_array_wrapper!(u8, FixedArrayWrapper32, 32); + + macro_rules! test_fixed_array_wrapper_codec { + ($t:ident, $n:literal) => { + #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] + struct $t(#[serde(with = "fixed_array_codec")] [u8; $n]); + test_bin_codec(bincode::config::legacy, $t([VAL; $n]), &repeat(HEX, $n)); + test_json_codec($t([VAL; $n]), &format!("\"{}\"", hex::encode(&[VAL; $n].to_vec()))); + test_json_deserialize($t([VAL; $n]), &serde_json::to_string(&[VAL; $n].to_vec()).unwrap()); + }; + } + + test_fixed_array_wrapper_codec!(CodecFixedArrayWrapper0, 0); + test_fixed_array_wrapper_codec!(CodecFixedArrayWrapper1, 1); + test_fixed_array_wrapper_codec!(CodecFixedArrayWrapper32, 32); + test_fixed_array_wrapper_codec!(CodecFixedArrayWrapper33, 33); + } + + fn size_to_hex_default(n: usize) -> String { + hex::encode(&(n as u64).to_le_bytes()) + } + + macro_rules! test_vec { + ($n:literal) => { + test_bin_codec(bincode::config::legacy, [VAL; $n].to_vec(), &format!("{}{}", size_to_hex_default($n), repeat(HEX, $n))); + test_json_codec([VAL; $n].to_vec(), &serde_json::to_string(&[VAL; $n].to_vec()).unwrap()); + }; + } + + macro_rules! test_vec_wrapper { + ($n:literal) => { + test_bin_codec(bincode::config::legacy, VecWrapper([VAL; $n].to_vec()), &format!("{}{}", size_to_hex_default($n), repeat(HEX, $n))); + test_json_codec(VecWrapper([VAL; $n].to_vec()), &serde_json::to_string(&[VAL; $n].to_vec()).unwrap()); + }; + } + + #[test] + fn test_u32_vecs() { + const VAL : u32 = 0xDEADBEEF; + const HEX : &str = "efbeadde"; + + test_vec!(0); + test_vec!(1); + test_vec!(32); + test_vec!(33); + + #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] + struct VecWrapper(Vec); + + test_vec_wrapper!(0); + test_vec_wrapper!(1); + test_vec_wrapper!(32); + test_vec_wrapper!(33); + + #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] + struct CodecVecWrapper(#[serde(with = "vec_codec")] Vec); + macro_rules! test_vec_wrapper_codec { + ($n:literal) => { + test_bin_codec(bincode::config::legacy, CodecVecWrapper([VAL; $n].to_vec()), &format!("{}{}", size_to_hex_default($n), repeat(HEX, $n))); + test_json_codec(CodecVecWrapper([VAL; $n].to_vec()), &serde_json::to_string(&[VAL; $n].to_vec()).unwrap()); + }; + } + + test_vec_wrapper_codec!(0); + test_vec_wrapper_codec!(1); + test_vec_wrapper_codec!(32); + test_vec_wrapper_codec!(33); + } + + #[test] + fn test_u8_vecs() { + const VAL : u8 = 123; + const HEX : &str = "7b"; + + test_vec!(0); + test_vec!(1); + test_vec!(32); + test_vec!(33); + + #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] + struct VecWrapper(Vec); + + test_vec_wrapper!(0); + test_vec_wrapper!(1); + test_vec_wrapper!(32); + test_vec_wrapper!(33); + + #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] + struct CodecVecWrapper(#[serde(with = "vec_codec")] Vec); + macro_rules! test_vec_wrapper_codec { + ($n:literal) => { + test_bin_codec(bincode::config::legacy, CodecVecWrapper([VAL; $n].to_vec()), &format!("{}{}", size_to_hex_default($n), repeat(HEX, $n))); + test_json_codec(CodecVecWrapper([VAL; $n].to_vec()), &format!("\"{}\"", hex::encode(&[VAL; $n].to_vec()))); + test_json_deserialize(CodecVecWrapper([VAL; $n].to_vec()), &serde_json::to_string(&[VAL; $n].to_vec()).unwrap()); + }; + } + + test_vec_wrapper_codec!(0); + test_vec_wrapper_codec!(1); + test_vec_wrapper_codec!(32); + test_vec_wrapper_codec!(33); + } + + #[test] + fn test_fixed_byte_array() { + const VAL : u8 = 123; + const HEX : &str = "7b"; + + macro_rules! test_fixed_byte_array { + ($n:literal) => { + test_bin_codec(bincode::config::legacy, FixedByteArray::<$n>([VAL; $n]), &repeat(HEX, $n)); + test_json_codec(FixedByteArray::<$n>([VAL; $n]), &format!("\"{}\"", repeat(HEX, $n))); + test_json_deserialize(FixedByteArray::<$n>([VAL; $n]), &serde_json::to_string(&[VAL; $n].to_vec()).unwrap()); + test_display_fromstr(FixedByteArray::<$n>([VAL; $n]), &repeat(HEX, $n)); + assert_eq!(format!("{:x}", FixedByteArray::<$n>([VAL; $n])), repeat(HEX, $n)); + assert_eq!( + format!("{:?}", FixedByteArray::<$n>([VAL; $n])), + format!("FixedByteArray<{}>({})", $n, repeat(HEX, $n))); + assert_eq!( + format!("{:x?}", FixedByteArray::<$n>([VAL; $n])), + format!("FixedByteArray<{}>({})", $n, repeat(HEX, $n))); + }; + } + + test_fixed_byte_array!(0); + test_fixed_byte_array!(1); + test_fixed_byte_array!(32); + test_fixed_byte_array!(33); + } +} diff --git a/src/utils/test_utils.rs b/src/utils/test_utils.rs index 0c8fcef..0106644 100644 --- a/src/utils/test_utils.rs +++ b/src/utils/test_utils.rs @@ -7,6 +7,7 @@ use crate::primitives::{ use crate::script::lang::Script; use crate::utils::transaction_utils::{construct_address, construct_tx_in_out_signable_hash}; use std::collections::BTreeMap; +use std::convert::TryInto; /// Generate a transaction with valid Script values /// and accompanying UTXO set for testing a set of @@ -48,7 +49,7 @@ pub fn generate_tx_with_ins_and_outs_assets( // Generate inputs for (input_amount, genesis_hash, md) in input_assets { - let tx_previous_out = OutPoint::new("tx_hash".to_owned(), tx.inputs.len() as i32); + let tx_previous_out = OutPoint::new("tx_hash".to_owned(), tx.inputs.len().try_into().unwrap()); let tx_in_previous_out = match genesis_hash { Some(drs) => { let item = Asset::item(*input_amount, Some(drs.to_string()), md.clone()); diff --git a/src/utils/transaction_utils.rs b/src/utils/transaction_utils.rs index 8e9fd99..9fa9cdc 100644 --- a/src/utils/transaction_utils.rs +++ b/src/utils/transaction_utils.rs @@ -6,8 +6,8 @@ use crate::primitives::druid::{DdeValues, DruidExpectation}; use crate::primitives::transaction::*; use crate::script::lang::Script; use crate::script::{OpCodes, StackEntry}; -use bincode::serialize; use std::collections::BTreeMap; +use std::convert::TryInto; use tracing::debug; pub struct ReceiverInfo { @@ -21,10 +21,7 @@ pub struct ReceiverInfo { /// /// * `script` - Script to build address for pub fn construct_p2sh_address(script: &Script) -> String { - let bytes = match serialize(script) { - Ok(bytes) => bytes, - Err(_) => vec![], - }; + let bytes = bincode::serde::encode_to_vec(script, bincode::config::legacy()).unwrap(); let mut addr = hex::encode(sha3_256::digest(&bytes)); addr.insert(ZERO, P2SH_PREPEND as char); addr.truncate(STANDARD_ADDRESS_LENGTH); @@ -260,7 +257,7 @@ pub fn get_tx_with_out_point<'a>( ) -> impl Iterator { txs.map(|(hash, tx)| (hash, tx, &tx.outputs)) .flat_map(|(hash, tx, outs)| outs.iter().enumerate().map(move |(idx, _)| (hash, idx, tx))) - .map(|(hash, idx, tx)| (OutPoint::new(hash.clone(), idx as i32), tx)) + .map(|(hash, idx, tx)| (OutPoint::new(hash.clone(), idx.try_into().unwrap()), tx)) } /// Get all the OutPoint and Transaction from the (hash,transactions) @@ -284,7 +281,7 @@ pub fn get_tx_out_with_out_point<'a>( ) -> impl Iterator { txs.map(|(hash, tx)| (hash, tx.outputs.iter())) .flat_map(|(hash, outs)| outs.enumerate().map(move |(idx, txo)| (hash, idx, txo))) - .map(|(hash, idx, txo)| (OutPoint::new(hash.clone(), idx as i32), txo)) + .map(|(hash, idx, txo)| (OutPoint::new(hash.clone(), idx.try_into().unwrap()), txo)) } /// Get all fee outputs from the (hash,transactions) @@ -297,7 +294,7 @@ pub fn get_fees_with_out_point<'a>( ) -> impl Iterator { txs.map(|(hash, tx)| (hash, tx.fees.iter())) .flat_map(|(hash, outs)| outs.enumerate().map(move |(idx, txo)| (hash, idx, txo))) - .map(|(hash, idx, txo)| (OutPoint::new(hash.clone(), idx as i32), txo)) + .map(|(hash, idx, txo)| (OutPoint::new(hash.clone(), idx.try_into().unwrap()), txo)) } /// Get all fee outputs from the (hash,transactions) @@ -310,7 +307,7 @@ pub fn get_fees_with_out_point_cloned<'a>( ) -> impl Iterator + 'a { txs.map(|(hash, tx)| (hash, tx.fees.iter())) .flat_map(|(hash, outs)| outs.enumerate().map(move |(idx, txo)| (hash, idx, txo))) - .map(|(hash, idx, txo)| (OutPoint::new(hash.clone(), idx as i32), txo.clone())) + .map(|(hash, idx, txo)| (OutPoint::new(hash.clone(), idx.try_into().unwrap()), txo.clone())) } /// Get all the OutPoint and TxOut from the (hash,transactions) @@ -344,7 +341,7 @@ pub fn update_utxo_set(current_utxo: &mut BTreeMap) { /// /// * `tx` - Transaction to hash pub fn construct_tx_hash(tx: &Transaction) -> String { - let bytes = match serialize(tx) { + let bytes = match bincode::serde::encode_to_vec(tx, bincode::config::legacy()) { Ok(bytes) => bytes, Err(_) => vec![], }; @@ -1179,7 +1176,7 @@ mod tests { ..Default::default() }]; - let bytes = match serialize(&tx_ins) { + let bytes = match bincode::serde::encode_to_vec(&tx_ins, bincode::config::legacy()) { Ok(bytes) => bytes, Err(_) => vec![], };