Skip to content

Commit 195ce40

Browse files
committed
feat: Handle account shrinking/expansion in merge_diff_copy
1 parent e8d0393 commit 195ce40

File tree

2 files changed

+220
-30
lines changed

2 files changed

+220
-30
lines changed

src/diff/algorithm.rs

Lines changed: 218 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -232,33 +232,112 @@ pub fn apply_diff_copy(original: &[u8], diffset: &DiffSet<'_>) -> Result<Vec<u8>
232232
})
233233
}
234234

235-
/// This function constructs destination by merging original with diff such that destination
236-
/// becomes the changed version of the original.
235+
/// Constructs destination by applying the diff to original, such that destination becomes the
236+
/// post-diff state of the original.
237237
///
238238
/// Precondition:
239-
/// - destination.len() == original.len()
239+
/// - destination.len() == diffset.changed_len()
240+
/// - original.len() may differ from destination.len() to allow Solana
241+
/// account resizing (shrink or expand).
242+
/// Assumption:
243+
/// - destination is assumed to be zero-initialized. That automatically holds true for freshly
244+
/// allocated Solana account data. The function does NOT validate this assumption for performance reason.
245+
/// Returns:
246+
/// - Ok(n) where n is number of bytes written.
247+
/// - If n < destination.len(), then the last (destination.len() - n) bytes are not written by this
248+
/// function, and the caller, if they want, could write to these bytes starting with index `n`.
249+
/// - else n == destination.len().
250+
/// Notes:
251+
/// - Unmodified regions are copied directly from original.
252+
/// - Extra trailing bytes from original (in shrink case) are ignored.
253+
/// - In expansion case, bytes covered by diff segments are written from diffset
254+
/// (even beyond `original.len()`); any remaining bytes beyond both the diff coverage
255+
/// and original.len() are unwritten and are assumed to be zero-initialized.
256+
///
240257
pub fn merge_diff_copy(
241258
destination: &mut [u8],
242259
original: &[u8],
243260
diffset: &DiffSet<'_>,
244-
) -> Result<(), ProgramError> {
245-
if destination.len() != original.len() {
261+
) -> Result<usize, ProgramError> {
262+
if destination.len() != diffset.changed_len() {
246263
return Err(DlpError::MergeDiffError.into());
247264
}
265+
248266
let mut write_index = 0;
249267
for item in diffset.iter() {
250268
let (diff_segment, OffsetInData { start, end }) = item?;
269+
251270
if write_index < start {
271+
if start > original.len() {
272+
return Err(DlpError::InvalidDiff.into());
273+
}
252274
// copy the unchanged bytes
253275
destination[write_index..start].copy_from_slice(&original[write_index..start]);
254276
}
277+
255278
destination[start..end].copy_from_slice(diff_segment);
256279
write_index = end;
257280
}
258-
if write_index < original.len() {
259-
destination[write_index..].copy_from_slice(&original[write_index..]);
260-
}
261-
Ok(())
281+
282+
// Ensure we have overwritten all bytes in destination, otherwise "construction" of destination
283+
// will be considered incomplete.
284+
let num_bytes_written = match write_index.cmp(&destination.len()) {
285+
Ordering::Equal => {
286+
// It means the destination is fully constructed.
287+
// Nothing to do here.
288+
289+
// It is possible that destination.len() <= original.len() i.e destination might have shrunk
290+
// in which case we do not care about those bytes of original which are not part of
291+
// destination anymore.
292+
write_index
293+
}
294+
Ordering::Less => {
295+
// destination is NOT fully constructed yet. Few bytes in the destination are still unwritten.
296+
// Let's say the number of these unwritten bytes is: N.
297+
//
298+
// Now how do we construct these N unwritten bytes? We have already processed the
299+
// diffset, so now where could the values for these N bytes come from?
300+
//
301+
// There are 3 scenarios:
302+
// - All N bytes must be copied from remaining region of the original:
303+
// - that means, destination.len() <= original.len()
304+
// - and the destination might have shrunk, in which case we do not care about
305+
// the extra bytes in the original: they're discarded.
306+
// - Only (N-M) bytes come from original and the rest M bytes stay unwritten and are
307+
// "assumed" to be already zero-initialized.
308+
// - that means, destination.len() > original.len()
309+
// - write_index + (N-M) == original.len()
310+
// - and the destination has expanded.
311+
// - None of these N bytes come from original. It's basically a special case of
312+
// the second scenario: when M = N i.e all N bytes stay unwritten.
313+
// - that means, destination.len() > original.len()
314+
// - and also, write_index == original.len().
315+
// - the destination has expanded just like the above case.
316+
// - all N bytes are "assumed" to be already zero-initialized (by the caller)
317+
318+
if destination.len() <= original.len() {
319+
// case: all n bytes come from original
320+
let dest_len = destination.len();
321+
destination[write_index..].copy_from_slice(&original[write_index..dest_len]);
322+
dest_len
323+
} else if write_index < original.len() {
324+
// case: some bytes come from original and the rest are "assumed" to be
325+
// zero-initialized (by the caller).
326+
destination[write_index..original.len()].copy_from_slice(&original[write_index..]);
327+
original.len()
328+
} else {
329+
// case: all N bytes are "assumed" to be zero-initialized (by the caller).
330+
write_index
331+
}
332+
}
333+
Ordering::Greater => {
334+
// It is an impossible scenario. Even if the diff is corrupt, or the lengths of destinatiare are same
335+
// or different, we'll not encounter this case. It only implies logic error.
336+
return Err(DlpError::InfallibleError.into());
337+
}
338+
};
339+
340+
Ok(num_bytes_written)
262341
}
263342

264343
// private function that does the actual work.
@@ -297,6 +376,58 @@ mod tests {
297376
);
298377
}
299378

379+
fn get_example_expected_diff(
380+
changed_len: usize,
381+
// additional_changes must apply after index 78 (index-in-data) !!
382+
additional_changes: Vec<(u32, &[u8])>,
383+
) -> Vec<u8> {
384+
// expected: | 100 | 2 | 0 11 | 4 71 | 11 12 13 14 71 72 ... 78 |
385+
386+
let mut expected_diff = vec![];
387+
388+
// changed_len (u32)
389+
expected_diff.extend_from_slice(&(changed_len as u32).to_le_bytes());
390+
391+
if additional_changes.is_empty() {
392+
// 2 (u32)
393+
expected_diff.extend_from_slice(&2u32.to_le_bytes());
394+
} else {
395+
expected_diff
396+
.extend_from_slice(&(2u32 + additional_changes.len() as u32).to_le_bytes());
397+
}
398+
399+
// -- offsets
400+
401+
// 0 11 (each u32)
402+
expected_diff.extend_from_slice(&0u32.to_le_bytes());
403+
expected_diff.extend_from_slice(&11u32.to_le_bytes());
404+
405+
// 4 71 (each u32)
406+
expected_diff.extend_from_slice(&4u32.to_le_bytes());
407+
expected_diff.extend_from_slice(&71u32.to_le_bytes());
408+
409+
let mut offset_in_diff = 12u32;
410+
for (offset_in_data, diff) in additional_changes.iter() {
411+
expected_diff.extend_from_slice(&offset_in_diff.to_le_bytes());
412+
expected_diff.extend_from_slice(&offset_in_data.to_le_bytes());
413+
offset_in_diff += diff.len() as u32;
414+
}
415+
416+
// -- segments --
417+
418+
// 11 12 13 14 (each u8)
419+
expected_diff.extend_from_slice(&0x01020304u32.to_le_bytes());
420+
// 71 72 ... 78 (each u8)
421+
expected_diff.extend_from_slice(&0x0102030405060708u64.to_le_bytes());
422+
423+
// append diff from additional_changes
424+
for (_, diff) in additional_changes.iter() {
425+
expected_diff.extend_from_slice(diff);
426+
}
427+
428+
expected_diff
429+
}
430+
300431
#[test]
301432
fn test_using_example_data() {
302433
let original = [0; 100];
@@ -311,42 +442,99 @@ mod tests {
311442

312443
let actual_diff = compute_diff(&original, &changed);
313444
let actual_diffset = DiffSet::try_new(&actual_diff).unwrap();
314-
let expected_diff = {
315-
// expected: | 100 | 2 | 0 11 | 4 71 | 11 12 13 14 71 72 ... 78 |
445+
let expected_diff = get_example_expected_diff(changed.len(), vec![]);
316446

317-
let mut serialized = vec![];
447+
assert_eq!(actual_diff.len(), 4 + 4 + 8 + 8 + (4 + 8));
448+
assert_eq!(actual_diff.as_slice(), expected_diff.as_slice());
318449

319-
// 100 (u32)
320-
serialized.extend_from_slice(&(changed.len() as u32).to_le_bytes());
450+
let expected_changed = apply_diff_copy(&original, &actual_diffset).unwrap();
321451

322-
// 2 (u32)
323-
serialized.extend_from_slice(&2u32.to_le_bytes());
452+
assert_eq!(changed.as_slice(), expected_changed.as_slice());
453+
454+
let expected_changed = {
455+
let mut destination = vec![255; original.len()];
456+
merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap();
457+
destination
458+
};
459+
460+
assert_eq!(changed.as_slice(), expected_changed.as_slice());
461+
}
324462

325-
// 0 11 (each u32)
326-
serialized.extend_from_slice(&0u32.to_le_bytes());
327-
serialized.extend_from_slice(&11u32.to_le_bytes());
463+
#[test]
464+
fn test_shrunk_account_data() {
465+
// Note that changed_len cannot be lower than 79 because the last "changed" index is
466+
// 78 in the diff.
467+
const CHANGED_LEN: usize = 80;
328468

329-
// 4 71 (each u32)
330-
serialized.extend_from_slice(&4u32.to_le_bytes());
331-
serialized.extend_from_slice(&71u32.to_le_bytes());
469+
let original = vec![0; 100];
470+
let changed = {
471+
let mut copy = original.clone();
472+
copy.truncate(CHANGED_LEN);
332473

333-
// 11 12 13 14 (each u8)
334-
serialized.extend_from_slice(&0x01020304u32.to_le_bytes());
335-
// 71 72 ... 78 (each u8)
336-
serialized.extend_from_slice(&0x0102030405060708u64.to_le_bytes());
337-
serialized
474+
// | 11 | 12 | 13 | 14 |
475+
copy[11..=14].copy_from_slice(&0x01020304u32.to_le_bytes());
476+
// | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
477+
copy[71..=78].copy_from_slice(&0x0102030405060708u64.to_le_bytes());
478+
copy
338479
};
339480

481+
let actual_diff = compute_diff(&original, &changed);
482+
483+
let actual_diffset = DiffSet::try_new(&actual_diff).unwrap();
484+
485+
let expected_diff = get_example_expected_diff(CHANGED_LEN, vec![]);
486+
340487
assert_eq!(actual_diff.len(), 4 + 4 + 8 + 8 + (4 + 8));
341488
assert_eq!(actual_diff.as_slice(), expected_diff.as_slice());
342489

343-
let expected_changed = apply_diff_copy(&original, &actual_diffset).unwrap();
490+
let expected_changed = {
491+
let mut destination = vec![255; CHANGED_LEN];
492+
merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap();
493+
destination
494+
};
344495

345496
assert_eq!(changed.as_slice(), expected_changed.as_slice());
497+
}
498+
499+
#[test]
500+
fn test_expanded_account_data() {
501+
const CHANGED_LEN: usize = 120;
502+
503+
let original = vec![0; 100];
504+
let changed = {
505+
let mut copy = original.clone();
506+
copy.resize(CHANGED_LEN, 0); // new bytes are zero-initialized
507+
508+
// | 11 | 12 | 13 | 14 |
509+
copy[11..=14].copy_from_slice(&0x01020304u32.to_le_bytes());
510+
// | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
511+
copy[71..=78].copy_from_slice(&0x0102030405060708u64.to_le_bytes());
512+
copy
513+
};
514+
515+
let actual_diff = compute_diff(&original, &changed);
516+
517+
let actual_diffset = DiffSet::try_new(&actual_diff).unwrap();
518+
519+
// When an account expands, the extra bytes at the end become part of the diff, even if
520+
// all of them are zeroes, that is why (100, &[0; 32]) is passed as additional_changes to
521+
// the following function.
522+
//
523+
// TODO (snawaz): we could optimize compute_diff to not include the zero bytes which are
524+
// part of the expansion.
525+
let expected_diff = get_example_expected_diff(CHANGED_LEN, vec![(100, &[0; 20])]);
526+
527+
assert_eq!(actual_diff.len(), 4 + 4 + (8 + 8) + (4 + 8) + (4 + 4 + 20));
528+
assert_eq!(actual_diff.as_slice(), expected_diff.as_slice());
346529

347530
let expected_changed = {
348-
let mut destination = vec![255; original.len()];
349-
merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap();
531+
let mut destination = vec![255; CHANGED_LEN];
532+
let written = merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap();
533+
534+
// TODO (snawaz): written == 120, is because currently the expanded bytes are part of the diff.
535+
// Once compute_diff is optimized further, written must be 100.
536+
assert_eq!(written, 120);
537+
350538
destination
351539
};
352540

src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ pub enum DlpError {
4141
InvalidDiffAlignment = 16,
4242
#[error("MergeDiff precondition did not meet")]
4343
MergeDiffError = 17,
44+
#[error("An infallible error is encountered possibly due to logic error")]
45+
InfallibleError = 18,
4446
}
4547

4648
impl From<DlpError> for ProgramError {

0 commit comments

Comments
 (0)