From 2077805384eb8f19aa96fcbd4a2beaa9f806845e Mon Sep 17 00:00:00 2001
From: Steven Allen <steven@stebalien.com>
Date: Thu, 14 Oct 2021 21:32:40 -0700
Subject: [PATCH] fix: correctly encode multihashes with serde

The number of digest bytes written needs to match the size.

See the upstream issue here: https://github.com/multiformats/rust-multihash/pull/140

All credit goes to @koushiro. My one change is to not allocate.
---
 src/multihash.rs | 40 ++++++++++++++++++++++++++++++----------
 1 file changed, 30 insertions(+), 10 deletions(-)

diff --git a/src/multihash.rs b/src/multihash.rs
index 55f4fde..f8636a0 100644
--- a/src/multihash.rs
+++ b/src/multihash.rs
@@ -194,11 +194,9 @@ impl<const S: usize> parity_scale_codec::Encode for Multihash<S> {
     &self,
     dest: &mut EncOut,
   ) {
-    let mut digest = [0; S];
-    digest.copy_from_slice(&self.digest);
     self.code.encode_to(dest);
     self.size.encode_to(dest);
-    digest.encode_to(dest);
+    dest.write(self.digest());
   }
 }
 
@@ -210,11 +208,17 @@ impl<const S: usize> parity_scale_codec::Decode for Multihash<S> {
   fn decode<DecIn: parity_scale_codec::Input>(
     input: &mut DecIn,
   ) -> Result<Self, parity_scale_codec::Error> {
-    Ok(Multihash {
+    let mut mh = Multihash {
       code: parity_scale_codec::Decode::decode(input)?,
       size: parity_scale_codec::Decode::decode(input)?,
-      digest: <[u8; S]>::decode(input)?,
-    })
+      digest: [0; S],
+    };
+    if mh.size as usize > S {
+      Err(parity_scale_codec::Error::from("invalid size"))
+    } else {
+      input.read(&mut mh.digest[..mh.size as usize])?;
+      Ok(mh)
+    }
   }
 }
 
@@ -318,15 +322,31 @@ mod tests {
   #[test]
   #[cfg(feature = "scale-codec")]
   fn test_scale() {
+    use crate::{Hasher, Sha2_256};
     use parity_scale_codec::{
       Decode,
       Encode,
     };
 
-    let mh = Multihash::<32>::default();
-    let bytes = mh.encode();
-    let mh2: Multihash<32> = Decode::decode(&mut &bytes[..]).unwrap();
-    assert_eq!(mh, mh2);
+    let mh1 = Multihash::<32>::wrap(
+      Code::Sha2_256.into(),
+      Sha2_256::digest(b"hello world").as_ref(),
+    )
+    .unwrap();
+    // println!("mh1: code = {}, size = {}, digest = {:?}", mh1.code(), mh1.size(), mh1.digest());
+    let mh1_bytes = mh1.encode();
+    // println!("Multihash<32>: {}", hex::encode(&mh1_bytes));
+    let mh2: Multihash<32> = Decode::decode(&mut &mh1_bytes[..]).unwrap();
+    assert_eq!(mh1, mh2);
+
+    let mh3: Multihash<64> = Code::Sha2_256.digest(b"hello world");
+    // println!("mh3: code = {}, size = {}, digest = {:?}", mh3.code(), mh3.size(), mh3.digest());
+    let mh3_bytes = mh3.encode();
+    // println!("Multihash<64>: {}", hex::encode(&mh3_bytes));
+    let mh4: Multihash<64> = Decode::decode(&mut &mh3_bytes[..]).unwrap();
+    assert_eq!(mh3, mh4);
+
+    assert_eq!(mh1_bytes, mh3_bytes);
   }
 
   #[test]