diff --git a/cpp/memilio/mobility/graph_simulation.h b/cpp/memilio/mobility/graph_simulation.h index 27f3942689..94000937f3 100644 --- a/cpp/memilio/mobility/graph_simulation.h +++ b/cpp/memilio/mobility/graph_simulation.h @@ -22,6 +22,7 @@ #include "memilio/mobility/graph.h" #include "memilio/utils/random_number_generator.h" +#include namespace mio { @@ -60,7 +61,9 @@ class GraphSimulationBase void advance(double t_max = 1.0) { - auto dt = m_dt; + auto dt = m_dt; + auto start_time = std::chrono::high_resolution_clock::now(); // Startzeit erfassen + while (m_t < t_max) { if (m_t + dt > t_max) { dt = t_max - m_t; @@ -77,6 +80,12 @@ class GraphSimulationBase m_graph.nodes()[e.end_node_idx].property); } } + + auto end_time = std::chrono::high_resolution_clock::now(); // Endzeit erfassen + std::chrono::duration execution_time = end_time - start_time; // Ausführungszeit berechnen + + std::cout << "t = " << m_t << " execution time (Graph Simulation): " << execution_time.count() << "sec" + << std::endl; } double get_t() const diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/GNN_utils.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/GNN_utils.py new file mode 100644 index 0000000000..0f66e938b6 --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/GNN_utils.py @@ -0,0 +1,166 @@ +import numpy as np +import pandas as pd +import os +from sklearn.preprocessing import FunctionTransformer + +from memilio.epidata import transformMobilityData as tmd +from memilio.epidata import getDataIntoPandasDataFrame as gd +from memilio.simulation.osecir import (ModelGraph, set_edges) +from memilio.epidata import modifyDataframeSeries as mdfs + + +def remove_confirmed_compartments(dataset_entries, num_groups): + """! The compartments which contain confirmed cases are not needed and are + therefore omitted by summarizing the confirmed compartment with the + original compartment. + @param dataset_entries Array that contains the compartmental data with + confirmed compartments. + @param num_groups Number of age groups. + @return Array that contains the compartmental data without confirmed compartments. + """ + + new_dataset_entries = [] + for i in dataset_entries: + dataset_entries_reshaped = i.reshape( + [num_groups, int(np.asarray(dataset_entries).shape[1]/num_groups)] + ) + sum_inf_no_symp = np.sum(dataset_entries_reshaped[:, [2, 3]], axis=1) + sum_inf_symp = np.sum(dataset_entries_reshaped[:, [4, 5]], axis=1) + dataset_entries_reshaped[:, 2] = sum_inf_no_symp + dataset_entries_reshaped[:, 4] = sum_inf_symp + new_dataset_entries.append( + np.delete(dataset_entries_reshaped, [3, 5], axis=1).flatten() + ) + return new_dataset_entries + + +def getBaselineMatrix(): + """! loads the baselinematrix + """ + + baseline_contact_matrix0 = os.path.join( + "./data/contacts/baseline_home.txt") + baseline_contact_matrix1 = os.path.join( + "./data/contacts/baseline_school_pf_eig.txt") + baseline_contact_matrix2 = os.path.join( + "./data/contacts/baseline_work.txt") + baseline_contact_matrix3 = os.path.join( + "./data/contacts/baseline_other.txt") + + baseline = np.loadtxt(baseline_contact_matrix0) \ + + np.loadtxt(baseline_contact_matrix1) + \ + np.loadtxt(baseline_contact_matrix2) + \ + np.loadtxt(baseline_contact_matrix3) + + return baseline + + +def getMinimumMatrix(): + """! loads the minimum matrix + """ + + minimum_contact_matrix0 = os.path.join( + "./data/contacts/minimum_home.txt") + minimum_contact_matrix1 = os.path.join( + "./data/contacts/minimum_school_pf_eig.txt") + minimum_contact_matrix2 = os.path.join( + "./data/contacts/minimum_work.txt") + minimum_contact_matrix3 = os.path.join( + "./data/contacts/minimum_other.txt") + + minimum = np.loadtxt(minimum_contact_matrix0) \ + + np.loadtxt(minimum_contact_matrix1) + \ + np.loadtxt(minimum_contact_matrix2) + \ + np.loadtxt(minimum_contact_matrix3) + + return minimum + + +def make_graph(directory, num_regions, countykey_list, models): + """! + @param directory Directory with mobility data. + @param num_regions Number (int) of counties that should be added to the + grap-ODE model. Equals 400 for whole Germany. + @param countykey_list List of keys/IDs for each county. + @models models List of osecir Model with one model per population. + @return graph Graph-ODE model. + """ + graph = ModelGraph() + for i in range(num_regions): + graph.add_node(int(countykey_list[i]), models[i]) + + num_locations = 4 + + set_edges(os.path.abspath(os.path.join(directory, os.pardir)), + graph, num_locations) + return graph + + +def transform_mobility_directory(): + """! Transforms the mobility data by merging Eisenach and Wartburgkreis + """ + # get mobility data directory + arg_dict = gd.cli("commuter_official") + + directory = arg_dict['out_folder'].split('/pydata')[0] + directory = os.path.join(directory, 'mobility/') + + # Merge Eisenach and Wartbugkreis in Input Data + tmd.updateMobility2022(directory, mobility_file='twitter_scaled_1252') + tmd.updateMobility2022( + directory, mobility_file='commuter_migration_scaled') + return directory + + +def get_population(): + df_population = pd.read_json( + 'data/pydata/Germany/county_population.json') + age_groups = ['0-4', '5-14', '15-34', '35-59', '60-79', '80-130'] + + df_population_agegroups = pd.DataFrame( + columns=[df_population.columns[0]] + age_groups) + for region_id in df_population.iloc[:, 0]: + df_population_agegroups.loc[len(df_population_agegroups.index), :] = [int(region_id)] + list( + mdfs.fit_age_group_intervals(df_population[df_population.iloc[:, 0] == int(region_id)].iloc[:, 2:], age_groups)) + + population = df_population_agegroups.values.tolist() + return population + + +def scale_data(data): + num_groups = int(np.asarray(data['inputs']).shape[2] / 8) + transformer = FunctionTransformer(np.log1p, validate=True) + + # Scale inputs + inputs = np.asarray( + data['inputs']).transpose(2, 0, 1, 3).reshape(num_groups * 8, -1) + scaled_inputs = transformer.transform(inputs) + original_shape_input = np.asarray(data['inputs']).shape + + # Reverse the reshape + reshaped_back = scaled_inputs.reshape(original_shape_input[2], + original_shape_input[0], + original_shape_input[1], + original_shape_input[3]) + + # Reverse the transpose + original_inputs = reshaped_back.transpose(1, 2, 0, 3) + scaled_inputs = original_inputs.transpose(0, 3, 1, 2) + + # Scale labels + labels = np.asarray( + data['labels']).transpose(2, 0, 1, 3).reshape(num_groups * 8, -1) + scaled_labels = transformer.transform(labels) + original_shape_labels = np.asarray(data['labels']).shape + + # Reverse the reshape + reshaped_back = scaled_labels.reshape(original_shape_labels[2], + original_shape_labels[0], + original_shape_labels[1], + original_shape_labels[3]) + + # Reverse the transpose + original_labels = reshaped_back.transpose(1, 2, 0, 3) + scaled_labels = original_labels.transpose(0, 3, 1, 2) + + return scaled_inputs, scaled_labels diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/data_generation_nodamp.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/data_generation_nodamp.py new file mode 100644 index 0000000000..c0f84261b5 --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/data_generation_nodamp.py @@ -0,0 +1,250 @@ +import copy +import os +import pickle +import random +import time +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +import numpy as np + +from progress.bar import Bar + +from datetime import date + +from memilio.simulation import (AgeGroup, LogLevel, set_log_level) +from memilio.simulation.osecir import (Index_InfectionState, interpolate_simulation_result, ParameterStudy, + InfectionState, Model, interpolate_simulation_result) +from memilio.epidata import geoModificationGermany as geoger +from memilio.surrogatemodel.GNN.GNN_utils import (transform_mobility_directory, + make_graph, scale_data) +from memilio.surrogatemodel.utils_surrogatemodel import ( + getBaselineMatrix, remove_confirmed_compartments, get_population) +from enum import Enum + + +class Location(Enum): + Home = 0 + School = 1 + Work = 2 + Other = 3 + + +start_date = mio.Date(2019, 1, 1) +end_date = mio.Date(2020, 12, 31) + + +def set_covid_parameters(model, num_groups=6): + for i in range(num_groups): + # Compartment transition duration + model.parameters.TimeExposed[AgeGroup(i)] = 3.2 + model.parameters.TimeInfectedNoSymptoms[AgeGroup(i)] = 2. + model.parameters.TimeInfectedSymptoms[AgeGroup(i)] = 6. + model.parameters.TimeInfectedSevere[AgeGroup(i)] = 12. + model.parameters.TimeInfectedCritical[AgeGroup(i)] = 8. + + # Compartment transition propabilities + model.parameters.RelativeTransmissionNoSymptoms[AgeGroup(i)] = 0.5 + model.parameters.TransmissionProbabilityOnContact[AgeGroup( + i)] = 0.1 + model.parameters.RecoveredPerInfectedNoSymptoms[AgeGroup(i)] = 0.09 + model.parameters.RiskOfInfectionFromSymptomatic[AgeGroup(i)] = 0.25 + model.parameters.SeverePerInfectedSymptoms[AgeGroup(i)] = 0.2 + model.parameters.CriticalPerSevere[AgeGroup(i)] = 0.25 + model.parameters.DeathsPerCritical[AgeGroup(i)] = 0.3 + model.parameters.MaxRiskOfInfectionFromSymptomatic[AgeGroup( + i)] = 0.5 + + # StartDay is the n-th day of the year + model.parameters.StartDay = start_date.day_in_year + + +def set_contact_matrices(model, data_dir, num_groups=6): + contact_matrices = mio.ContactMatrixGroup( + len(list(Location)), num_groups) + locations = ["home", "school_pf_eig", "work", "other"] + + for i, location in enumerate(locations): + baseline_file = os.path.join( + data_dir, "contacts", "baseline_" + location + ".txt") + minimum_file = os.path.join( + data_dir, "contacts", "minimum_" + location + ".txt") + contact_matrices[i] = mio.ContactMatrix( + mio.read_mobility_plain(baseline_file), + mio.read_mobility_plain(minimum_file) + ) + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + +def get_graph(num_groups, data_dir): + model = Model(num_groups) + set_covid_parameters(model) + set_contact_matrices(model, data_dir) + + graph = osecir.ModelGraph() + + scaling_factor_infected = [2.5, 2.5, 2.5, 2.5, 2.5, 2.5] + scaling_factor_icu = 1.0 + tnt_capacity_factor = 7.5 / 100000. + + path_population_data = os.path.join( + data_dir, "pydata", "Germany", + "county_current_population.json") + + mio.osecir.set_nodes( + model.parameters, + mio.Date(start_date.year, + start_date.month, start_date.day), + mio.Date(end_date.year, + end_date.month, end_date.day), data_dir, + path_population_data, True, graph, scaling_factor_infected, + scaling_factor_icu, tnt_capacity_factor, 0, False) + + mio.osecir.set_edges( + data_dir, graph, len(Location)) + + return graph + + +def run_secir_groups_simulation(days, graph, num_groups=6): + """! Uses an ODE SECIR model allowing for asymptomatic infection with 6 + different age groups. The model is not stratified by region. + Virus-specific parameters are fixed and initial number of persons + in the particular infection states are chosen randomly from defined ranges. + @param Days Describes how many days we simulate within a single run. + @param Graph Graph initilized for the start_date with the population data which + is sampled during the run. + @return List containing the populations in each compartment used to initialize + the run. + """ + for node_indx in range(graph.num_nodes): + model = graph.get_node(node_indx).property + + # Set parameters + # TODO: Put This in the draw_sample function in the ParameterStudy + for i in range(num_groups): + age_group = AgeGroup(i) + pop_age_group = model.populations.get_group_total_AgeGroup( + age_group) + + # Initial number of people in each compartment with random numbers + # Numbers are chosen heuristically based on experience + model.populations[age_group, Index_InfectionState(InfectionState.Exposed)] = random.uniform( + 0.00025, 0.005) * pop_age_group + model.populations[age_group, Index_InfectionState(InfectionState.InfectedNoSymptoms)] = random.uniform( + 0.0001, 0.0035) * pop_age_group + model.populations[age_group, Index_InfectionState( + InfectionState.InfectedNoSymptomsConfirmed)] = 0 + model.populations[age_group, Index_InfectionState(InfectionState.InfectedSymptoms)] = random.uniform( + 0.00007, 0.001) * pop_age_group + model.populations[age_group, Index_InfectionState( + InfectionState.InfectedSymptomsConfirmed)] = 0 + model.populations[age_group, Index_InfectionState(InfectionState.InfectedSevere)] = random.uniform( + 0.00003, 0.0006) * pop_age_group + model.populations[age_group, Index_InfectionState(InfectionState.InfectedCritical)] = random.uniform( + 0.00001, 0.0002) * pop_age_group + model.populations[age_group, Index_InfectionState(InfectionState.Recovered)] = random.uniform( + 0.002, 0.08) * pop_age_group + model.populations[age_group, Index_InfectionState(InfectionState.Dead)] = random.uniform( + 0, 0.0003) * pop_age_group + model.populations.set_difference_from_group_total_AgeGroup( + (age_group, Index_InfectionState(InfectionState.Susceptible)), + pop_age_group) + + # Apply mathematical constraints to parameters + model.apply_constraints() + + # set model to graph + graph.get_node(node_indx).property.populations = model.populations + + study = ParameterStudy(graph, 0, days, dt=0.5, num_runs=1) + start_time = time.time() + study.run() + print("Simulation took: ", time.time() - start_time) + + graph_run = study.run()[0] + results = interpolate_simulation_result(graph_run) + + for result_indx in range(len(results)): + results[result_indx] = remove_confirmed_compartments( + np.transpose(results[result_indx].as_ndarray()[1:, :]), num_groups) + + # Omit first column, as the time points are not of interest here. + dataset_entry = copy.deepcopy(results) + + return dataset_entry + + +def generate_data( + num_runs, data_dir, path, input_width, days, save_data=True): + """! Generate dataset by calling run_secir_simulation (num_runs)-often + @param num_runs Number of times, the function run_secir_simulation is called. + @param data_dir Directory with all data needed to initialize the models. + @param path Path, where the datasets are stored. + @param input_width number of time steps used for model input. + @param label_width number of time steps (days) used as model output/label. + @param save_data Option to deactivate the save of the dataset. Per default true. + """ + set_log_level(mio.LogLevel.Error) + days_sum = days + input_width - 1 + + data = {"inputs": [], + "labels": [], + } + + num_groups = 6 + graph = get_graph(num_groups, data_dir) + + # show progess in terminal for longer runs + # Due to the random structure, theres currently no need to shuffle the data + bar = Bar('Number of Runs done', max=num_runs) + + for _ in range(num_runs): + + data_run = run_secir_groups_simulation( + days_sum, graph) + + inputs = np.asarray(data_run).transpose(1, 2, 0)[: input_width] + data["inputs"].append(inputs) + + data["labels"].append(np.asarray( + data_run).transpose(1, 2, 0)[input_width:]) + + bar.next() + + bar.finish() + + if save_data: + + scaled_inputs, scaled_labels = scale_data(data) + + all_data = {"inputs": scaled_inputs, + "labels": scaled_labels, + } + + # check if data directory exists. If necessary create it. + if not os.path.isdir(path): + os.mkdir(path) + + # save dict to json file + with open(os.path.join(path, 'data_secir_age_groups.pickle'), 'wb') as f: + pickle.dump(all_data, f) + + return data + + +if __name__ == "__main__": + + path = os.path.dirname(os.path.realpath(__file__)) + path_data = os.path.join( + os.path.dirname( + os.path.realpath(os.path.dirname(os.path.realpath(path)))), + 'data_GNN_nodamp_test') + + data_dir = os.path.join(os.getcwd(), 'data') + + input_width = 5 + days = 30 + num_runs = 1 + + generate_data(num_runs, data_dir, path_data, input_width, + days, save_data=True) diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/data_generation_withdamp.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/data_generation_withdamp.py new file mode 100644 index 0000000000..1bf7857461 --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/GNN/data_generation_withdamp.py @@ -0,0 +1,360 @@ +import copy +import os +import pickle +import random +import numpy as np +import time +import memilio.simulation as mio +import memilio.simulation.osecir as osecir +from datetime import date +from enum import Enum + +from progress.bar import Bar +from sklearn.preprocessing import FunctionTransformer + +from memilio.simulation import (AgeGroup, LogLevel, set_log_level, Damping) +from memilio.simulation.osecir import (Index_InfectionState, interpolate_simulation_result, ParameterStudy, + InfectionState, Model, + interpolate_simulation_result, ModelGraph, set_edges) +from memilio.epidata import geoModificationGermany as geoger + +from memilio.surrogatemodel.GNN.GNN_utils import (transform_mobility_directory, + make_graph, scale_data) +from memilio.surrogatemodel.utils_surrogatemodel import ( + getBaselineMatrix, remove_confirmed_compartments, get_population, getMinimumMatrix) + + +class Location(Enum): + Home = 0 + School = 1 + Work = 2 + Other = 3 + + +start_date = mio.Date(2019, 1, 1) +end_date = mio.Date(2020, 12, 31) + + +def set_covid_parameters(model, num_groups=6): + for i in range(num_groups): + # Compartment transition duration + model.parameters.TimeExposed[AgeGroup(i)] = 3.2 + model.parameters.TimeInfectedNoSymptoms[AgeGroup(i)] = 2. + model.parameters.TimeInfectedSymptoms[AgeGroup(i)] = 6. + model.parameters.TimeInfectedSevere[AgeGroup(i)] = 12. + model.parameters.TimeInfectedCritical[AgeGroup(i)] = 8. + + # Compartment transition propabilities + model.parameters.RelativeTransmissionNoSymptoms[AgeGroup(i)] = 0.5 + model.parameters.TransmissionProbabilityOnContact[AgeGroup( + i)] = 0.1 + model.parameters.RecoveredPerInfectedNoSymptoms[AgeGroup(i)] = 0.09 + model.parameters.RiskOfInfectionFromSymptomatic[AgeGroup(i)] = 0.25 + model.parameters.SeverePerInfectedSymptoms[AgeGroup(i)] = 0.2 + model.parameters.CriticalPerSevere[AgeGroup(i)] = 0.25 + model.parameters.DeathsPerCritical[AgeGroup(i)] = 0.3 + model.parameters.MaxRiskOfInfectionFromSymptomatic[AgeGroup( + i)] = 0.5 + + # StartDay is the n-th day of the year + model.parameters.StartDay = start_date.day_in_year + + +def set_contact_matrices(model, data_dir, num_groups=6): + contact_matrices = mio.ContactMatrixGroup( + len(list(Location)), num_groups) + locations = ["home", "school_pf_eig", "work", "other"] + + for i, location in enumerate(locations): + baseline_file = os.path.join( + data_dir, "contacts", "baseline_" + location + ".txt") + minimum_file = os.path.join( + data_dir, "contacts", "minimum_" + location + ".txt") + contact_matrices[i] = mio.ContactMatrix( + mio.read_mobility_plain(baseline_file), + mio.read_mobility_plain(minimum_file) + ) + model.parameters.ContactPatterns.cont_freq_mat = contact_matrices + + +def get_graph(num_groups, data_dir): + model = Model(num_groups) + set_covid_parameters(model) + set_contact_matrices(model, data_dir) + + graph = osecir.ModelGraph() + + scaling_factor_infected = [2.5, 2.5, 2.5, 2.5, 2.5, 2.5] + scaling_factor_icu = 1.0 + tnt_capacity_factor = 7.5 / 100000. + + path_population_data = os.path.join( + data_dir, "pydata", "Germany", + "county_current_population.json") + + mio.osecir.set_nodes( + model.parameters, + mio.Date(start_date.year, + start_date.month, start_date.day), + mio.Date(end_date.year, + end_date.month, end_date.day), data_dir, + path_population_data, True, graph, scaling_factor_infected, + scaling_factor_icu, tnt_capacity_factor, 0, False) + + mio.osecir.set_edges( + data_dir, graph, len(Location)) + + return graph + + +def run_secir_groups_simulation(days, graph, dampings, num_groups=6): + """! Uses an ODE SECIR model allowing for asymptomatic infection + with 6 different age groups. The model is not stratified by region. + Virus-specific parameters are fixed and initial number of persons + in the particular infection states are chosen randomly from defined ranges. + @param days (int) Describes how many days we simulate within a single run. + @param Graph Graph initilized for the start_date with the population data which + is sampled during the run. + @param damping_day (int) The day when damping is applied. + @return List containing the populations in each compartment + used to initialize the run. + """ + for node_indx in range(graph.num_nodes): + model = graph.get_node(node_indx).property + + # Set parameters + # TODO: Put This in the draw_sample function in the ParameterStudy + for i in range(num_groups): + age_group = AgeGroup(i) + pop_age_group = model.populations.get_group_total_AgeGroup( + age_group) + + # Initial number of people in each compartment with random numbers + # Numbers are chosen heuristically based on experience + model.populations[age_group, Index_InfectionState(InfectionState.Exposed)] = random.uniform( + 0.00025, 0.005) * pop_age_group + model.populations[age_group, Index_InfectionState(InfectionState.InfectedNoSymptoms)] = random.uniform( + 0.0001, 0.0035) * pop_age_group + model.populations[age_group, Index_InfectionState( + InfectionState.InfectedNoSymptomsConfirmed)] = 0 + model.populations[age_group, Index_InfectionState(InfectionState.InfectedSymptoms)] = random.uniform( + 0.00007, 0.001) * pop_age_group + model.populations[age_group, Index_InfectionState( + InfectionState.InfectedSymptomsConfirmed)] = 0 + model.populations[age_group, Index_InfectionState(InfectionState.InfectedSevere)] = random.uniform( + 0.00003, 0.0006) * pop_age_group + model.populations[age_group, Index_InfectionState(InfectionState.InfectedCritical)] = random.uniform( + 0.00001, 0.0002) * pop_age_group + model.populations[age_group, Index_InfectionState(InfectionState.Recovered)] = random.uniform( + 0.002, 0.08) * pop_age_group + model.populations[age_group, Index_InfectionState(InfectionState.Dead)] = random.uniform( + 0, 0.0003) * pop_age_group + model.populations.set_difference_from_group_total_AgeGroup( + (age_group, Index_InfectionState(InfectionState.Susceptible)), + pop_age_group) + + # Apply mathematical constraints to parameters + model.apply_constraints() + + # Generate a damping matrix and assign it to the model + # TODO: This can be done outside and is (currently) static for all models + damped_matrices = [] + damping_coeff = [] + for day in dampings: + + # generat a random damping factor + damping = np.ones((num_groups, num_groups) + ) * np.float16(random.uniform(0, 0.5)) + + # add damping to model + model.parameters.ContactPatterns.cont_freq_mat.add_damping(Damping( + coeffs=(damping), t=day, level=0, type=0)) + + damped_matrices.append(model.parameters.ContactPatterns.cont_freq_mat.get_matrix_at( + day+1)) + damping_coeff.append(damping[0][0]) + + # Apply mathematical constraints to parameters + model.apply_constraints() + # set model to graph + graph.get_node(node_indx).property.populations = model.populations + graph.get_node(node_indx).property.parameters = model.parameters + + study = ParameterStudy(graph, 0, days, dt=0.5, num_runs=1) + start_time = time.time() + study.run() + print("Simulation took: ", time.time() - start_time) + + graph_run = study.run()[0] + results = interpolate_simulation_result(graph_run) + + for result_indx in range(len(results)): + results[result_indx] = remove_confirmed_compartments( + np.transpose(results[result_indx].as_ndarray()[1:, :]), num_groups) + + # Omit first column, as the time points are not of interest here. + dataset_entry = copy.deepcopy(results) + + return dataset_entry, damped_matrices, dampings, damping_coeff + + +def generate_dampings_withshadowdamp(number_of_dampings, days, min_distance, min_damping_day, n_runs): + """! Draw damping days while keeping a minimum distance between the + damping days. This method aims to create a uniform ditribution of + drawn damping days. + @param num_of_dampings (int) Number of dampings that have to be drawn. + @param days (int) Number of days which are simulated (label_width). + @param min_distance (int) The minimum number of days between two dampings. + @param min_damping_day (int) The earliest day of the simualtion where a damping + can take place. + @param n_runs 8int) Number of simulation runs. + """ + + all_dampings = [] + count_runs = 0 + count_shadow = 0 + while len(all_dampings) < n_runs: + + days_list = list(range((min_damping_day), days)) + dampings = [] + if count_shadow < 2: + for i in range(number_of_dampings): + + damp = random.choice(days_list) + days_before = list(range(damp-(min_distance), damp)) + days_after = list(range(damp, damp+(min_distance+1))) + dampings.append(damp) + days_list = [ele for ele in days_list if ele not in ( + days_before+days_after)] + else: + # chose a forbidden damping + damp = random.choice( + list(range((0-min_distance), 0)) + list(range(days+1, (days+min_distance+1)))) + + days_before = list(range(damp-(min_distance), damp)) + days_after = list(range(damp, damp+(min_distance+1))) + days_list = [ele for ele in days_list if ele not in ( + days_before+days_after)] + dampings.append(damp) + for i in range(number_of_dampings): + + damp = random.choice(days_list) + days_before = list(range(damp-(min_distance), damp)) + days_after = list(range(damp, damp+(min_distance+1))) + dampings.append(damp) + days_list = [ele for ele in days_list if ele not in ( + days_before+days_after)] + count_shadow = 0 + + forbidden_damping_values = list( + range((0-min_distance), 0)) + list(range(days+1, (days+min_distance+1))) + dampings = [ + ele for ele in dampings if ele not in forbidden_damping_values] + count_runs += 1 + count_shadow += 1 + # select first or last five dampings + if len(dampings) >= number_of_dampings: + all_dampings.append(sorted(dampings)) + + return np.asarray(all_dampings) + + +def generate_data( + num_runs, data_dir, path, input_width, label_width, number_of_dampings, save_data=True): + """! Generate dataset by calling run_secir_simulation (num_runs)-often + @param num_runs Number of times, the function run_secir_simulation is called. + @param data_dir Directory with all data needed to initialize the models. + @param path Path, where the datasets are stored. + @param input_width Number of time steps used for model input. + @param label_width Number of time steps (days) used as model output/label. + @param number_of_dampings (int) The number of contact change points applied to the simulation. + @param save_data Option to deactivate the save of the dataset. Per default true. + """ + set_log_level(mio.LogLevel.Error) + days_sum = label_width+input_width-1 + + num_groups = 6 + graph = get_graph(num_groups, data_dir) + + # generate dampings + damping_days = generate_dampings_withshadowdamp( + number_of_dampings=number_of_dampings, days=label_width, + min_distance=7, min_damping_day=input_width, n_runs=num_runs + ) + + # all data including damping information + all_data = {"inputs": [], + "labels": [], + "damping_coeff": [], + "damping_day": [], + "damped_matrix": []} + + # data that needs to be scaled + data = {"inputs": [], + "labels": [], + "damping_coeff": [], + "damping_day": [], + "damped_matrix": []} + + # show progess in terminal for longer runs + # Due to the random structure, theres currently no need to shuffle the data + bar = Bar('Number of Runs done', max=num_runs) + + model_params = copy.deepcopy(graph.get_node(0).property.parameters) + + for i in range(num_runs): + params_run = copy.deepcopy(model_params) + # reset contact matrix in each node + for node_indx in range(graph.num_nodes): + graph.get_node(node_indx).property.parameters = params_run + data_run, damped_contact_matrix, damping_days_s, damping_factor = run_secir_groups_simulation( + days_sum, graph, damping_days[i]) + + inputs = np.asarray(data_run).transpose(1, 2, 0)[:input_width] + data["inputs"].append(inputs) + data["labels"].append(np.asarray( + data_run).transpose(1, 2, 0)[input_width:]) + data["damping_coeff"].append(damping_factor) + data["damping_day"].append(damping_days_s) + data["damped_matrix"].append(damped_contact_matrix) + + bar.next() + + bar.finish() + + if save_data: + + scaled_inputs, scaled_labels = scale_data(data) + + all_data = {"inputs": scaled_inputs, + "labels": scaled_labels, + "damping_coeff": data['damping_coeff'], + "damping_day": data['damping_day'], + "damped_matrix": data['damped_matrix']} + + # check if data directory exists. If necessary create it. + if not os.path.isdir(path): + os.mkdir(path) + + # save dict to json file + with open(os.path.join(path, 'data_secir_age_groups.pickle'), 'wb') as f: + pickle.dump(all_data, f) + return data + + +if __name__ == "__main__": + + input_width = 5 + label_width = 95 + number_of_dampings = 3 + num_runs = 5 + path = os.path.dirname(os.path.realpath(__file__)) + path_data = os.path.join( + os.path.dirname( + os.path.realpath(os.path.dirname(os.path.realpath(path)))), + 'data_GNN_with_'+str(number_of_dampings)+'_dampings_test') + + data_dir = os.path.join(os.getcwd(), 'data') + + generate_data(num_runs, data_dir, path_data, input_width, + label_width, number_of_dampings, save_data=False) diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/ode_secir_groups/data_generation.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/ode_secir_groups/data_generation.py index 1cb1dac82d..5e94ebf689 100644 --- a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/ode_secir_groups/data_generation.py +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/ode_secir_groups/data_generation.py @@ -34,6 +34,9 @@ InfectionState, Model, interpolate_simulation_result, simulate) +from memilio.surrogatemodel.utils_surrogatemodel import ( + getBaselineMatrix, getMinimumMatrix, get_population) + def interpolate_age_groups(data_entry): """! Interpolates the age groups from the population data into the age groups used in the simulation. @@ -214,7 +217,8 @@ def generate_data( days = input_width + label_width - 1 # Load population data - population = get_population(path_population) + # population = get_population(path_population) + population = get_population() # show progess in terminal for longer runs # Due to the random structure, there's currently no need to shuffle the data @@ -256,61 +260,6 @@ def generate_data( return data -def getBaselineMatrix(): - """! loads the baselinematrix - """ - - baseline_contact_matrix0 = os.path.join( - "./data/contacts/baseline_home.txt") - baseline_contact_matrix1 = os.path.join( - "./data/contacts/baseline_school_pf_eig.txt") - baseline_contact_matrix2 = os.path.join( - "./data/contacts/baseline_work.txt") - baseline_contact_matrix3 = os.path.join( - "./data/contacts/baseline_other.txt") - - baseline = np.loadtxt(baseline_contact_matrix0) \ - + np.loadtxt(baseline_contact_matrix1) + \ - np.loadtxt(baseline_contact_matrix2) + \ - np.loadtxt(baseline_contact_matrix3) - - return baseline - - -def getMinimumMatrix(): - """! loads the minimum matrix - """ - - minimum_contact_matrix0 = os.path.join( - "./data/contacts/minimum_home.txt") - minimum_contact_matrix1 = os.path.join( - "./data/contacts/minimum_school_pf_eig.txt") - minimum_contact_matrix2 = os.path.join( - "./data/contacts/minimum_work.txt") - minimum_contact_matrix3 = os.path.join( - "./data/contacts/minimum_other.txt") - - minimum = np.loadtxt(minimum_contact_matrix0) \ - + np.loadtxt(minimum_contact_matrix1) + \ - np.loadtxt(minimum_contact_matrix2) + \ - np.loadtxt(minimum_contact_matrix3) - - return minimum - - -def get_population(path): - """! read population data in list from dataset - @param path Path to the dataset containing the population data - """ - - with open(path) as f: - data = json.load(f) - population = [] - for data_entry in data: - population.append(interpolate_age_groups(data_entry)) - return population - - if __name__ == "__main__": # Store data relative to current file two levels higher. path = os.path.dirname(os.path.realpath(__file__)) @@ -322,6 +271,6 @@ def get_population(path): input_width = 5 label_width = 30 - num_runs = 10000 + num_runs = 10 data = generate_data(num_runs, path_output, path_population, input_width, - label_width) + label_width, save_data=True) diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/ode_secir_simple/data_generation.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/ode_secir_simple/data_generation.py index 96eff3ac94..9dee453c70 100644 --- a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/ode_secir_simple/data_generation.py +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/ode_secir_simple/data_generation.py @@ -35,13 +35,7 @@ InfectionState, Model, Simulation, interpolate_simulation_result, simulate) - -def remove_confirmed_compartments(result_array): - sum_inf_no_symp = np.sum(result_array[:, [2, 3]], axis=1) - sum_inf_symp = np.sum(result_array[:, [2, 3]], axis=1) - result_array[:, 2] = sum_inf_no_symp - result_array[:, 4] = sum_inf_symp - return np.delete(result_array, [3, 5], axis=1) +from memilio.surrogatemodel.utils_surrogatemodel import remove_confirmed_compartments def run_secir_simple_simulation(days): @@ -121,13 +115,11 @@ def run_secir_simple_simulation(days): result_array = result.as_ndarray() result_array = remove_confirmed_compartments( - result_array[1:, :].transpose()) - - dataset = [] + result_array[1:, :].transpose(), 1) dataset_entries = copy.deepcopy(result_array) - return dataset_entries.tolist() + return dataset_entries def generate_data( @@ -205,6 +197,6 @@ def generate_data( input_width = 5 label_width = 30 - num_runs = 1000 + num_runs = 10000 data = generate_data(num_runs, path_data, input_width, - label_width) + label_width, save_data=True) diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel/utils_surrogatemodel.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/utils_surrogatemodel.py new file mode 100644 index 0000000000..0fc619092d --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel/utils_surrogatemodel.py @@ -0,0 +1,87 @@ +import numpy as np +import pandas as pd +import os +from memilio.epidata import modifyDataframeSeries as mdfs + + +def remove_confirmed_compartments(dataset_entries, num_groups): + """! The compartments which contain confirmed cases are not needed and are + therefore omitted by summarizing the confirmed compartment with the + original compartment. + @param dataset_entries Array that contains the compartmental data with + confirmed compartments. + @param num_groups Number of age groups. + @return Array that contains the compartmental data without confirmed compartments. + """ + + new_dataset_entries = [] + for i in dataset_entries: + dataset_entries_reshaped = i.reshape( + [num_groups, int(np.asarray(dataset_entries).shape[1]/num_groups)] + ) + sum_inf_no_symp = np.sum(dataset_entries_reshaped[:, [2, 3]], axis=1) + sum_inf_symp = np.sum(dataset_entries_reshaped[:, [4, 5]], axis=1) + dataset_entries_reshaped[:, 2] = sum_inf_no_symp + dataset_entries_reshaped[:, 4] = sum_inf_symp + new_dataset_entries.append( + np.delete(dataset_entries_reshaped, [3, 5], axis=1).flatten() + ) + return new_dataset_entries + + +def getBaselineMatrix(): + """! loads the baselinematrix + """ + + baseline_contact_matrix0 = os.path.join( + "./data/contacts/baseline_home.txt") + baseline_contact_matrix1 = os.path.join( + "./data/contacts/baseline_school_pf_eig.txt") + baseline_contact_matrix2 = os.path.join( + "./data/contacts/baseline_work.txt") + baseline_contact_matrix3 = os.path.join( + "./data/contacts/baseline_other.txt") + + baseline = np.loadtxt(baseline_contact_matrix0) \ + + np.loadtxt(baseline_contact_matrix1) + \ + np.loadtxt(baseline_contact_matrix2) + \ + np.loadtxt(baseline_contact_matrix3) + + return baseline + + +def getMinimumMatrix(): + """! loads the minimum matrix + """ + + minimum_contact_matrix0 = os.path.join( + "./data/contacts/minimum_home.txt") + minimum_contact_matrix1 = os.path.join( + "./data/contacts/minimum_school_pf_eig.txt") + minimum_contact_matrix2 = os.path.join( + "./data/contacts/minimum_work.txt") + minimum_contact_matrix3 = os.path.join( + "./data/contacts/minimum_other.txt") + + minimum = np.loadtxt(minimum_contact_matrix0) \ + + np.loadtxt(minimum_contact_matrix1) + \ + np.loadtxt(minimum_contact_matrix2) + \ + np.loadtxt(minimum_contact_matrix3) + + return minimum + + +def get_population(): + df_population = pd.read_json( + 'data/pydata/Germany/county_population.json') + age_groups = ['0-4', '5-14', '15-34', '35-59', '60-79', '80-130'] + + df_population_agegroups = pd.DataFrame( + columns=[df_population.columns[0]] + age_groups) + for region_id in df_population.iloc[:, 0]: + df_population_agegroups.loc[len(df_population_agegroups.index), :] = [int(region_id)] + list( + mdfs.fit_age_group_intervals(df_population[df_population.iloc[:, 0] == int(region_id)].iloc[:, 2:], age_groups)) + + population = df_population_agegroups.values.tolist() + + return population diff --git a/pycode/memilio-surrogatemodel/memilio/surrogatemodel_test/test_surrogatemodel_GNN.py b/pycode/memilio-surrogatemodel/memilio/surrogatemodel_test/test_surrogatemodel_GNN.py new file mode 100644 index 0000000000..8371aed65a --- /dev/null +++ b/pycode/memilio-surrogatemodel/memilio/surrogatemodel_test/test_surrogatemodel_GNN.py @@ -0,0 +1,329 @@ +############################################################################# +# Copyright (C) 2020-2023 German Aerospace Center (DLR-SC) +# +# Authors: +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# +from pyfakefs import fake_filesystem_unittest + +from memilio.surrogatemodel.GNN import ( + data_generation_nodamp, data_generation_withdamp) +import memilio.simulation as mio +import memilio.simulation.osecir as osecir + +from unittest.mock import patch +import os +import unittest + +import numpy as np + +import tensorflow as tf +tf.get_logger().setLevel('ERROR') + + +class TestSurrogatemodelGNN(fake_filesystem_unittest.TestCase): + + path = '/home/' + + num_groups = 6 + model = osecir.Model(num_groups) + graph = osecir.ModelGraph() + graph.add_node(0, model) + graph.add_node(1, model) + mobility_coefficients = 0.01 * np.ones(model.populations.numel()) + for i in range(num_groups): + flat_index = model.populations.get_flat_index( + osecir.MultiIndex_PopulationsArray(mio.AgeGroup(i), osecir.InfectionState.Dead)) + mobility_coefficients[flat_index] = 0 + graph.add_edge(0, 1, mobility_coefficients) + graph.add_edge(1, 0, mobility_coefficients) + + def setUp(self): + self.setUpPyfakefs() + +#### test simulation no damp #### + + @patch('memilio.surrogatemodel.GNN.data_generation_nodamp.getBaselineMatrix', + return_value=0.6 * np.ones((6, 6))) + @patch('memilio.surrogatemodel.GNN.data_generation_nodamp.make_graph', + return_value=graph) + @patch('memilio.epidata.transformMobilityData.updateMobility2022') + @patch('memilio.surrogatemodel.GNN.data_generation_nodamp.transform_mobility_directory', + autospec=True) + def test_simulation_run_nodamp(self, mock_transform_mobility, + mock_update_mobility, mock_baseline, mock_graph): + + mock_transform_mobility.side_effect = lambda: None + mock_update_mobility.side_effect = lambda directory, mobility_file: None + + days_1 = 10 + days_2 = 30 + days_3 = 50 + + population = [[5256.0, 10551, 32368.5, 43637.833333333336, 22874.066666666666, 8473.6], + [4000, 8000, 40000, 28000, 15000, 6000]] + + # Call the actual function being tested + simulation_1 = data_generation_nodamp.run_secir_groups_simulation( + days_1, population) + simulation_2 = data_generation_nodamp.run_secir_groups_simulation( + days_2, population) + simulation_3 = data_generation_nodamp.run_secir_groups_simulation( + days_3, population) + + # Ensure that the results are of the correct length. + self.assertEqual(len(simulation_1[0]), days_1 + 1) + self.assertEqual(len(simulation_2[0]), days_2 + 1) + self.assertEqual(len(simulation_3[0]), days_3 + 1) + + +#### test data genertion no damp #### + + + @patch('memilio.surrogatemodel.GNN.data_generation_nodamp.getBaselineMatrix', + return_value=0.6 * np.ones((6, 6))) + @patch('memilio.surrogatemodel.GNN.data_generation_nodamp.make_graph', + return_value=graph) + @patch('memilio.surrogatemodel.GNN.data_generation_nodamp.get_population', + return_value=np.random.randint(0, 700001, size=(400, 6))) + @patch('memilio.epidata.transformMobilityData.updateMobility2022') + @patch('memilio.surrogatemodel.GNN.data_generation_nodamp.transform_mobility_directory', + autospec=True) + def test_data_generation_runs_nodamp( + self, mock_transform_mobility, mock_update_mobility, + mock_baseline, mock_graph, mock_population): + + mock_transform_mobility.side_effect = lambda: None + mock_update_mobility.side_effect = lambda directory, mobility_file: None + + input_width_1 = 1 + input_width_2 = 5 + + label_width_1 = 1 + label_width_2 = 10 + + num_runs_1 = 1 + num_runs_2 = 2 + + # test data generation without dampings + data_1 = data_generation_nodamp.generate_data( + num_runs_1, self.path, input_width_1, label_width_1, + save_data=False) + self.assertEqual(len(data_1['inputs']), num_runs_1) + self.assertEqual(len(data_1['inputs'][0]), input_width_1) + self.assertEqual(len(data_1['inputs'][0][0]), 48) + self.assertEqual(len(data_1['labels']), num_runs_1) + self.assertEqual(len(data_1['labels'][0]), label_width_1) + self.assertEqual(len(data_1['labels'][0][0]), 48) + + data_2 = data_generation_nodamp.generate_data( + num_runs_2, self.path, input_width_2, label_width_2, + save_data=False) + self.assertEqual(len(data_2['inputs']), num_runs_2) + self.assertEqual(len(data_2['inputs'][0]), input_width_2) + self.assertEqual(len(data_2['inputs'][0][0]), 48) + self.assertEqual(len(data_2['labels']), num_runs_2) + self.assertEqual(len(data_2['labels'][0]), label_width_2) + self.assertEqual(len(data_2['labels'][0][0]), 48) + + @patch('memilio.surrogatemodel.GNN.data_generation_nodamp.getBaselineMatrix', + return_value=0.6 * np.ones((6, 6))) + @patch('memilio.surrogatemodel.GNN.data_generation_nodamp.get_population', + return_value=np.random.randint(0, 700001, size=(400, 6))) + # create mock graph + @patch('memilio.surrogatemodel.GNN.data_generation_nodamp.make_graph', + return_value=graph) + @patch('memilio.surrogatemodel.GNN.data_generation_nodamp.scale_data', + autospec=True) + @patch('memilio.epidata.transformMobilityData.updateMobility2022') + @patch('memilio.surrogatemodel.GNN.data_generation_nodamp.transform_mobility_directory', + autospec=True) + def test_data_generation_save_nodamp( + self, mock_transform_mobility, mock_update_mobility, mock_scale_data, + mock_population, mock_baseline, mock_graph): + + mock_transform_mobility.side_effect = lambda: None + mock_update_mobility.side_effect = lambda directory, mobility_file: None + + # Mock the return value of scale_data with dummy inputs and labels + # Assuming scaled inputs and labels are 4D arrays based on the original function's output + mock_scale_data.return_value = (np.random.rand( + 10, 8, 10, 6), np.random.rand(10, 8, 10, 6)) + + input_width = 2 + label_width = 3 + num_runs = 1 + + # Call the function being tested + data_generation_nodamp.generate_data(num_runs, self.path, input_width, + label_width) + + # Check the number of generated files + self.assertEqual(len(os.listdir(self.path)), 1) + + # Check the contents of the directory + self.assertEqual(os.listdir(self.path), + ['data_secir_age_groups.pickle']) + + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.getMinimumMatrix', + return_value=0 * np.ones((6, 6))) + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.getBaselineMatrix', + return_value=0.6 * np.ones((6, 6))) + # create mock graph + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.make_graph', + return_value=graph) + @patch('memilio.epidata.transformMobilityData.updateMobility2022') + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.transform_mobility_directory', + autospec=True) + def test_simulation_run_withdamp(self, mock_transform_mobility, mock_update_mobility, + mock_minimum, mock_baseline, mock_graph): + + mock_transform_mobility.side_effect = lambda: None + mock_update_mobility.side_effect = lambda directory, mobility_file: None + + days_1 = 10 + days_2 = 30 + days_3 = 50 + + dampings1 = [5] + dampings2 = [6, 15] + dampings3 = [8, 18, 35] + + population = [[5256.0, 10551, 32368.5, + 43637.833333333336, 22874.066666666666, 8473.6], + [4000, 8000, 40000, + 28000, 15000, 6000]] + + dataset_entry1, damped_matrices1, num_damp1, damping_coeff1 = data_generation_withdamp.run_secir_groups_simulation( + days_1, population, dampings1) + dataset_entry2, damped_matrices2, num_damp2, damping_coeff2 = data_generation_withdamp.run_secir_groups_simulation( + days_2, population, dampings2) + dataset_entry3, damped_matrices3, num_damp3, damping_coeff3 = data_generation_withdamp.run_secir_groups_simulation( + days_3, population, dampings3) + + # result length + self.assertEqual(len(dataset_entry1[0]), days_1+1) + self.assertEqual(len(dataset_entry2[0]), days_2+1) + self.assertEqual(len(dataset_entry3[0]), days_3+1) + + baseline = data_generation_withdamp.getBaselineMatrix() + # damping factor + self.assertEqual(damped_matrices1[0].all(), + (baseline * damping_coeff1[0]).all()) + self.assertEqual( + damped_matrices2[1].all(), + (baseline * damping_coeff2[1]).all()) + self.assertEqual( + damped_matrices3[2].all(), + (baseline * damping_coeff3[2]).all()) + + # number of dampings length + self.assertEqual(len(damping_coeff1), len(dampings1)) + self.assertEqual(len(damping_coeff2), len(dampings2)) + self.assertEqual(len(damping_coeff3), len(dampings3)) + +# test data generation with dampings + + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.getMinimumMatrix', + return_value=0 * np.ones((6, 6))) + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.getBaselineMatrix', + return_value=0.6 * np.ones((6, 6))) + # create mock graph + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.make_graph', + return_value=graph) + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.get_population', + return_value=np.random.randint(0, 700001, size=(400, 6))) + # mock transform directory function + @patch('memilio.epidata.transformMobilityData.updateMobility2022') + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.transform_mobility_directory', autospec=True) + def test_data_generation_runs_withdamp( + self, mock_transform_mobility, mock_update_mobility, + mock_minimum, mock_baseline, mock_graph, mock_population): + + mock_transform_mobility.side_effect = lambda: None + mock_update_mobility.side_effect = lambda directory, mobility_file: None + + input_width_1 = 1 + input_width_2 = 5 + + label_width_1 = 10 + label_width_2 = 30 + + num_runs_1 = 1 + num_runs_2 = 2 + + damping1 = 1 + damping2 = 2 + + data_1 = data_generation_withdamp.generate_data( + num_runs_1, self.path, input_width_1, label_width_1, + damping1, save_data=False) + self.assertEqual(len(data_1['inputs']), num_runs_1) + self.assertEqual(len(data_1['inputs'][0]), input_width_1) + self.assertEqual(len(data_1['inputs'][0][0]), 48) + self.assertEqual(len(data_1['labels']), num_runs_1) + self.assertEqual(len(data_1['labels'][0]), label_width_1) + self.assertEqual(len(data_1['labels'][0][0]), 48) + + data_2 = data_generation_withdamp.generate_data( + num_runs_2, self.path, input_width_2, label_width_2, + damping2, save_data=False) + self.assertEqual(len(data_2['inputs']), num_runs_2) + self.assertEqual(len(data_2['inputs'][0]), input_width_2) + self.assertEqual(len(data_2['inputs'][0][0]), 48) + self.assertEqual(len(data_2['labels']), num_runs_2) + self.assertEqual(len(data_2['labels'][0]), label_width_2) + self.assertEqual(len(data_2['labels'][0][0]), 48) + + # test saving for model with dampings + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.getMinimumMatrix', + return_value=0 * np.ones((6, 6))) + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.getBaselineMatrix', + return_value=0.6 * np.ones((6, 6))) + # create mock graph + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.make_graph', + return_value=graph) + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.get_population', + return_value=np.random.randint(0, 700001, size=(400, 6))) + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.scale_data', + autospec=True) + @patch('memilio.epidata.transformMobilityData.updateMobility2022') + @patch('memilio.surrogatemodel.GNN.data_generation_withdamp.transform_mobility_directory', + autospec=True) + def test_data_generation_save_withdamp( + self, mock_transform_mobility, mock_update_mobility, mock_scale_data, + mock_population, mock_baseline, mcok_minimum, mock_graph): + + mock_transform_mobility.side_effect = lambda: None + mock_update_mobility.side_effect = lambda directory, mobility_file: None + + mock_scale_data.return_value = (np.random.rand( + 10, 8, 10, 6), np.random.rand(10, 8, 10, 6)) + + input_width = 5 + label_width = 20 + num_runs = 2 + num_dampings = 2 + data = data_generation_withdamp.generate_data(num_runs, self.path, input_width, + label_width, num_dampings, save_data=True) + + self.assertEqual(len(os.listdir(self.path)), 1) + self.assertEqual(os.listdir(self.path), + ['data_secir_age_groups.pickle']) + + +if __name__ == '__main__': + unittest.main()