@@ -232,33 +232,112 @@ pub fn apply_diff_copy(original: &[u8], diffset: &DiffSet<'_>) -> Result<Vec<u8>
232232 } )
233233}
234234
235- /// This function constructs destination by merging original with diff such that destination
236- /// becomes the changed version of the original.
235+ /// Constructs destination by applying the diff to original, such that destination becomes the
236+ /// post-diff state of the original.
237237///
238238/// Precondition:
239- /// - destination.len() == original.len()
239+ /// - destination.len() == diffset.changed_len()
240+ /// - original.len() may differ from destination.len() to allow Solana
241+ /// account resizing (shrink or expand).
242+ /// Assumption:
243+ /// - destination is assumed to be zero-initialized. That automatically holds true for freshly
244+ /// allocated Solana account data. The function does NOT validate this assumption for performance reason.
245+ /// Returns:
246+ /// - Ok(n) where n is number of bytes written.
247+ /// - If n < destination.len(), then the last (destination.len() - n) bytes are not written by this
248+ /// function, and the caller, if they want, could write to these bytes starting with index `n`.
249+ /// - else n == destination.len().
250+ /// Notes:
251+ /// - Unmodified regions are copied directly from original.
252+ /// - Extra trailing bytes from original (in shrink case) are ignored.
253+ /// - In expansion case, bytes covered by diff segments are written from diffset
254+ /// (even beyond `original.len()`); any remaining bytes beyond both the diff coverage
255+ /// and original.len() are unwritten and are assumed to be zero-initialized.
256+ ///
240257pub fn merge_diff_copy (
241258 destination : & mut [ u8 ] ,
242259 original : & [ u8 ] ,
243260 diffset : & DiffSet < ' _ > ,
244- ) -> Result < ( ) , ProgramError > {
245- if destination. len ( ) != original . len ( ) {
261+ ) -> Result < usize , ProgramError > {
262+ if destination. len ( ) != diffset . changed_len ( ) {
246263 return Err ( DlpError :: MergeDiffError . into ( ) ) ;
247264 }
265+
248266 let mut write_index = 0 ;
249267 for item in diffset. iter ( ) {
250268 let ( diff_segment, OffsetInData { start, end } ) = item?;
269+
251270 if write_index < start {
271+ if start > original. len ( ) {
272+ return Err ( DlpError :: InvalidDiff . into ( ) ) ;
273+ }
252274 // copy the unchanged bytes
253275 destination[ write_index..start] . copy_from_slice ( & original[ write_index..start] ) ;
254276 }
277+
255278 destination[ start..end] . copy_from_slice ( diff_segment) ;
256279 write_index = end;
257280 }
258- if write_index < original. len ( ) {
259- destination[ write_index..] . copy_from_slice ( & original[ write_index..] ) ;
260- }
261- Ok ( ( ) )
281+
282+ // Ensure we have overwritten all bytes in destination, otherwise "construction" of destination
283+ // will be considered incomplete.
284+ let num_bytes_written = match write_index. cmp ( & destination. len ( ) ) {
285+ Ordering :: Equal => {
286+ // It means the destination is fully constructed.
287+ // Nothing to do here.
288+
289+ // It is possible that destination.len() <= original.len() i.e destination might have shrunk
290+ // in which case we do not care about those bytes of original which are not part of
291+ // destination anymore.
292+ write_index
293+ }
294+ Ordering :: Less => {
295+ // destination is NOT fully constructed yet. Few bytes in the destination are still unwritten.
296+ // Let's say the number of these unwritten bytes is: N.
297+ //
298+ // Now how do we construct these N unwritten bytes? We have already processed the
299+ // diffset, so now where could the values for these N bytes come from?
300+ //
301+ // There are 3 scenarios:
302+ // - All N bytes must be copied from remaining region of the original:
303+ // - that means, destination.len() <= original.len()
304+ // - and the destination might have shrunk, in which case we do not care about
305+ // the extra bytes in the original: they're discarded.
306+ // - Only (N-M) bytes come from original and the rest M bytes stay unwritten and are
307+ // "assumed" to be already zero-initialized.
308+ // - that means, destination.len() > original.len()
309+ // - write_index + (N-M) == original.len()
310+ // - and the destination has expanded.
311+ // - None of these N bytes come from original. It's basically a special case of
312+ // the second scenario: when M = N i.e all N bytes stay unwritten.
313+ // - that means, destination.len() > original.len()
314+ // - and also, write_index == original.len().
315+ // - the destination has expanded just like the above case.
316+ // - all N bytes are "assumed" to be already zero-initialized (by the caller)
317+
318+ if destination. len ( ) <= original. len ( ) {
319+ // case: all n bytes come from original
320+ let dest_len = destination. len ( ) ;
321+ destination[ write_index..] . copy_from_slice ( & original[ write_index..dest_len] ) ;
322+ dest_len
323+ } else if write_index < original. len ( ) {
324+ // case: some bytes come from original and the rest are "assumed" to be
325+ // zero-initialized (by the caller).
326+ destination[ write_index..original. len ( ) ] . copy_from_slice ( & original[ write_index..] ) ;
327+ original. len ( )
328+ } else {
329+ // case: all N bytes are "assumed" to be zero-initialized (by the caller).
330+ write_index
331+ }
332+ }
333+ Ordering :: Greater => {
334+ // It is an impossible scenario. Even if the diff is corrupt, or the lengths of destinatiare are same
335+ // or different, we'll not encounter this case. It only implies logic error.
336+ return Err ( DlpError :: InfallibleError . into ( ) ) ;
337+ }
338+ } ;
339+
340+ Ok ( num_bytes_written)
262341}
263342
264343// private function that does the actual work.
@@ -297,6 +376,58 @@ mod tests {
297376 ) ;
298377 }
299378
379+ fn get_example_expected_diff (
380+ changed_len : usize ,
381+ // additional_changes must apply after index 78 (index-in-data) !!
382+ additional_changes : Vec < ( u32 , & [ u8 ] ) > ,
383+ ) -> Vec < u8 > {
384+ // expected: | 100 | 2 | 0 11 | 4 71 | 11 12 13 14 71 72 ... 78 |
385+
386+ let mut expected_diff = vec ! [ ] ;
387+
388+ // changed_len (u32)
389+ expected_diff. extend_from_slice ( & ( changed_len as u32 ) . to_le_bytes ( ) ) ;
390+
391+ if additional_changes. is_empty ( ) {
392+ // 2 (u32)
393+ expected_diff. extend_from_slice ( & 2u32 . to_le_bytes ( ) ) ;
394+ } else {
395+ expected_diff
396+ . extend_from_slice ( & ( 2u32 + additional_changes. len ( ) as u32 ) . to_le_bytes ( ) ) ;
397+ }
398+
399+ // -- offsets
400+
401+ // 0 11 (each u32)
402+ expected_diff. extend_from_slice ( & 0u32 . to_le_bytes ( ) ) ;
403+ expected_diff. extend_from_slice ( & 11u32 . to_le_bytes ( ) ) ;
404+
405+ // 4 71 (each u32)
406+ expected_diff. extend_from_slice ( & 4u32 . to_le_bytes ( ) ) ;
407+ expected_diff. extend_from_slice ( & 71u32 . to_le_bytes ( ) ) ;
408+
409+ let mut offset_in_diff = 12u32 ;
410+ for ( offset_in_data, diff) in additional_changes. iter ( ) {
411+ expected_diff. extend_from_slice ( & offset_in_diff. to_le_bytes ( ) ) ;
412+ expected_diff. extend_from_slice ( & offset_in_data. to_le_bytes ( ) ) ;
413+ offset_in_diff += diff. len ( ) as u32 ;
414+ }
415+
416+ // -- segments --
417+
418+ // 11 12 13 14 (each u8)
419+ expected_diff. extend_from_slice ( & 0x01020304u32 . to_le_bytes ( ) ) ;
420+ // 71 72 ... 78 (each u8)
421+ expected_diff. extend_from_slice ( & 0x0102030405060708u64 . to_le_bytes ( ) ) ;
422+
423+ // append diff from additional_changes
424+ for ( _, diff) in additional_changes. iter ( ) {
425+ expected_diff. extend_from_slice ( diff) ;
426+ }
427+
428+ expected_diff
429+ }
430+
300431 #[ test]
301432 fn test_using_example_data ( ) {
302433 let original = [ 0 ; 100 ] ;
@@ -311,42 +442,99 @@ mod tests {
311442
312443 let actual_diff = compute_diff ( & original, & changed) ;
313444 let actual_diffset = DiffSet :: try_new ( & actual_diff) . unwrap ( ) ;
314- let expected_diff = {
315- // expected: | 100 | 2 | 0 11 | 4 71 | 11 12 13 14 71 72 ... 78 |
445+ let expected_diff = get_example_expected_diff ( changed. len ( ) , vec ! [ ] ) ;
316446
317- let mut serialized = vec ! [ ] ;
447+ assert_eq ! ( actual_diff. len( ) , 4 + 4 + 8 + 8 + ( 4 + 8 ) ) ;
448+ assert_eq ! ( actual_diff. as_slice( ) , expected_diff. as_slice( ) ) ;
318449
319- // 100 (u32)
320- serialized. extend_from_slice ( & ( changed. len ( ) as u32 ) . to_le_bytes ( ) ) ;
450+ let expected_changed = apply_diff_copy ( & original, & actual_diffset) . unwrap ( ) ;
321451
322- // 2 (u32)
323- serialized. extend_from_slice ( & 2u32 . to_le_bytes ( ) ) ;
452+ assert_eq ! ( changed. as_slice( ) , expected_changed. as_slice( ) ) ;
453+
454+ let expected_changed = {
455+ let mut destination = vec ! [ 255 ; original. len( ) ] ;
456+ merge_diff_copy ( & mut destination, & original, & actual_diffset) . unwrap ( ) ;
457+ destination
458+ } ;
459+
460+ assert_eq ! ( changed. as_slice( ) , expected_changed. as_slice( ) ) ;
461+ }
324462
325- // 0 11 (each u32)
326- serialized. extend_from_slice ( & 0u32 . to_le_bytes ( ) ) ;
327- serialized. extend_from_slice ( & 11u32 . to_le_bytes ( ) ) ;
463+ #[ test]
464+ fn test_shrunk_account_data ( ) {
465+ // Note that changed_len cannot be lower than 79 because the last "changed" index is
466+ // 78 in the diff.
467+ const CHANGED_LEN : usize = 80 ;
328468
329- // 4 71 (each u32)
330- serialized. extend_from_slice ( & 4u32 . to_le_bytes ( ) ) ;
331- serialized. extend_from_slice ( & 71u32 . to_le_bytes ( ) ) ;
469+ let original = vec ! [ 0 ; 100 ] ;
470+ let changed = {
471+ let mut copy = original. clone ( ) ;
472+ copy. truncate ( CHANGED_LEN ) ;
332473
333- // 11 12 13 14 (each u8)
334- serialized . extend_from_slice ( & 0x01020304u32 . to_le_bytes ( ) ) ;
335- // 71 72 ... 78 (each u8)
336- serialized . extend_from_slice ( & 0x0102030405060708u64 . to_le_bytes ( ) ) ;
337- serialized
474+ // | 11 | 12 | 13 | 14 |
475+ copy [ 11 ..= 14 ] . copy_from_slice ( & 0x01020304u32 . to_le_bytes ( ) ) ;
476+ // | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
477+ copy [ 71 ..= 78 ] . copy_from_slice ( & 0x0102030405060708u64 . to_le_bytes ( ) ) ;
478+ copy
338479 } ;
339480
481+ let actual_diff = compute_diff ( & original, & changed) ;
482+
483+ let actual_diffset = DiffSet :: try_new ( & actual_diff) . unwrap ( ) ;
484+
485+ let expected_diff = get_example_expected_diff ( CHANGED_LEN , vec ! [ ] ) ;
486+
340487 assert_eq ! ( actual_diff. len( ) , 4 + 4 + 8 + 8 + ( 4 + 8 ) ) ;
341488 assert_eq ! ( actual_diff. as_slice( ) , expected_diff. as_slice( ) ) ;
342489
343- let expected_changed = apply_diff_copy ( & original, & actual_diffset) . unwrap ( ) ;
490+ let expected_changed = {
491+ let mut destination = vec ! [ 255 ; CHANGED_LEN ] ;
492+ merge_diff_copy ( & mut destination, & original, & actual_diffset) . unwrap ( ) ;
493+ destination
494+ } ;
344495
345496 assert_eq ! ( changed. as_slice( ) , expected_changed. as_slice( ) ) ;
497+ }
498+
499+ #[ test]
500+ fn test_expanded_account_data ( ) {
501+ const CHANGED_LEN : usize = 120 ;
502+
503+ let original = vec ! [ 0 ; 100 ] ;
504+ let changed = {
505+ let mut copy = original. clone ( ) ;
506+ copy. resize ( CHANGED_LEN , 0 ) ; // new bytes are zero-initialized
507+
508+ // | 11 | 12 | 13 | 14 |
509+ copy[ 11 ..=14 ] . copy_from_slice ( & 0x01020304u32 . to_le_bytes ( ) ) ;
510+ // | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
511+ copy[ 71 ..=78 ] . copy_from_slice ( & 0x0102030405060708u64 . to_le_bytes ( ) ) ;
512+ copy
513+ } ;
514+
515+ let actual_diff = compute_diff ( & original, & changed) ;
516+
517+ let actual_diffset = DiffSet :: try_new ( & actual_diff) . unwrap ( ) ;
518+
519+ // When an account expands, the extra bytes at the end become part of the diff, even if
520+ // all of them are zeroes, that is why (100, &[0; 32]) is passed as additional_changes to
521+ // the following function.
522+ //
523+ // TODO (snawaz): we could optimize compute_diff to not include the zero bytes which are
524+ // part of the expansion.
525+ let expected_diff = get_example_expected_diff ( CHANGED_LEN , vec ! [ ( 100 , & [ 0 ; 20 ] ) ] ) ;
526+
527+ assert_eq ! ( actual_diff. len( ) , 4 + 4 + ( 8 + 8 ) + ( 4 + 8 ) + ( 4 + 4 + 20 ) ) ;
528+ assert_eq ! ( actual_diff. as_slice( ) , expected_diff. as_slice( ) ) ;
346529
347530 let expected_changed = {
348- let mut destination = vec ! [ 255 ; original. len( ) ] ;
349- merge_diff_copy ( & mut destination, & original, & actual_diffset) . unwrap ( ) ;
531+ let mut destination = vec ! [ 255 ; CHANGED_LEN ] ;
532+ let written = merge_diff_copy ( & mut destination, & original, & actual_diffset) . unwrap ( ) ;
533+
534+ // TODO (snawaz): written == 120, is because currently the expanded bytes are part of the diff.
535+ // Once compute_diff is optimized further, written must be 100.
536+ assert_eq ! ( written, 120 ) ;
537+
350538 destination
351539 } ;
352540
0 commit comments