diff --git a/src/Core/Database/BaseDatabaseConfig.py b/src/Core/Database/BaseDatabaseConfig.py index 0ffe5934..3ebdc527 100644 --- a/src/Core/Database/BaseDatabaseConfig.py +++ b/src/Core/Database/BaseDatabaseConfig.py @@ -31,7 +31,7 @@ def __init__(self, if not isdir(existing_dir): raise ValueError(f"[{self.name}] The given 'existing_dir'={existing_dir} does not exist.") if len(existing_dir.split(sep)) > 1 and existing_dir.split(sep)[-1] == 'dataset': - existing_dir = join(*existing_dir.split(sep)[:-1]) + existing_dir = sep.join(existing_dir.split(sep)[:-1]) # Check storage variables if mode is not None: diff --git a/src/Core/Manager/DataManager.py b/src/Core/Manager/DataManager.py index 602d1369..70f168ad 100644 --- a/src/Core/Manager/DataManager.py +++ b/src/Core/Manager/DataManager.py @@ -168,6 +168,12 @@ def get_prediction(self, self.pipeline.network_manager.compute_online_prediction(instance_id=instance_id, normalization=self.normalization) + def set_eval(self): + self.database_manager.set_eval() + + def set_train(self): + self.database_manager.set_train() + def close(self) -> None: """ Launch the closing procedure of the DataManager. diff --git a/src/Core/Manager/DatabaseManager.py b/src/Core/Manager/DatabaseManager.py index 9a96f6b0..0e22191c 100644 --- a/src/Core/Manager/DatabaseManager.py +++ b/src/Core/Manager/DatabaseManager.py @@ -92,7 +92,8 @@ def __init__(self, self.create_partition() # Complete a Database in a new session --> copy and load the existing directory else: - copy_dir(src_dir=database_config.existing_dir, dest_dir=session, sub_folders='dataset') + copy_dir(src_dir=database_config.existing_dir, dest_dir=session, + sub_folders='dataset') self.load_directory(rename_partitions=True) # Complete a Database in the same session --> load the directory else: @@ -111,7 +112,8 @@ def __init__(self, self.create_partition() # Complete a Database in a new session --> copy and load the existing directory else: - copy_dir(src_dir=database_config.existing_dir, dest_dir=session, sub_folders='dataset') + copy_dir(src_dir=database_config.existing_dir, dest_dir=session, + sub_folders='dataset') self.load_directory() # Complete a Database in the same directory --> load the directory else: @@ -285,8 +287,8 @@ def change_mode(self, :param mode: Name of the Database mode. """ - - pass + self.mode = mode + self.index_samples() ########################################################################################## ########################################################################################## @@ -416,6 +418,7 @@ def index_samples(self) -> None: Create a new indexing list of samples. Samples are identified by [partition_id, line_id]. """ + self.sample_indices = empty((0, 2), dtype=int) # Create the indices for each sample such as [partition_id, line_id] for i, nb_sample in enumerate(self.json_content['nb_samples'][self.mode]): partition_indices = empty((nb_sample, 2), dtype=int) @@ -507,7 +510,7 @@ def compute_normalization(self) -> Dict[str, List[float]]: for field in self.json_content['data_shape']: table_name, field_name = field.split('.') fields += [field_name] if table_name == 'Training' else [] - normalization = {field: [0., 0.] for field in fields} + normalization = {field: [0., 1.] for field in fields} # 2. Compute the mean of samples for each field means = {field: [] for field in fields} @@ -607,6 +610,12 @@ def load_partitions_fields(partition: Database, ########################################################################################## ########################################################################################## + def set_eval(self): + self.change_mode('validation') + + def set_train(self): + self.change_mode('training') + def close(self): """ Launch the closing procedure of the DatabaseManager. diff --git a/src/Core/Manager/NetworkManager.py b/src/Core/Manager/NetworkManager.py index d00d995a..5457762d 100644 --- a/src/Core/Manager/NetworkManager.py +++ b/src/Core/Manager/NetworkManager.py @@ -186,6 +186,9 @@ def compute_prediction_and_loss(self, lines_id=data_lines) # Apply normalization and convert to tensor for field in batch.keys(): + # batch can contain dicts if fields refer to joined tables + if isinstance(batch[field], dict): + batch[field] = batch[field][field] batch[field] = array(batch[field]) if field in normalization: batch[field] = self.normalize_data(data=batch[field], @@ -228,6 +231,8 @@ def compute_online_prediction(self, # Apply normalization and convert to tensor for field in sample.keys(): + if isinstance(sample[field], dict): + sample[field] = sample[field][field] sample[field] = array([sample[field]]) if field in normalization.keys(): sample[field] = self.normalize_data(data=sample[field], diff --git a/src/Core/Manager/StatsManager.py b/src/Core/Manager/StatsManager.py index 060c986f..fde0d1de 100644 --- a/src/Core/Manager/StatsManager.py +++ b/src/Core/Manager/StatsManager.py @@ -37,8 +37,16 @@ def __init__(self, # Open Tensorboard tb = program.TensorBoard() - tb.configure(argv=[None, '--logdir', self.log_dir]) - url = tb.launch() + port = 6006 + tb.configure(argv=[None, '--logdir', self.log_dir, '--port', str(port)]) + while True and port<7000: + try: + url = tb.launch() + break + except: + port +=1 + tb.configure(argv=[None, '--logdir', self.log_dir, '--port', str(port)]) + continue w_open(url) # Values diff --git a/src/Core/Network/BaseNetwork.py b/src/Core/Network/BaseNetwork.py index 32830ded..b4463b0a 100644 --- a/src/Core/Network/BaseNetwork.py +++ b/src/Core/Network/BaseNetwork.py @@ -1,6 +1,8 @@ from typing import Any, Dict from numpy import ndarray from collections import namedtuple +from torch import as_tensor +import torch class BaseNetwork: @@ -115,7 +117,7 @@ def numpy_to_tensor(self, :return: Converted tensor. """ - return data.astype(self.config.data_type) + return as_tensor(data, dtype=getattr(torch, self.config.data_type)).requires_grad_(grad) def tensor_to_numpy(self, data: Any) -> ndarray: @@ -126,7 +128,7 @@ def tensor_to_numpy(self, :return: Converted array. """ - return data.astype(self.config.data_type) + return data.detach().cpu().numpy() def __str__(self) -> str: diff --git a/src/Core/Network/BaseNetworkConfig.py b/src/Core/Network/BaseNetworkConfig.py index d47df632..4be3fb5c 100644 --- a/src/Core/Network/BaseNetworkConfig.py +++ b/src/Core/Network/BaseNetworkConfig.py @@ -23,7 +23,11 @@ def __init__(self, lr: Optional[float] = None, require_training_stuff: bool = True, loss: Optional[Any] = None, - optimizer: Optional[Any] = None): + loss_parameters: Type[dict] = None, + optimizer: Optional[Any] = None, + optimizer_parameters: Type[dict] = None, + scheduler_class: Any = None, + scheduler_parameters: dict = None): """ BaseNetworkConfig is a configuration class to parameterize and create BaseNetwork, BaseOptimization and BaseTransformation for the NetworkManager. @@ -88,7 +92,11 @@ def __init__(self, configuration_name='optimization_config', loss=loss, lr=lr, - optimizer=optimizer) + optimizer=optimizer, + optimizer_parameters=optimizer_parameters, + loss_parameters=loss_parameters, + scheduler_class=scheduler_class, + scheduler_parameters=scheduler_parameters) self.training_stuff: bool = (loss is not None) and (optimizer is not None) or (not require_training_stuff) # NetworkManager parameterization diff --git a/src/Core/Pipelines/BaseTraining.py b/src/Core/Pipelines/BaseTraining.py index d67b5c46..47dacd23 100644 --- a/src/Core/Pipelines/BaseTraining.py +++ b/src/Core/Pipelines/BaseTraining.py @@ -143,6 +143,7 @@ def epoch_begin(self) -> None: Called one at the beginning of each epoch. """ + self.set_train() self.batch_id = 0 def batch_condition(self) -> bool: @@ -174,6 +175,23 @@ def optimize(self) -> None: normalization=self.data_manager.normalization, optimize=True) + def execute_validation(self): + self.set_eval() + id_batch = 0 + while id_batch < self.nb_validation_batches: + self.validate() + id_batch += 1 + + def validate(self): + """ + | Pulls data from the manager and run a prediction step. + """ + self.data_manager.get_data(epoch=0) + self.loss_dict = self.network_manager.compute_prediction_and_loss( + data_lines=self.data_manager.data_lines, + normalization=self.data_manager.normalization, + optimize=False) + def batch_count(self) -> None: """ Increment the batch counter. @@ -210,6 +228,9 @@ def epoch_end(self) -> None: if self.stats_manager is not None: self.stats_manager.add_train_epoch_loss(self.loss_dict['loss'], self.epoch_id) self.network_manager.save_network() + if self.do_validation: + self.execute_validation() + self.stats_manager.add_test_loss(self.loss_dict['loss'], self.epoch_id) def train_end(self) -> None: """ @@ -221,6 +242,26 @@ def train_end(self) -> None: if self.stats_manager is not None: self.stats_manager.close() + def set_eval(self): + # Set DBManager mode, build indices + self.data_manager.set_eval() + # Set network to eval mode + self.network_manager.set_eval() + # Connect the handler to the validation partition + self.data_manager.connect_handler(self.network_manager.get_database_handler()) + # Create the links + self.network_manager.link_clients(self.data_manager.nb_environment) + + def set_train(self): + # Set DBManager mode, build indices + self.data_manager.set_train() + # Set network to train mode + self.network_manager.set_train() + # Connect the handler to the training partition + self.data_manager.connect_handler(self.network_manager.get_database_handler()) + # Create the links + self.network_manager.link_clients(self.data_manager.nb_environment) + def save_info_file(self) -> None: """ Save a .txt file that provides a template for user notes and the description of all the components. diff --git a/src/Core/Utils/yamlUtils.py b/src/Core/Utils/yamlUtils.py new file mode 100644 index 00000000..c9308578 --- /dev/null +++ b/src/Core/Utils/yamlUtils.py @@ -0,0 +1,112 @@ +import copy +import yaml +import importlib + + +def BaseYamlExporter(filename: str=None, var_dict:dict=None): + """ + | Exports variables in a yaml file, excluding classes, modules and functions. Additionally, variables with a name in + | excluded will not be exported. + :param str filename: Path to the file in which var_dict will be saved after filtering + :param dict var_dict: Dictionnary containing the key:val pairs to be saved. Key is a variable name and val its value + """ + export_dict = copy.deepcopy(var_dict) + recursive_convert_type_to_type_str(export_dict) + if filename is not None: + #Export to yaml file + with open(filename,'w') as f: + print(f"[BaseYamlExporter] Saving conf to {filename}.") + yaml.dump(export_dict, f) + return export_dict + + +def BaseYamlLoader(filename: str): + """Loads a yaml file and converts the str representation of types to types using convert_type_str_to_type.""" + with open(filename, 'r') as f: + loaded_dict = yaml.load(f, yaml.Loader) + recursive_convert_type_str_to_type(loaded_dict) + return loaded_dict + + +def recursive_convert_type_to_type_str(var_container): + """Recursively converts types in a nested dict or iterable to their str representation using + convert_type_to_type_str.""" + var_container_type = type(var_container) + if isinstance(var_container, dict): + keys = list(var_container.keys()) + if 'excluded' in keys: #Special keyword that specify which keys should be removed + for exclude_key in var_container['excluded']: + if exclude_key in var_container: var_container.pop(exclude_key) #Remove the key listed in excluded + keys = list(var_container.keys()) #Update the keys + elif isinstance(var_container, (tuple, list, set)): #Is not a dict but is iterable. + keys = range(len(var_container)) + var_container = list(var_container) #Allows to change elements in var_container + else: + raise ValueError(f"BaseYamlExporter: encountered an object to convert which is not a dict, tuple or list.") + for k in keys: + v = var_container[k] + if isinstance(v, type): # Object is just a type, not an instance + new_val = convert_type_to_type_str(v) + new_val = dict(type=new_val) + elif hasattr(v, '__iter__') and not isinstance(v, str): # Object contains other objects + new_val = recursive_convert_type_to_type_str(v) + else: # Object is assumed to not contain other objects + new_val = v + var_container[k] = new_val + if var_container_type in (tuple, set): + var_container = var_container_type(var_container) #Convert back to original type + return var_container + + +def recursive_convert_type_str_to_type(var_container): + """Recursively converts str representation of types in a nested dict or iterable using convert_type_str_to_type.""" + var_container_type = type(var_container) + if isinstance(var_container, dict): + keys = list(var_container.keys()) + elif isinstance(var_container, (tuple, list, set)): # Is not a dict but is iterable. + keys = range(len(var_container)) + var_container = list(var_container) #Allows to change elements in var_container + else: + raise ValueError(f"recursive_convert_type_str_to_type: " + f"encountered an object to convert which is not a dict, tuple or list.") + for k in keys: + v = var_container[k] + # Detection of a type object that was converted to str + if isinstance(v,dict) and len(v) == 1 and 'type' in v and isinstance(v['type'], str): + new_val = convert_type_str_to_type(v['type']) + elif hasattr(v, '__iter__') and not isinstance(v, str): # Object contains other objects + new_val = recursive_convert_type_str_to_type(v) + else: + new_val = v + var_container[k] = new_val + if var_container_type in (tuple, set): + var_container = var_container_type(var_container) #Convert back to original type + return var_container + + +def convert_type_str_to_type(name: str): + """Converts a str representation of a type to a type.""" + module = importlib.import_module('.'.join(name.split('.')[:-1])) + object_name_in_module = name.split('.')[-1] + return getattr(module, object_name_in_module) + + +def convert_type_to_type_str(type_to_convert: type): + """Converts a type to its str representation.""" + repr_str = repr(type_to_convert) + if repr_str.__contains__("")[0] + else: + raise ValueError(f"BaseYamlExporter: {repr_str} could not be converted to an object name.") + + +def unpack_pipeline_config(pipeline_config): + """Initializes the network, environment and dataset config objects from the pipeline config.""" + nested_configs_keys = ['network_config', 'environment_config', 'database_config'] + unpacked = {} + for key in pipeline_config: + if key in nested_configs_keys and pipeline_config[key] is not None: + unpacked.update({key: pipeline_config[key][0](**pipeline_config[key][1])}) + else: + unpacked.update({key: pipeline_config[key]}) + return unpacked