Skip to content

Commit 33977d3

Browse files
authored
Merge pull request #399 from ramon-garcia/checkpoint
Add support for saving to and restoring from checkpoints
2 parents b40f8a7 + d04c852 commit 33977d3

File tree

4 files changed

+400
-0
lines changed

4 files changed

+400
-0
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ rustversion = "1.0.9"
3737
[dev-dependencies]
3838
rand = "0.8.5"
3939
serial_test = "0.9.0"
40+
tempdir = "0.3"
4041

4142
[features]
4243
default = ["tensorflow-sys"]

src/checkpoint.rs

+376
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
use crate::option_insert_result::OptionInsertWithResult;
2+
use crate::{ops, Operation, Scope, Session, SessionRunArgs, Status, Tensor, Variable};
3+
4+
#[derive(Debug)]
5+
struct SaveRestoreOps {
6+
prefix_save: Operation,
7+
prefix_restore: Operation,
8+
save_op: Operation,
9+
restore_op: Operation,
10+
}
11+
12+
/// This struct supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format.
13+
/// First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored.
14+
/// The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring.
15+
/// When one wants to save/restore from or into a session, one calls the save/restore methods
16+
/// # Example
17+
/// ```
18+
/// let mut scope = Scope::new_root_scope();
19+
/// // add operations to define the graph
20+
/// // ...
21+
/// // let w and b the variables that we wish to save
22+
/// let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"),
23+
/// vec![w.clone(), b.clone()].into_boxed_slice(),
24+
/// );
25+
/// let session = Session::new(&SessionOptions::new(), &scope.graph())?;
26+
/// // run some training
27+
/// // ...
28+
/// // to save the training
29+
/// checkpoint_maker.save(&session, "data/checkpoint")?;
30+
/// // then we restore in a different session to continue there
31+
/// let new_session = Session::new(&SessionOptions::new(), &scope.graph())?;
32+
/// checkpoint_maker.restore(&new_session, "data/checkpoint")?;
33+
/// ```
34+
///
35+
#[derive(Debug)]
36+
pub struct CheckpointMaker {
37+
scope: Scope,
38+
variables: Box<[Variable]>,
39+
save_restore_ops: Option<SaveRestoreOps>,
40+
}
41+
42+
impl CheckpointMaker {
43+
/// Creates a new CheckpointMaker for a Scope, with a list of variables to save/restore.
44+
/// The scope is used to modify the graph to add the save and restore ops.
45+
///
46+
/// In order to provide a scope for the CheckpointMaker one can use scope.new_sub_scope("checkpoint")
47+
/// in order to create the nodes with scoped names.
48+
pub fn new(scope: Scope, variables: Box<[Variable]>) -> CheckpointMaker {
49+
CheckpointMaker {
50+
scope,
51+
variables,
52+
save_restore_ops: None,
53+
}
54+
}
55+
56+
// Add save and restore ops to the graph.
57+
fn build_save_ops(&mut self) -> Result<SaveRestoreOps, Status> {
58+
let mut all_variable_ops_opt: Option<Vec<Operation>> = None;
59+
60+
let existing_save_op = self.scope.graph().operation_by_name("save")?;
61+
let (prefix_save, save_op) = if let Some(op) = existing_save_op {
62+
let prefix_save_op = self
63+
.scope
64+
.graph()
65+
.operation_by_name_required("prefix_save")?;
66+
(prefix_save_op, op)
67+
} else {
68+
let all_variable_ops = all_variable_ops_opt.get_or_insert_with(|| {
69+
self.variables
70+
.iter()
71+
.map(|v| v.output.operation.clone())
72+
.collect::<Vec<_>>()
73+
});
74+
let prefix_save = ops::Placeholder::new()
75+
.dtype(crate::DataType::String)
76+
.build(&mut self.scope.with_op_name("prefix_save"))?;
77+
let tensor_names = ops::constant(
78+
self.variables
79+
.iter()
80+
.map(|v| String::from(v.name()))
81+
.collect::<Vec<_>>()
82+
.as_slice(),
83+
&mut self.scope,
84+
)?;
85+
let shape_and_slices = ops::constant(
86+
&self
87+
.variables
88+
.iter()
89+
.map(|_| "".to_string())
90+
.collect::<Vec<_>>()[..],
91+
&mut self.scope,
92+
)?;
93+
let tensors = all_variable_ops
94+
.iter()
95+
.map(|v| v.output(0).clone())
96+
.collect::<Vec<_>>();
97+
98+
let mut g = self.scope.graph_mut();
99+
let mut nd = g.new_operation("SaveV2", "save")?;
100+
nd.add_input(prefix_save.clone());
101+
nd.add_input(tensor_names);
102+
nd.add_input(shape_and_slices);
103+
nd.add_input_list(&tensors[..]);
104+
105+
let dtypes = all_variable_ops
106+
.iter()
107+
.map(|v| v.get_attr_type("dtype"))
108+
.collect::<Result<Vec<_>, Status>>()?;
109+
nd.set_attr_type_list("dtypes", &dtypes[..])?;
110+
let save_op = nd.finish()?;
111+
(prefix_save, save_op)
112+
};
113+
let opt_restore_op = self.scope.graph().operation_by_name("restore")?;
114+
let (prefix_restore, restore_op) = if let Some(op) = opt_restore_op {
115+
let the_prefix_restore = self
116+
.scope
117+
.graph()
118+
.operation_by_name_required("prefix_restore")?;
119+
(the_prefix_restore, op)
120+
} else {
121+
let all_variable_ops = all_variable_ops_opt.get_or_insert_with(|| {
122+
self.variables
123+
.iter()
124+
.map(|v| v.output.operation.clone())
125+
.collect::<Vec<_>>()
126+
});
127+
let prefix_restore = ops::Placeholder::new()
128+
.dtype(crate::DataType::String)
129+
.build(&mut self.scope.with_op_name("prefix_restore"))?;
130+
let all_var_names = self
131+
.variables
132+
.iter()
133+
.map(|v| v.name.clone())
134+
.collect::<Vec<_>>();
135+
let tensor_names = ops::constant(&all_var_names[..], &mut self.scope)?;
136+
let shape_and_slices = ops::constant(
137+
&self
138+
.variables
139+
.iter()
140+
.map(|_| "".to_string())
141+
.collect::<Vec<_>>()[..],
142+
&mut self.scope,
143+
)?;
144+
let mut g = self.scope.graph_mut();
145+
let mut nd = g.new_operation("RestoreV2", "restore")?;
146+
nd.add_input(prefix_restore.clone());
147+
nd.add_input(tensor_names);
148+
nd.add_input(shape_and_slices);
149+
let dtypes = all_variable_ops
150+
.iter()
151+
.map(|v| v.get_attr_type("dtype"))
152+
.collect::<Result<Vec<_>, Status>>()?;
153+
nd.set_attr_type_list("dtypes", &dtypes[..])?;
154+
let restore_op = nd.finish()?;
155+
drop(g);
156+
let mut restore_var_ops = Vec::<Operation>::new();
157+
for (i, var) in self.variables.iter().enumerate() {
158+
let var_op = var.output.operation.clone();
159+
restore_var_ops.push(ops::assign(
160+
var_op,
161+
crate::Output {
162+
operation: restore_op.clone(),
163+
index: i as i32,
164+
},
165+
&mut self.scope.new_sub_scope(format!("restore{}", i).as_str()),
166+
)?);
167+
}
168+
let mut no_op = ops::NoOp::new();
169+
for op in restore_var_ops {
170+
no_op = no_op.add_control_input(op);
171+
}
172+
(prefix_restore, no_op.build(&mut self.scope)?)
173+
};
174+
Ok(SaveRestoreOps {
175+
prefix_save,
176+
prefix_restore,
177+
save_op,
178+
restore_op,
179+
})
180+
}
181+
182+
fn get_save_operation(&mut self) -> Result<&SaveRestoreOps, Status> {
183+
if self.save_restore_ops.is_none() {
184+
self.save_restore_ops = Some(self.build_save_ops()?);
185+
}
186+
let save_r_op_ref = self.save_restore_ops.as_ref();
187+
// SAFETY: the condition above has ensured that self.save_restore_ops is Some(_)
188+
let save_r_op = unsafe { save_r_op_ref.unwrap_unchecked() };
189+
Ok(save_r_op)
190+
}
191+
192+
/// Save the variables listed in this CheckpointMaker to the checkpoint with base filename backup_filename_base.
193+
pub fn save(&mut self, session: &Session, backup_filename_base: &str) -> Result<(), Status> {
194+
let save_restore_ops = self.get_save_operation()?;
195+
let prefix_arg = Tensor::from(backup_filename_base.to_string());
196+
let mut run_args = SessionRunArgs::new();
197+
run_args.add_feed(&save_restore_ops.prefix_save, 0, &prefix_arg);
198+
run_args.add_target(&save_restore_ops.save_op);
199+
session.run(&mut run_args)?;
200+
Ok(())
201+
}
202+
203+
/// Restore into the session the variables listed in this CheckpointMaker from the checkpoint
204+
/// in path_base.
205+
pub fn restore(&mut self, session: &Session, path_base: &str) -> Result<(), Status> {
206+
let save_restore_ops = self.get_save_operation()?;
207+
let prefix_arg = Tensor::from(path_base.to_string());
208+
let mut run_args = SessionRunArgs::new();
209+
run_args.add_feed(&save_restore_ops.prefix_restore, 0, &prefix_arg);
210+
run_args.add_target(&save_restore_ops.restore_op);
211+
session.run(&mut run_args)?;
212+
Ok(())
213+
}
214+
}
215+
216+
#[cfg(test)]
217+
mod tests {
218+
use crate::ops::Placeholder;
219+
use crate::{
220+
ops, CheckpointMaker, Code, DataType, FetchToken, Operation, Scope, Session,
221+
SessionOptions, SessionRunArgs, Status, Tensor, Variable,
222+
};
223+
224+
fn make_variable(
225+
scope: &mut Scope,
226+
name: &str,
227+
dims: &[u64],
228+
values: &[f32],
229+
) -> Result<Variable, Status> {
230+
Ok(Variable::builder()
231+
.const_initial_value(Tensor::new(dims).with_values(values)?)
232+
.data_type(DataType::Float)
233+
.build(&mut scope.with_op_name(name))?)
234+
}
235+
236+
fn create_assignment(
237+
var: &Variable,
238+
scope: &mut Scope,
239+
) -> Result<(Operation, Operation), Status> {
240+
let placeholder = Placeholder::new()
241+
.dtype(DataType::Float)
242+
.shape(var.shape.clone())
243+
.build(&mut scope.with_op_name(var.name.as_str()))?;
244+
Ok((
245+
placeholder.clone(),
246+
ops::assign(var.output.clone(), placeholder, scope)?,
247+
))
248+
}
249+
250+
struct MyScopeData {
251+
scope: Scope,
252+
variables: [Variable; 3],
253+
}
254+
255+
// Initialize a scope and place same variables in it
256+
fn create_scope() -> Result<MyScopeData, Status> {
257+
let mut scope = Scope::new_root_scope();
258+
let var_w = make_variable(&mut scope, "w", &[], &[2.2])?;
259+
let var_b = make_variable(&mut scope, "b", &[3], &[1.0, 2.0, 4.5])?;
260+
let var_a = make_variable(&mut scope, "a", &[3, 2], &[1.0, 2.0, 3.3, 7.0, 8.0, 8.5])?;
261+
Ok(MyScopeData {
262+
scope,
263+
variables: [var_w, var_b, var_a],
264+
})
265+
}
266+
267+
struct AssignData {
268+
pub placeholder_ops: Box<[Operation]>,
269+
pub assign_op: Operation,
270+
}
271+
fn add_assign_op(scope_data: &mut MyScopeData) -> Result<AssignData, Status> {
272+
let mut placeholder_scope = scope_data.scope.new_sub_scope("placeholder");
273+
let mut placeholders: Vec<Operation> = Vec::new();
274+
let mut no_op_bld = ops::NoOp::new();
275+
for var in scope_data.variables.as_ref() {
276+
let (placeholder, assign_op) = create_assignment(&var, &mut placeholder_scope)?;
277+
placeholders.push(placeholder);
278+
no_op_bld = no_op_bld.add_control_input(assign_op);
279+
}
280+
let assign_op = no_op_bld.build(&mut scope_data.scope)?;
281+
Ok(AssignData {
282+
placeholder_ops: placeholders.into_boxed_slice(),
283+
assign_op,
284+
})
285+
}
286+
287+
fn assign_variables(
288+
session: &Session,
289+
scope_data: &MyScopeData,
290+
assign_data: &AssignData,
291+
values: &[&[f32]],
292+
) -> Result<(), Status> {
293+
let mut values_fed: Vec<Tensor<f32>> =
294+
Vec::with_capacity(assign_data.placeholder_ops.len());
295+
let mut session_run = SessionRunArgs::new();
296+
for i_var in 0..assign_data.placeholder_ops.len() {
297+
let value_fed_as_tensor = Tensor::new(
298+
&scope_data.variables[i_var]
299+
.shape()
300+
.0
301+
.as_ref()
302+
.ok_or(Status::new_set(Code::Internal, "Shape not present")?)?
303+
.iter()
304+
.map(|o| {
305+
o.map(|i| i as u64)
306+
.ok_or(Status::new_set(Code::Internal, "Shape item not present")?)
307+
})
308+
.collect::<Result<Vec<u64>, Status>>()?
309+
.as_ref(),
310+
)
311+
.with_values(&values[i_var])?;
312+
values_fed.push(value_fed_as_tensor);
313+
}
314+
for i_var in 0..assign_data.placeholder_ops.len() {
315+
session_run.add_feed(&assign_data.placeholder_ops[i_var], 0, &values_fed[i_var]);
316+
}
317+
session_run.add_target(&assign_data.assign_op);
318+
session.run(&mut session_run)?;
319+
Ok(())
320+
}
321+
322+
fn check_variables(
323+
session: &Session,
324+
variables: &[Variable],
325+
values: &[&[f32]],
326+
) -> Result<(), Status> {
327+
let mut session_run = SessionRunArgs::new();
328+
let mut tokens: Vec<FetchToken> = Vec::with_capacity(variables.len());
329+
for i in 0..variables.len() {
330+
tokens.push(session_run.request_fetch(
331+
&variables[i].output().operation,
332+
variables[i].output().index,
333+
));
334+
}
335+
session.run(&mut session_run)?;
336+
for i in 0..variables.len() {
337+
let got_tensor: Tensor<f32> = session_run.fetch(tokens[i])?;
338+
assert_eq!(values[i], got_tensor.as_ref());
339+
}
340+
Ok(())
341+
}
342+
343+
#[test]
344+
fn simple_save() -> Result<(), Box<dyn std::error::Error>> {
345+
let mut first_scope_data = create_scope()?;
346+
let assign_data = add_assign_op(&mut first_scope_data)?;
347+
let first_session = Session::new(&SessionOptions::new(), &first_scope_data.scope.graph())?;
348+
let new_values: [&[f32]; 3] = [
349+
&[5.1],
350+
&[4.0, 2.2, 6.0],
351+
&[11.0, 12.0, 13.6, 17.1, 18.4, 19.5],
352+
];
353+
assign_variables(&first_session, &first_scope_data, &assign_data, &new_values)?;
354+
let mut checkpoint = CheckpointMaker::new(
355+
first_scope_data.scope.new_sub_scope("checkpoint"),
356+
Box::from(first_scope_data.variables.clone()),
357+
);
358+
let temp_dir = tempdir::TempDir::new("test-tensorflow")?;
359+
let checkpoint_path = temp_dir.path().join("checkpoint-vars");
360+
let checkpoint_path_str = checkpoint_path
361+
.into_os_string()
362+
.into_string()
363+
.map_err(|_| "Cannot convert checkpoint path")?;
364+
checkpoint.save(&first_session, checkpoint_path_str.as_str())?;
365+
let MyScopeData {
366+
scope: second_scope,
367+
variables: second_variables,
368+
} = create_scope()?;
369+
let second_session = Session::new(&SessionOptions::new(), &second_scope.graph())?;
370+
let mut second_checkpoint =
371+
CheckpointMaker::new(second_scope, Box::new(second_variables.clone()));
372+
second_checkpoint.restore(&second_session, checkpoint_path_str.as_str())?;
373+
check_variables(&second_session, &second_variables, &new_values)?;
374+
Ok(())
375+
}
376+
}

src/lib.rs

+5
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,11 @@ pub mod train;
207207
mod saved_model;
208208
pub use saved_model::*;
209209

210+
mod checkpoint;
211+
pub use checkpoint::*;
212+
213+
mod option_insert_result;
214+
210215
#[cfg(feature = "eager")]
211216
pub mod eager;
212217

0 commit comments

Comments
 (0)