Skip to content

Commit 87b7051

Browse files
authored
rust: implement RPC-time downsampling (#4403)
Summary: A new `downsample` function performs TensorBoard-style downsampling of a vector: random subsequence, relative order preserved, last element always kept. The RPC server uses this to honor the downsampling specifications for `ReadScalars` RPCs. Test Plan: Unit tests included for both the utility function and the RPC. wchargin-branch: rust-downsample
1 parent 96654d0 commit 87b7051

File tree

4 files changed

+127
-8
lines changed

4 files changed

+127
-8
lines changed

tensorboard/data/server/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ rust_library(
2929
"lib.rs",
3030
"commit.rs",
3131
"data_compat.rs",
32+
"downsample.rs",
3233
"event_file.rs",
3334
"logdir.rs",
3435
"masked_crc.rs",

tensorboard/data/server/downsample.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
//! Downsampling for in-memory sequences.
17+
18+
use rand::SeedableRng;
19+
use rand_chacha::ChaCha20Rng;
20+
21+
/// Downsamples `xs` in place to contain at most `k` elements, always including the last element.
22+
///
23+
/// If `k == 0`, then `xs` is cleared. If `k >= xs.len()`, then `xs` is returned unchanged.
24+
/// Otherwise, a uniformly random sample of `k - 1` of the first `n - 1` elements of `xs` is chosen
25+
/// and retained, and the final element of `xs` is also retained. The relative order of elements of
26+
/// `xs` is unchanged.
27+
///
28+
/// More declaratively: among all subsequences of `xs` of length `min(k, xs.len())` that include
29+
/// the last element, one is selected uniformly at random, and `xs` is updated in place to
30+
/// represent that subsequence.
31+
///
32+
/// The random number generator is initialized with a fixed seed, so this function is
33+
/// deterministic.
34+
pub fn downsample<T>(xs: &mut Vec<T>, k: usize) {
35+
let n = xs.len();
36+
if k == 0 {
37+
xs.clear();
38+
return;
39+
}
40+
if k >= n {
41+
return;
42+
}
43+
44+
let mut rng = ChaCha20Rng::seed_from_u64(0);
45+
46+
// Choose `k - 1` of the `n - 1` indices to keep, and move their elements into place. Then,
47+
// move the last element into place and drop extra elements.
48+
let mut indices = rand::seq::index::sample(&mut rng, n - 1, k - 1).into_vec();
49+
indices.sort_unstable();
50+
for (dst, src) in indices.into_iter().enumerate() {
51+
xs.swap(dst, src);
52+
}
53+
xs.swap(k - 1, n - 1);
54+
xs.truncate(k);
55+
}
56+
57+
#[cfg(test)]
58+
mod tests {
59+
use super::*;
60+
61+
/// Clones `xs` and [`downsample`]s the result to `k` elements.
62+
fn downsample_cloned<T: Clone>(xs: &Vec<T>, k: usize) -> Vec<T> {
63+
let mut ys = xs.clone();
64+
downsample(&mut ys, k);
65+
ys
66+
}
67+
68+
#[test]
69+
fn test_deterministic() {
70+
let xs: Vec<char> = "abcdefg".chars().collect();
71+
let expected = downsample_cloned(&xs, 4);
72+
assert_eq!(expected.len(), 4);
73+
for _ in 0..100 {
74+
assert_eq!(downsample_cloned(&xs, 4), expected);
75+
}
76+
}
77+
78+
#[test]
79+
fn test_ok_when_k_greater_than_n() {
80+
let xs: Vec<char> = "abcdefg".chars().collect();
81+
assert_eq!(downsample_cloned(&xs, 10), xs);
82+
assert_eq!(downsample_cloned(&xs, usize::MAX), xs);
83+
}
84+
85+
#[test]
86+
fn test_inorder_plus_last() {
87+
let xs: Vec<u32> = downsample_cloned(&(0..10000).collect(), 100);
88+
let mut ys = xs.clone();
89+
ys.sort();
90+
assert_eq!(xs, ys);
91+
assert_eq!(xs.last(), Some(&9999));
92+
}
93+
94+
#[test]
95+
fn test_zero_k() {
96+
for n in 0..3 {
97+
let xs: Vec<u32> = (0..n).collect();
98+
assert_eq!(downsample_cloned(&xs, 0), Vec::<u32>::new());
99+
}
100+
}
101+
102+
#[test]
103+
fn test_zero_n() {
104+
let xs: Vec<u32> = vec![];
105+
for k in 0..3 {
106+
assert_eq!(downsample_cloned(&xs, k), Vec::<u32>::new());
107+
}
108+
}
109+
}

tensorboard/data/server/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919

2020
pub mod commit;
2121
pub mod data_compat;
22+
pub mod downsample;
2223
pub mod event_file;
2324
pub mod logdir;
2425
pub mod masked_crc;

tensorboard/data/server/server.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ use futures_core::Stream;
1717
use std::borrow::Borrow;
1818
use std::collections::HashMap;
1919
use std::collections::HashSet;
20+
use std::convert::TryInto;
2021
use std::hash::Hash;
2122
use std::pin::Pin;
2223
use std::sync::{RwLock, RwLockReadGuard};
2324
use tonic::{Request, Response, Status};
2425

2526
use crate::commit::{self, Commit};
27+
use crate::downsample;
2628
use crate::proto::tensorboard::data;
2729
use crate::types::{Run, Tag, WallTime};
2830
use data::tensor_board_data_provider_server::TensorBoardDataProvider;
@@ -158,7 +160,7 @@ impl TensorBoardDataProvider for DataProviderHandler {
158160
let req = req.into_inner();
159161
let want_plugin = parse_plugin_filter(req.plugin_filter)?;
160162
let (run_filter, tag_filter) = parse_rtf(req.run_tag_filter);
161-
let _downsample = parse_downsample(req.downsample)?; // TODO(@wchargin): Use `downsample`.
163+
let num_points = parse_downsample(req.downsample)?;
162164
let runs = self.read_runs()?;
163165

164166
let mut res: data::ReadScalarsResponse = Default::default();
@@ -183,11 +185,13 @@ impl TensorBoardDataProvider for DataProviderHandler {
183185
continue;
184186
}
185187

186-
let n = ts.valid_values().count();
188+
let mut points = ts.valid_values().collect::<Vec<_>>();
189+
downsample::downsample(&mut points, num_points);
190+
let n = points.len();
187191
let mut steps = Vec::with_capacity(n);
188192
let mut wall_times = Vec::with_capacity(n);
189193
let mut values = Vec::with_capacity(n);
190-
for (step, wall_time, &commit::ScalarValue(value)) in ts.valid_values() {
194+
for (step, wall_time, &commit::ScalarValue(value)) in points {
191195
steps.push(step.into());
192196
wall_times.push(wall_time.into());
193197
values.push(value);
@@ -278,7 +282,7 @@ fn parse_rtf(rtf: Option<data::RunTagFilter>) -> (Filter<Run>, Filter<Tag>) {
278282
}
279283

280284
/// Parses `Downsample.num_points` from a request, failing if it's not given or invalid.
281-
fn parse_downsample(downsample: Option<data::Downsample>) -> Result<i64, Status> {
285+
fn parse_downsample(downsample: Option<data::Downsample>) -> Result<usize, Status> {
282286
let num_points = downsample
283287
.ok_or_else(|| Status::invalid_argument("must specify downsample"))?
284288
.num_points;
@@ -288,7 +292,13 @@ fn parse_downsample(downsample: Option<data::Downsample>) -> Result<i64, Status>
288292
num_points
289293
)));
290294
}
291-
Ok(num_points)
295+
num_points.try_into().map_err(|_| {
296+
Status::out_of_range(format!(
297+
"num_points ({}) is too large for this system; max: {}",
298+
num_points,
299+
usize::MAX
300+
))
301+
})
292302
}
293303

294304
/// A predicate that accepts either all values or just an explicit set of values.
@@ -575,8 +585,6 @@ mod tests {
575585
let map = run_tag_map!(res.runs);
576586
let train_run = &map[&Run("train".to_string())];
577587
let xent_data = &train_run[&Tag("xent".to_string())].data.as_ref().unwrap();
578-
// TODO(@wchargin): Enable once downsampling is implemented.
579-
// assert_eq!(xent_data.value.len(), 0);
580-
let _ = xent_data;
588+
assert_eq!(xent_data.value, Vec::new());
581589
}
582590
}

0 commit comments

Comments
 (0)