|
| 1 | +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | + |
| 16 | +//! Downsampling for in-memory sequences. |
| 17 | +
|
| 18 | +use rand::SeedableRng; |
| 19 | +use rand_chacha::ChaCha20Rng; |
| 20 | + |
| 21 | +/// Downsamples `xs` in place to contain at most `k` elements, always including the last element. |
| 22 | +/// |
| 23 | +/// If `k == 0`, then `xs` is cleared. If `k >= xs.len()`, then `xs` is returned unchanged. |
| 24 | +/// Otherwise, a uniformly random sample of `k - 1` of the first `n - 1` elements of `xs` is chosen |
| 25 | +/// and retained, and the final element of `xs` is also retained. The relative order of elements of |
| 26 | +/// `xs` is unchanged. |
| 27 | +/// |
| 28 | +/// More declaratively: among all subsequences of `xs` of length `min(k, xs.len())` that include |
| 29 | +/// the last element, one is selected uniformly at random, and `xs` is updated in place to |
| 30 | +/// represent that subsequence. |
| 31 | +/// |
| 32 | +/// The random number generator is initialized with a fixed seed, so this function is |
| 33 | +/// deterministic. |
| 34 | +pub fn downsample<T>(xs: &mut Vec<T>, k: usize) { |
| 35 | + let n = xs.len(); |
| 36 | + if k == 0 { |
| 37 | + xs.clear(); |
| 38 | + return; |
| 39 | + } |
| 40 | + if k >= n { |
| 41 | + return; |
| 42 | + } |
| 43 | + |
| 44 | + let mut rng = ChaCha20Rng::seed_from_u64(0); |
| 45 | + |
| 46 | + // Choose `k - 1` of the `n - 1` indices to keep, and move their elements into place. Then, |
| 47 | + // move the last element into place and drop extra elements. |
| 48 | + let mut indices = rand::seq::index::sample(&mut rng, n - 1, k - 1).into_vec(); |
| 49 | + indices.sort_unstable(); |
| 50 | + for (dst, src) in indices.into_iter().enumerate() { |
| 51 | + xs.swap(dst, src); |
| 52 | + } |
| 53 | + xs.swap(k - 1, n - 1); |
| 54 | + xs.truncate(k); |
| 55 | +} |
| 56 | + |
| 57 | +#[cfg(test)] |
| 58 | +mod tests { |
| 59 | + use super::*; |
| 60 | + |
| 61 | + /// Clones `xs` and [`downsample`]s the result to `k` elements. |
| 62 | + fn downsample_cloned<T: Clone>(xs: &Vec<T>, k: usize) -> Vec<T> { |
| 63 | + let mut ys = xs.clone(); |
| 64 | + downsample(&mut ys, k); |
| 65 | + ys |
| 66 | + } |
| 67 | + |
| 68 | + #[test] |
| 69 | + fn test_deterministic() { |
| 70 | + let xs: Vec<char> = "abcdefg".chars().collect(); |
| 71 | + let expected = downsample_cloned(&xs, 4); |
| 72 | + assert_eq!(expected.len(), 4); |
| 73 | + for _ in 0..100 { |
| 74 | + assert_eq!(downsample_cloned(&xs, 4), expected); |
| 75 | + } |
| 76 | + } |
| 77 | + |
| 78 | + #[test] |
| 79 | + fn test_ok_when_k_greater_than_n() { |
| 80 | + let xs: Vec<char> = "abcdefg".chars().collect(); |
| 81 | + assert_eq!(downsample_cloned(&xs, 10), xs); |
| 82 | + assert_eq!(downsample_cloned(&xs, usize::MAX), xs); |
| 83 | + } |
| 84 | + |
| 85 | + #[test] |
| 86 | + fn test_inorder_plus_last() { |
| 87 | + let xs: Vec<u32> = downsample_cloned(&(0..10000).collect(), 100); |
| 88 | + let mut ys = xs.clone(); |
| 89 | + ys.sort(); |
| 90 | + assert_eq!(xs, ys); |
| 91 | + assert_eq!(xs.last(), Some(&9999)); |
| 92 | + } |
| 93 | + |
| 94 | + #[test] |
| 95 | + fn test_zero_k() { |
| 96 | + for n in 0..3 { |
| 97 | + let xs: Vec<u32> = (0..n).collect(); |
| 98 | + assert_eq!(downsample_cloned(&xs, 0), Vec::<u32>::new()); |
| 99 | + } |
| 100 | + } |
| 101 | + |
| 102 | + #[test] |
| 103 | + fn test_zero_n() { |
| 104 | + let xs: Vec<u32> = vec![]; |
| 105 | + for k in 0..3 { |
| 106 | + assert_eq!(downsample_cloned(&xs, k), Vec::<u32>::new()); |
| 107 | + } |
| 108 | + } |
| 109 | +} |
0 commit comments