|
| 1 | +use crate::option_insert_result::OptionInsertWithResult; |
| 2 | +use crate::{ops, Operation, Scope, Session, SessionRunArgs, Status, Tensor, Variable}; |
| 3 | + |
| 4 | +#[derive(Debug)] |
| 5 | +struct SaveRestoreOps { |
| 6 | + prefix_save: Operation, |
| 7 | + prefix_restore: Operation, |
| 8 | + save_op: Operation, |
| 9 | + restore_op: Operation, |
| 10 | +} |
| 11 | + |
| 12 | +/// This struct supports saving and restoring variables using Tensorflow checkpoints in SaveV2 format. |
| 13 | +/// First, the user creates a [`CheckpointMaker`] attached to an [`Scope`] indicating the list of variables to be saved/restored. |
| 14 | +/// The CheckpointMaker lazily modifies the graph creating the nodes needed for saving/restoring. |
| 15 | +/// When one wants to save/restore from or into a session, one calls the save/restore methods |
| 16 | +/// # Example |
| 17 | +/// ``` |
| 18 | +/// let mut scope = Scope::new_root_scope(); |
| 19 | +/// // add operations to define the graph |
| 20 | +/// // ... |
| 21 | +/// // let w and b the variables that we wish to save |
| 22 | +/// let mut checkpoint_maker = CheckpointMaker::new(scope.new_sub_scope("checkpoint"), |
| 23 | +/// vec![w.clone(), b.clone()].into_boxed_slice(), |
| 24 | +/// ); |
| 25 | +/// let session = Session::new(&SessionOptions::new(), &scope.graph())?; |
| 26 | +/// // run some training |
| 27 | +/// // ... |
| 28 | +/// // to save the training |
| 29 | +/// checkpoint_maker.save(&session, "data/checkpoint")?; |
| 30 | +/// // then we restore in a different session to continue there |
| 31 | +/// let new_session = Session::new(&SessionOptions::new(), &scope.graph())?; |
| 32 | +/// checkpoint_maker.restore(&new_session, "data/checkpoint")?; |
| 33 | +/// ``` |
| 34 | +/// |
| 35 | +#[derive(Debug)] |
| 36 | +pub struct CheckpointMaker { |
| 37 | + scope: Scope, |
| 38 | + variables: Box<[Variable]>, |
| 39 | + save_restore_ops: Option<SaveRestoreOps>, |
| 40 | +} |
| 41 | + |
| 42 | +impl CheckpointMaker { |
| 43 | + /// Creates a new CheckpointMaker for a Scope, with a list of variables to save/restore. |
| 44 | + /// The scope is used to modify the graph to add the save and restore ops. |
| 45 | + /// |
| 46 | + /// In order to provide a scope for the CheckpointMaker one can use scope.new_sub_scope("checkpoint") |
| 47 | + /// in order to create the nodes with scoped names. |
| 48 | + pub fn new(scope: Scope, variables: Box<[Variable]>) -> CheckpointMaker { |
| 49 | + CheckpointMaker { |
| 50 | + scope, |
| 51 | + variables, |
| 52 | + save_restore_ops: None, |
| 53 | + } |
| 54 | + } |
| 55 | + |
| 56 | + // Add save and restore ops to the graph. |
| 57 | + fn build_save_ops(&mut self) -> Result<SaveRestoreOps, Status> { |
| 58 | + let mut all_variable_ops_opt: Option<Vec<Operation>> = None; |
| 59 | + |
| 60 | + let existing_save_op = self.scope.graph().operation_by_name("save")?; |
| 61 | + let (prefix_save, save_op) = if let Some(op) = existing_save_op { |
| 62 | + let prefix_save_op = self |
| 63 | + .scope |
| 64 | + .graph() |
| 65 | + .operation_by_name_required("prefix_save")?; |
| 66 | + (prefix_save_op, op) |
| 67 | + } else { |
| 68 | + let all_variable_ops = all_variable_ops_opt.get_or_insert_with(|| { |
| 69 | + self.variables |
| 70 | + .iter() |
| 71 | + .map(|v| v.output.operation.clone()) |
| 72 | + .collect::<Vec<_>>() |
| 73 | + }); |
| 74 | + let prefix_save = ops::Placeholder::new() |
| 75 | + .dtype(crate::DataType::String) |
| 76 | + .build(&mut self.scope.with_op_name("prefix_save"))?; |
| 77 | + let tensor_names = ops::constant( |
| 78 | + self.variables |
| 79 | + .iter() |
| 80 | + .map(|v| String::from(v.name())) |
| 81 | + .collect::<Vec<_>>() |
| 82 | + .as_slice(), |
| 83 | + &mut self.scope, |
| 84 | + )?; |
| 85 | + let shape_and_slices = ops::constant( |
| 86 | + &self |
| 87 | + .variables |
| 88 | + .iter() |
| 89 | + .map(|_| "".to_string()) |
| 90 | + .collect::<Vec<_>>()[..], |
| 91 | + &mut self.scope, |
| 92 | + )?; |
| 93 | + let tensors = all_variable_ops |
| 94 | + .iter() |
| 95 | + .map(|v| v.output(0).clone()) |
| 96 | + .collect::<Vec<_>>(); |
| 97 | + |
| 98 | + let mut g = self.scope.graph_mut(); |
| 99 | + let mut nd = g.new_operation("SaveV2", "save")?; |
| 100 | + nd.add_input(prefix_save.clone()); |
| 101 | + nd.add_input(tensor_names); |
| 102 | + nd.add_input(shape_and_slices); |
| 103 | + nd.add_input_list(&tensors[..]); |
| 104 | + |
| 105 | + let dtypes = all_variable_ops |
| 106 | + .iter() |
| 107 | + .map(|v| v.get_attr_type("dtype")) |
| 108 | + .collect::<Result<Vec<_>, Status>>()?; |
| 109 | + nd.set_attr_type_list("dtypes", &dtypes[..])?; |
| 110 | + let save_op = nd.finish()?; |
| 111 | + (prefix_save, save_op) |
| 112 | + }; |
| 113 | + let opt_restore_op = self.scope.graph().operation_by_name("restore")?; |
| 114 | + let (prefix_restore, restore_op) = if let Some(op) = opt_restore_op { |
| 115 | + let the_prefix_restore = self |
| 116 | + .scope |
| 117 | + .graph() |
| 118 | + .operation_by_name_required("prefix_restore")?; |
| 119 | + (the_prefix_restore, op) |
| 120 | + } else { |
| 121 | + let all_variable_ops = all_variable_ops_opt.get_or_insert_with(|| { |
| 122 | + self.variables |
| 123 | + .iter() |
| 124 | + .map(|v| v.output.operation.clone()) |
| 125 | + .collect::<Vec<_>>() |
| 126 | + }); |
| 127 | + let prefix_restore = ops::Placeholder::new() |
| 128 | + .dtype(crate::DataType::String) |
| 129 | + .build(&mut self.scope.with_op_name("prefix_restore"))?; |
| 130 | + let all_var_names = self |
| 131 | + .variables |
| 132 | + .iter() |
| 133 | + .map(|v| v.name.clone()) |
| 134 | + .collect::<Vec<_>>(); |
| 135 | + let tensor_names = ops::constant(&all_var_names[..], &mut self.scope)?; |
| 136 | + let shape_and_slices = ops::constant( |
| 137 | + &self |
| 138 | + .variables |
| 139 | + .iter() |
| 140 | + .map(|_| "".to_string()) |
| 141 | + .collect::<Vec<_>>()[..], |
| 142 | + &mut self.scope, |
| 143 | + )?; |
| 144 | + let mut g = self.scope.graph_mut(); |
| 145 | + let mut nd = g.new_operation("RestoreV2", "restore")?; |
| 146 | + nd.add_input(prefix_restore.clone()); |
| 147 | + nd.add_input(tensor_names); |
| 148 | + nd.add_input(shape_and_slices); |
| 149 | + let dtypes = all_variable_ops |
| 150 | + .iter() |
| 151 | + .map(|v| v.get_attr_type("dtype")) |
| 152 | + .collect::<Result<Vec<_>, Status>>()?; |
| 153 | + nd.set_attr_type_list("dtypes", &dtypes[..])?; |
| 154 | + let restore_op = nd.finish()?; |
| 155 | + drop(g); |
| 156 | + let mut restore_var_ops = Vec::<Operation>::new(); |
| 157 | + for (i, var) in self.variables.iter().enumerate() { |
| 158 | + let var_op = var.output.operation.clone(); |
| 159 | + restore_var_ops.push(ops::assign( |
| 160 | + var_op, |
| 161 | + crate::Output { |
| 162 | + operation: restore_op.clone(), |
| 163 | + index: i as i32, |
| 164 | + }, |
| 165 | + &mut self.scope.new_sub_scope(format!("restore{}", i).as_str()), |
| 166 | + )?); |
| 167 | + } |
| 168 | + let mut no_op = ops::NoOp::new(); |
| 169 | + for op in restore_var_ops { |
| 170 | + no_op = no_op.add_control_input(op); |
| 171 | + } |
| 172 | + (prefix_restore, no_op.build(&mut self.scope)?) |
| 173 | + }; |
| 174 | + Ok(SaveRestoreOps { |
| 175 | + prefix_save, |
| 176 | + prefix_restore, |
| 177 | + save_op, |
| 178 | + restore_op, |
| 179 | + }) |
| 180 | + } |
| 181 | + |
| 182 | + fn get_save_operation(&mut self) -> Result<&SaveRestoreOps, Status> { |
| 183 | + if self.save_restore_ops.is_none() { |
| 184 | + self.save_restore_ops = Some(self.build_save_ops()?); |
| 185 | + } |
| 186 | + let save_r_op_ref = self.save_restore_ops.as_ref(); |
| 187 | + // SAFETY: the condition above has ensured that self.save_restore_ops is Some(_) |
| 188 | + let save_r_op = unsafe { save_r_op_ref.unwrap_unchecked() }; |
| 189 | + Ok(save_r_op) |
| 190 | + } |
| 191 | + |
| 192 | + /// Save the variables listed in this CheckpointMaker to the checkpoint with base filename backup_filename_base. |
| 193 | + pub fn save(&mut self, session: &Session, backup_filename_base: &str) -> Result<(), Status> { |
| 194 | + let save_restore_ops = self.get_save_operation()?; |
| 195 | + let prefix_arg = Tensor::from(backup_filename_base.to_string()); |
| 196 | + let mut run_args = SessionRunArgs::new(); |
| 197 | + run_args.add_feed(&save_restore_ops.prefix_save, 0, &prefix_arg); |
| 198 | + run_args.add_target(&save_restore_ops.save_op); |
| 199 | + session.run(&mut run_args)?; |
| 200 | + Ok(()) |
| 201 | + } |
| 202 | + |
| 203 | + /// Restore into the session the variables listed in this CheckpointMaker from the checkpoint |
| 204 | + /// in path_base. |
| 205 | + pub fn restore(&mut self, session: &Session, path_base: &str) -> Result<(), Status> { |
| 206 | + let save_restore_ops = self.get_save_operation()?; |
| 207 | + let prefix_arg = Tensor::from(path_base.to_string()); |
| 208 | + let mut run_args = SessionRunArgs::new(); |
| 209 | + run_args.add_feed(&save_restore_ops.prefix_restore, 0, &prefix_arg); |
| 210 | + run_args.add_target(&save_restore_ops.restore_op); |
| 211 | + session.run(&mut run_args)?; |
| 212 | + Ok(()) |
| 213 | + } |
| 214 | +} |
| 215 | + |
| 216 | +#[cfg(test)] |
| 217 | +mod tests { |
| 218 | + use crate::ops::Placeholder; |
| 219 | + use crate::{ |
| 220 | + ops, CheckpointMaker, Code, DataType, FetchToken, Operation, Scope, Session, |
| 221 | + SessionOptions, SessionRunArgs, Status, Tensor, Variable, |
| 222 | + }; |
| 223 | + |
| 224 | + fn make_variable( |
| 225 | + scope: &mut Scope, |
| 226 | + name: &str, |
| 227 | + dims: &[u64], |
| 228 | + values: &[f32], |
| 229 | + ) -> Result<Variable, Status> { |
| 230 | + Ok(Variable::builder() |
| 231 | + .const_initial_value(Tensor::new(dims).with_values(values)?) |
| 232 | + .data_type(DataType::Float) |
| 233 | + .build(&mut scope.with_op_name(name))?) |
| 234 | + } |
| 235 | + |
| 236 | + fn create_assignment( |
| 237 | + var: &Variable, |
| 238 | + scope: &mut Scope, |
| 239 | + ) -> Result<(Operation, Operation), Status> { |
| 240 | + let placeholder = Placeholder::new() |
| 241 | + .dtype(DataType::Float) |
| 242 | + .shape(var.shape.clone()) |
| 243 | + .build(&mut scope.with_op_name(var.name.as_str()))?; |
| 244 | + Ok(( |
| 245 | + placeholder.clone(), |
| 246 | + ops::assign(var.output.clone(), placeholder, scope)?, |
| 247 | + )) |
| 248 | + } |
| 249 | + |
| 250 | + struct MyScopeData { |
| 251 | + scope: Scope, |
| 252 | + variables: [Variable; 3], |
| 253 | + } |
| 254 | + |
| 255 | + // Initialize a scope and place same variables in it |
| 256 | + fn create_scope() -> Result<MyScopeData, Status> { |
| 257 | + let mut scope = Scope::new_root_scope(); |
| 258 | + let var_w = make_variable(&mut scope, "w", &[], &[2.2])?; |
| 259 | + let var_b = make_variable(&mut scope, "b", &[3], &[1.0, 2.0, 4.5])?; |
| 260 | + let var_a = make_variable(&mut scope, "a", &[3, 2], &[1.0, 2.0, 3.3, 7.0, 8.0, 8.5])?; |
| 261 | + Ok(MyScopeData { |
| 262 | + scope, |
| 263 | + variables: [var_w, var_b, var_a], |
| 264 | + }) |
| 265 | + } |
| 266 | + |
| 267 | + struct AssignData { |
| 268 | + pub placeholder_ops: Box<[Operation]>, |
| 269 | + pub assign_op: Operation, |
| 270 | + } |
| 271 | + fn add_assign_op(scope_data: &mut MyScopeData) -> Result<AssignData, Status> { |
| 272 | + let mut placeholder_scope = scope_data.scope.new_sub_scope("placeholder"); |
| 273 | + let mut placeholders: Vec<Operation> = Vec::new(); |
| 274 | + let mut no_op_bld = ops::NoOp::new(); |
| 275 | + for var in scope_data.variables.as_ref() { |
| 276 | + let (placeholder, assign_op) = create_assignment(&var, &mut placeholder_scope)?; |
| 277 | + placeholders.push(placeholder); |
| 278 | + no_op_bld = no_op_bld.add_control_input(assign_op); |
| 279 | + } |
| 280 | + let assign_op = no_op_bld.build(&mut scope_data.scope)?; |
| 281 | + Ok(AssignData { |
| 282 | + placeholder_ops: placeholders.into_boxed_slice(), |
| 283 | + assign_op, |
| 284 | + }) |
| 285 | + } |
| 286 | + |
| 287 | + fn assign_variables( |
| 288 | + session: &Session, |
| 289 | + scope_data: &MyScopeData, |
| 290 | + assign_data: &AssignData, |
| 291 | + values: &[&[f32]], |
| 292 | + ) -> Result<(), Status> { |
| 293 | + let mut values_fed: Vec<Tensor<f32>> = |
| 294 | + Vec::with_capacity(assign_data.placeholder_ops.len()); |
| 295 | + let mut session_run = SessionRunArgs::new(); |
| 296 | + for i_var in 0..assign_data.placeholder_ops.len() { |
| 297 | + let value_fed_as_tensor = Tensor::new( |
| 298 | + &scope_data.variables[i_var] |
| 299 | + .shape() |
| 300 | + .0 |
| 301 | + .as_ref() |
| 302 | + .ok_or(Status::new_set(Code::Internal, "Shape not present")?)? |
| 303 | + .iter() |
| 304 | + .map(|o| { |
| 305 | + o.map(|i| i as u64) |
| 306 | + .ok_or(Status::new_set(Code::Internal, "Shape item not present")?) |
| 307 | + }) |
| 308 | + .collect::<Result<Vec<u64>, Status>>()? |
| 309 | + .as_ref(), |
| 310 | + ) |
| 311 | + .with_values(&values[i_var])?; |
| 312 | + values_fed.push(value_fed_as_tensor); |
| 313 | + } |
| 314 | + for i_var in 0..assign_data.placeholder_ops.len() { |
| 315 | + session_run.add_feed(&assign_data.placeholder_ops[i_var], 0, &values_fed[i_var]); |
| 316 | + } |
| 317 | + session_run.add_target(&assign_data.assign_op); |
| 318 | + session.run(&mut session_run)?; |
| 319 | + Ok(()) |
| 320 | + } |
| 321 | + |
| 322 | + fn check_variables( |
| 323 | + session: &Session, |
| 324 | + variables: &[Variable], |
| 325 | + values: &[&[f32]], |
| 326 | + ) -> Result<(), Status> { |
| 327 | + let mut session_run = SessionRunArgs::new(); |
| 328 | + let mut tokens: Vec<FetchToken> = Vec::with_capacity(variables.len()); |
| 329 | + for i in 0..variables.len() { |
| 330 | + tokens.push(session_run.request_fetch( |
| 331 | + &variables[i].output().operation, |
| 332 | + variables[i].output().index, |
| 333 | + )); |
| 334 | + } |
| 335 | + session.run(&mut session_run)?; |
| 336 | + for i in 0..variables.len() { |
| 337 | + let got_tensor: Tensor<f32> = session_run.fetch(tokens[i])?; |
| 338 | + assert_eq!(values[i], got_tensor.as_ref()); |
| 339 | + } |
| 340 | + Ok(()) |
| 341 | + } |
| 342 | + |
| 343 | + #[test] |
| 344 | + fn simple_save() -> Result<(), Box<dyn std::error::Error>> { |
| 345 | + let mut first_scope_data = create_scope()?; |
| 346 | + let assign_data = add_assign_op(&mut first_scope_data)?; |
| 347 | + let first_session = Session::new(&SessionOptions::new(), &first_scope_data.scope.graph())?; |
| 348 | + let new_values: [&[f32]; 3] = [ |
| 349 | + &[5.1], |
| 350 | + &[4.0, 2.2, 6.0], |
| 351 | + &[11.0, 12.0, 13.6, 17.1, 18.4, 19.5], |
| 352 | + ]; |
| 353 | + assign_variables(&first_session, &first_scope_data, &assign_data, &new_values)?; |
| 354 | + let mut checkpoint = CheckpointMaker::new( |
| 355 | + first_scope_data.scope.new_sub_scope("checkpoint"), |
| 356 | + Box::from(first_scope_data.variables.clone()), |
| 357 | + ); |
| 358 | + let temp_dir = tempdir::TempDir::new("test-tensorflow")?; |
| 359 | + let checkpoint_path = temp_dir.path().join("checkpoint-vars"); |
| 360 | + let checkpoint_path_str = checkpoint_path |
| 361 | + .into_os_string() |
| 362 | + .into_string() |
| 363 | + .map_err(|_| "Cannot convert checkpoint path")?; |
| 364 | + checkpoint.save(&first_session, checkpoint_path_str.as_str())?; |
| 365 | + let MyScopeData { |
| 366 | + scope: second_scope, |
| 367 | + variables: second_variables, |
| 368 | + } = create_scope()?; |
| 369 | + let second_session = Session::new(&SessionOptions::new(), &second_scope.graph())?; |
| 370 | + let mut second_checkpoint = |
| 371 | + CheckpointMaker::new(second_scope, Box::new(second_variables.clone())); |
| 372 | + second_checkpoint.restore(&second_session, checkpoint_path_str.as_str())?; |
| 373 | + check_variables(&second_session, &second_variables, &new_values)?; |
| 374 | + Ok(()) |
| 375 | + } |
| 376 | +} |
0 commit comments