diff --git a/examples/features/reproducibleTraining.py b/examples/features/reproducibleTraining.py new file mode 100644 index 00000000..d32b7e48 --- /dev/null +++ b/examples/features/reproducibleTraining.py @@ -0,0 +1,67 @@ +""" +offlineTraining.py +Run the pipeline BaseTrainer to create a training session with an existing Dataset. +""" + +# Python related imports +import os + +from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig +from DeepPhysX.Core.Utils.yamlUtils import unpack_pipeline_config, BaseYamlExporter, BaseYamlLoader +from DeepPhysX.Core.Visualizer.VedoVisualizer import VedoVisualizer +from torch.nn import MSELoss +from torch.optim import Adam + +# DeepPhysX related imports +from DeepPhysX.Core.Pipelines.BaseTrainer import BaseTrainer +from DeepPhysX.Core.Dataset.BaseDatasetConfig import BaseDatasetConfig +from DeepPhysX.Torch.FC.FCConfig import FCConfig +from Environment import MeanEnvironment + + +def set_pipepline_config(): + # Some int parameters + nb_points = 30 + dimension = 3 + # The config object and its keyword arguments are put in a tuple + environment_config = (BaseEnvironmentConfig, dict(environment_class=MeanEnvironment, + visualizer=VedoVisualizer, + param_dict={'constant': False, + 'data_size': [nb_points, dimension], + 'sleep': False, + 'allow_requests': True}, + as_tcp_ip_client=True, + number_of_thread=10)) + network_config = (FCConfig, dict(loss=MSELoss, + lr=1e-3, + optimizer=Adam, + dim_layers=[nb_points * dimension, nb_points * dimension, dimension], + dim_output=dimension)) + + dataset_config = (BaseDatasetConfig, dict(shuffle_dataset=True, + normalize=False)) + + # This is the config that will be passed to the Trainer + # The config will be instantiated using unpack_pipeline_config + pipeline_config = dict(session_dir='sessions', + session_name='offline_training', + environment_config=environment_config, + dataset_config=dataset_config, + network_config=network_config, + nb_epochs=1, + nb_batches=500, + batch_size=10) + return pipeline_config + + +if __name__ == '__main__': + pipeline_config = set_pipepline_config() + # Initialize the network config and dataset config then initialize the Trainer + pipeline = BaseTrainer(**unpack_pipeline_config(pipeline_config)) + # Save the Pipeline config in the current session dir + config_export_path = os.path.join(pipeline.manager.session_dir, 'conf.yml') + BaseYamlExporter(config_export_path, pipeline_config) + # Loads the pipeline config from the file + loaded_pipeline_config = BaseYamlLoader(config_export_path) + if loaded_pipeline_config == pipeline_config: + print("The loaded config is equal to the original config.") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 859b901f..1398651e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ SimulationSimpleDatabase tensorboard tensorboardX pyDataverse -torch \ No newline at end of file +torch +yaml \ No newline at end of file diff --git a/src/Utils/yamlUtils.py b/src/Utils/yamlUtils.py new file mode 100644 index 00000000..22c3f022 --- /dev/null +++ b/src/Utils/yamlUtils.py @@ -0,0 +1,112 @@ +import copy +import yaml +import importlib + + +def BaseYamlExporter(filename: str=None, var_dict:dict=None): + """ + | Exports variables in a yaml file, excluding classes, modules and functions. Additionally, variables with a name in + | excluded will not be exported. + :param str filename: Path to the file in which var_dict will be saved after filtering + :param dict var_dict: Dictionnary containing the key:val pairs to be saved. Key is a variable name and val its value + """ + export_dict = copy.deepcopy(var_dict) + recursive_convert_type_to_type_str(export_dict) + if filename is not None: + #Export to yaml file + with open(filename,'w') as f: + print(f"[BaseYamlExporter] Saving conf to {filename}.") + yaml.dump(export_dict, f) + return export_dict + + +def BaseYamlLoader(filename: str): + """Loads a yaml file and converts the str representation of types to types using convert_type_str_to_type.""" + with open(filename, 'r') as f: + loaded_dict = yaml.load(f, yaml.Loader) + recursive_convert_type_str_to_type(loaded_dict) + return loaded_dict + + +def recursive_convert_type_to_type_str(var_container): + """Recursively converts types in a nested dict or iterable to their str representation using + convert_type_to_type_str.""" + var_container_type = type(var_container) + if isinstance(var_container, dict): + keys = list(var_container.keys()) + if 'excluded' in keys: #Special keyword that specify which keys should be removed + for exclude_key in var_container['excluded']: + if exclude_key in var_container: var_container.pop(exclude_key) #Remove the key listed in excluded + keys = list(var_container.keys()) #Update the keys + elif isinstance(var_container, (tuple, list, set)): #Is not a dict but is iterable. + keys = range(len(var_container)) + var_container = list(var_container) #Allows to change elements in var_container + else: + raise ValueError(f"BaseYamlExporter: encountered an object to convert which is not a dict, tuple or list.") + for k in keys: + v = var_container[k] + if isinstance(v, type): # Object is just a type, not an instance + new_val = convert_type_to_type_str(v) + new_val = dict(type=new_val) + elif hasattr(v, '__iter__') and not isinstance(v, str): # Object contains other objects + new_val = recursive_convert_type_to_type_str(v) + else: # Object is assumed to not contain other objects + new_val = v + var_container[k] = new_val + if var_container_type in (tuple, set): + var_container = var_container_type(var_container) #Convert back to original type + return var_container + + +def recursive_convert_type_str_to_type(var_container): + """Recursively converts str representation of types in a nested dict or iterable using convert_type_str_to_type.""" + var_container_type = type(var_container) + if isinstance(var_container, dict): + keys = list(var_container.keys()) + elif isinstance(var_container, (tuple, list, set)): # Is not a dict but is iterable. + keys = range(len(var_container)) + var_container = list(var_container) #Allows to change elements in var_container + else: + raise ValueError(f"recursive_convert_type_str_to_type: " + f"encountered an object to convert which is not a dict, tuple or list.") + for k in keys: + v = var_container[k] + # Detection of a type object that was converted to str + if isinstance(v,dict) and len(v) == 1 and 'type' in v and isinstance(v['type'], str): + new_val = convert_type_str_to_type(v['type']) + elif hasattr(v, '__iter__') and not isinstance(v, str): # Object contains other objects + new_val = recursive_convert_type_str_to_type(v) + else: + new_val = v + var_container[k] = new_val + if var_container_type in (tuple, set): + var_container = var_container_type(var_container) #Convert back to original type + return var_container + + +def convert_type_str_to_type(name: str): + """Converts a str representation of a type to a type.""" + module = importlib.import_module('.'.join(name.split('.')[:-1])) + object_name_in_module = name.split('.')[-1] + return getattr(module, object_name_in_module) + + +def convert_type_to_type_str(type_to_convert: type): + """Converts a type to its str representation.""" + repr_str = repr(type_to_convert) + if repr_str.__contains__("")[0] + else: + raise ValueError(f"BaseYamlExporter: {repr_str} could not be converted to an object name.") + + +def unpack_pipeline_config(pipeline_config): + """Initializes the network, environment and dataset config objects from the pipeline config.""" + nested_configs_keys = ['network_config', 'environment_config', 'dataset_config'] + unpacked = {} + for key in pipeline_config: + if key in nested_configs_keys and pipeline_config[key] is not None: + unpacked.update({key: pipeline_config[key][0](**pipeline_config[key][1])}) + else: + unpacked.update({key: pipeline_config[key]}) + return unpacked