Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions examples/features/reproducibleTraining.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
offlineTraining.py
Run the pipeline BaseTrainer to create a training session with an existing Dataset.
"""

# Python related imports
import os

from DeepPhysX.Core.Environment.BaseEnvironmentConfig import BaseEnvironmentConfig
from DeepPhysX.Core.Utils.yamlUtils import unpack_pipeline_config, BaseYamlExporter, BaseYamlLoader
from DeepPhysX.Core.Visualizer.VedoVisualizer import VedoVisualizer
from torch.nn import MSELoss
from torch.optim import Adam

# DeepPhysX related imports
from DeepPhysX.Core.Pipelines.BaseTrainer import BaseTrainer
from DeepPhysX.Core.Dataset.BaseDatasetConfig import BaseDatasetConfig
from DeepPhysX.Torch.FC.FCConfig import FCConfig
from Environment import MeanEnvironment


def set_pipepline_config():
# Some int parameters
nb_points = 30
dimension = 3
# The config object and its keyword arguments are put in a tuple
environment_config = (BaseEnvironmentConfig, dict(environment_class=MeanEnvironment,
visualizer=VedoVisualizer,
param_dict={'constant': False,
'data_size': [nb_points, dimension],
'sleep': False,
'allow_requests': True},
as_tcp_ip_client=True,
number_of_thread=10))
network_config = (FCConfig, dict(loss=MSELoss,
lr=1e-3,
optimizer=Adam,
dim_layers=[nb_points * dimension, nb_points * dimension, dimension],
dim_output=dimension))

dataset_config = (BaseDatasetConfig, dict(shuffle_dataset=True,
normalize=False))

# This is the config that will be passed to the Trainer
# The config will be instantiated using unpack_pipeline_config
pipeline_config = dict(session_dir='sessions',
session_name='offline_training',
environment_config=environment_config,
dataset_config=dataset_config,
network_config=network_config,
nb_epochs=1,
nb_batches=500,
batch_size=10)
return pipeline_config


if __name__ == '__main__':
pipeline_config = set_pipepline_config()
# Initialize the network config and dataset config then initialize the Trainer
pipeline = BaseTrainer(**unpack_pipeline_config(pipeline_config))
# Save the Pipeline config in the current session dir
config_export_path = os.path.join(pipeline.manager.session_dir, 'conf.yml')
BaseYamlExporter(config_export_path, pipeline_config)
# Loads the pipeline config from the file
loaded_pipeline_config = BaseYamlLoader(config_export_path)
if loaded_pipeline_config == pipeline_config:
print("The loaded config is equal to the original config.")
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ SimulationSimpleDatabase
tensorboard
tensorboardX
pyDataverse
torch
torch
yaml
112 changes: 112 additions & 0 deletions src/Utils/yamlUtils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import copy
import yaml
import importlib


def BaseYamlExporter(filename: str=None, var_dict:dict=None):
"""
| Exports variables in a yaml file, excluding classes, modules and functions. Additionally, variables with a name in
| excluded will not be exported.
:param str filename: Path to the file in which var_dict will be saved after filtering
:param dict var_dict: Dictionnary containing the key:val pairs to be saved. Key is a variable name and val its value
"""
export_dict = copy.deepcopy(var_dict)
recursive_convert_type_to_type_str(export_dict)
if filename is not None:
#Export to yaml file
with open(filename,'w') as f:
print(f"[BaseYamlExporter] Saving conf to {filename}.")
yaml.dump(export_dict, f)
return export_dict


def BaseYamlLoader(filename: str):
"""Loads a yaml file and converts the str representation of types to types using convert_type_str_to_type."""
with open(filename, 'r') as f:
loaded_dict = yaml.load(f, yaml.Loader)
recursive_convert_type_str_to_type(loaded_dict)
return loaded_dict


def recursive_convert_type_to_type_str(var_container):
"""Recursively converts types in a nested dict or iterable to their str representation using
convert_type_to_type_str."""
var_container_type = type(var_container)
if isinstance(var_container, dict):
keys = list(var_container.keys())
if 'excluded' in keys: #Special keyword that specify which keys should be removed
for exclude_key in var_container['excluded']:
if exclude_key in var_container: var_container.pop(exclude_key) #Remove the key listed in excluded
keys = list(var_container.keys()) #Update the keys
elif isinstance(var_container, (tuple, list, set)): #Is not a dict but is iterable.
keys = range(len(var_container))
var_container = list(var_container) #Allows to change elements in var_container
else:
raise ValueError(f"BaseYamlExporter: encountered an object to convert which is not a dict, tuple or list.")
for k in keys:
v = var_container[k]
if isinstance(v, type): # Object is just a type, not an instance
new_val = convert_type_to_type_str(v)
new_val = dict(type=new_val)
elif hasattr(v, '__iter__') and not isinstance(v, str): # Object contains other objects
new_val = recursive_convert_type_to_type_str(v)
else: # Object is assumed to not contain other objects
new_val = v
var_container[k] = new_val
if var_container_type in (tuple, set):
var_container = var_container_type(var_container) #Convert back to original type
return var_container


def recursive_convert_type_str_to_type(var_container):
"""Recursively converts str representation of types in a nested dict or iterable using convert_type_str_to_type."""
var_container_type = type(var_container)
if isinstance(var_container, dict):
keys = list(var_container.keys())
elif isinstance(var_container, (tuple, list, set)): # Is not a dict but is iterable.
keys = range(len(var_container))
var_container = list(var_container) #Allows to change elements in var_container
else:
raise ValueError(f"recursive_convert_type_str_to_type: "
f"encountered an object to convert which is not a dict, tuple or list.")
for k in keys:
v = var_container[k]
# Detection of a type object that was converted to str
if isinstance(v,dict) and len(v) == 1 and 'type' in v and isinstance(v['type'], str):
new_val = convert_type_str_to_type(v['type'])
elif hasattr(v, '__iter__') and not isinstance(v, str): # Object contains other objects
new_val = recursive_convert_type_str_to_type(v)
else:
new_val = v
var_container[k] = new_val
if var_container_type in (tuple, set):
var_container = var_container_type(var_container) #Convert back to original type
return var_container


def convert_type_str_to_type(name: str):
"""Converts a str representation of a type to a type."""
module = importlib.import_module('.'.join(name.split('.')[:-1]))
object_name_in_module = name.split('.')[-1]
return getattr(module, object_name_in_module)


def convert_type_to_type_str(type_to_convert: type):
"""Converts a type to its str representation."""
repr_str = repr(type_to_convert)
if repr_str.__contains__("<class "): #Class object, not instanciated
return repr_str.split("<class '")[1].split("'>")[0]
else:
raise ValueError(f"BaseYamlExporter: {repr_str} could not be converted to an object name.")


def unpack_pipeline_config(pipeline_config):
"""Initializes the network, environment and dataset config objects from the pipeline config."""
nested_configs_keys = ['network_config', 'environment_config', 'dataset_config']
unpacked = {}
for key in pipeline_config:
if key in nested_configs_keys and pipeline_config[key] is not None:
unpacked.update({key: pipeline_config[key][0](**pipeline_config[key][1])})
else:
unpacked.update({key: pipeline_config[key]})
return unpacked