-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathtrain_graph.py
221 lines (190 loc) · 11 KB
/
train_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
import keras as ks
import numpy as np
import argparse
import time
import kgcnn.training.scheduler # noqa
import kgcnn.training.schedule # noqa
from datetime import timedelta
import kgcnn.losses.losses
import kgcnn.metrics.metrics
from kgcnn.metrics.metrics import ScaledMeanAbsoluteError, ScaledRootMeanSquaredError
from kgcnn.training.history import save_history_score, load_history_list, load_time_list
from kgcnn.data.transform.scaler.serial import deserialize as deserialize_scaler
from kgcnn.utils.plots import plot_train_test_loss, plot_predict_true
from kgcnn.models.serial import deserialize as deserialize_model
from kgcnn.data.serial import deserialize as deserialize_dataset
from kgcnn.training.hyper import HyperParameter
from kgcnn.utils.devices import check_device, set_cuda_device
from kgcnn.data.utils import save_pickle_file
# Input arguments from command line with default values from example.
# From command line, one can specify the model, dataset and the hyperparameter which contain all configuration
# for training and model setup.
parser = argparse.ArgumentParser(description='Train a GNN on a graph regression or classification task.')
parser.add_argument("--hyper", required=False, help="Filepath to hyperparameter config file (.py or .json).",
default="hyper/hyper_esol.py")
parser.add_argument("--category", required=False, help="Graph model to train.", default="DGIN")
parser.add_argument("--model", required=False, help="Graph model to train.", default=None)
parser.add_argument("--dataset", required=False, help="Name of the dataset.", default=None)
parser.add_argument("--make", required=False, help="Name of the class for model.", default=None)
parser.add_argument("--gpu", required=False, help="GPU index used for training.", default=None, nargs="+", type=int)
parser.add_argument("--fold", required=False, help="Split or fold indices to run.", default=None, nargs="+", type=int)
parser.add_argument("--seed", required=False, help="Set random seed.", default=42, type=int)
args = vars(parser.parse_args())
print("Input of argparse:", args)
# Check and set device
if args["gpu"] is not None:
set_cuda_device(args["gpu"])
print(check_device())
# Set seed.
np.random.seed(args["seed"])
ks.utils.set_random_seed(args["seed"])
# A class `HyperParameter` is used to expose and verify hyperparameter.
# The hyperparameter is stored as a dictionary with section 'model', 'dataset' and 'training'.
hyper = HyperParameter(
hyper_info=args["hyper"], hyper_category=args["category"],
model_name=args["model"], model_class=args["make"], dataset_class=args["dataset"])
hyper.verify()
# Loading a specific per-defined dataset from a module in kgcnn.data.datasets.
# Those sub-classed classes are named after the dataset like e.g. `ESOLDataset`
dataset = deserialize_dataset(hyper["dataset"])
# Check if dataset has the required properties for model input. This includes a quick shape comparison.
# The name of the keras `Input` layer of the model is directly connected to property of the dataset.
# Example 'edge_indices' or 'node_attributes'. This couples the keras model to the dataset.
dataset.assert_valid_model_input(hyper["model"]["config"]["inputs"])
# Filter the dataset for invalid graphs. At the moment invalid graphs are graphs which do not have the property set,
# which is required by the model's input layers, or if a tensor-like property has zero length.
dataset.clean(hyper["model"]["config"]["inputs"])
data_length = len(dataset) # Length of the cleaned dataset.
# Make output directory. This can further be adapted in hyperparameter.
filepath = hyper.results_file_path()
postfix_file = hyper["info"]["postfix_file"]
# Always train on `graph_labels` .
# Just making sure that the target is of shape `(N, #labels)`. This means output embedding is on graph level.
label_names, label_units = dataset.set_multi_target_labels(
"graph_labels",
hyper["training"]["multi_target_indices"] if "multi_target_indices" in hyper["training"] else None,
data_unit=hyper["data"]["data_unit"] if "data_unit" in hyper["data"] else None
)
# Iterate over the cross-validation splits.
# Indices for train-test splits are stored in 'test_indices_list'.
if "cross_validation" in hyper["training"]:
from sklearn.model_selection import KFold
splitter = KFold(**hyper["training"]["cross_validation"]["config"])
train_test_indices = [
(train_index, test_index) for train_index, test_index in splitter.split(X=np.zeros((data_length, 1)))]
else:
train_test_indices = dataset.get_train_test_indices(train="train", test="test")
train_indices_all, test_indices_all = [], []
# Run splits.
execute_folds = args["fold"] if "execute_folds" not in hyper["training"] else hyper["training"]["execute_folds"]
model, current_split, scaled_predictions = None, None, False
for current_split, (train_index, test_index) in enumerate(train_test_indices):
# Keep list of train/test indices.
test_indices_all.append(test_index)
train_indices_all.append(train_index)
# Only do execute_splits out of the k-folds of cross-validation.
if execute_folds is not None:
if current_split not in execute_folds:
continue
print("Running training on split: '%s'." % current_split)
dataset_train, dataset_test = dataset[train_index], dataset[test_index]
# Make the model for current split using model kwargs from hyperparameter.
model = deserialize_model(hyper["model"])
# Adapt output-scale via a transform.
# Scaler is applied to target if 'scaler' appears in hyperparameter. Only use for regression.
scaled_metrics = None
if "scaler" in hyper["training"]:
print("Using Scaler to adjust output scale.")
scaler = deserialize_scaler(hyper["training"]["scaler"])
scaler.fit_dataset(dataset_train)
if hasattr(model, "set_scale"):
print("Setting scale at model.")
model.set_scale(scaler)
else:
print("Transforming dataset.")
dataset_train = scaler.transform_dataset(dataset_train, copy_dataset=True, copy=True)
dataset_test = scaler.transform_dataset(dataset_test, copy_dataset=True, copy=True)
# If scaler was used we add rescaled standard metrics to compile, since otherwise the keras history will not
# directly log the original target values, but the scaled ones.
scaler_scale = scaler.get_scaling()
mae_metric = ScaledMeanAbsoluteError(scaler_scale.shape, name="scaled_mean_absolute_error")
rms_metric = ScaledRootMeanSquaredError(scaler_scale.shape, name="scaled_root_mean_squared_error")
if scaler_scale is not None:
mae_metric.set_scale(scaler_scale)
rms_metric.set_scale(scaler_scale)
scaled_metrics = [mae_metric, rms_metric]
scaled_predictions = True
# Save scaler to file
scaler.save(os.path.join(filepath, f"scaler{postfix_file}_fold_{current_split}"))
# Pick train/test data.
x_train = dataset_train.tensor(hyper["model"]["config"]["inputs"])
y_train = np.array(dataset_train.get("graph_labels"))
x_test = dataset_test.tensor(hyper["model"]["config"]["inputs"])
y_test = np.array(dataset_test.get("graph_labels"))
# Compile model with optimizer and loss from hyperparameter.
# The metrics from this script is added to the hyperparameter entry for metrics.
model.compile(**hyper.compile(metrics=scaled_metrics))
# Build model with reasonable data.
model.predict(x_test, batch_size=2, steps=2)
model._compile_metrics.build(y_test, y_test)
model._compile_loss.build(y_test, y_test)
# Model summary
model.summary()
print(" Compiled with jit: %s" % model._jit_compile) # noqa
print(" Model is built: %s, with unbuilt: %s" % (
all([layer.built for layer in model._flatten_layers()]), # noqa
[layer.name for layer in model._flatten_layers() if not layer.built]
))
# Run keras model-fit and take time for training.
start = time.time()
hist = model.fit(
x_train, y_train,
validation_data=(x_test, y_test),
**hyper.fit()
)
stop = time.time()
print("Print Time for training: '%s'." % str(timedelta(seconds=stop - start)))
# Save history for this fold.
save_pickle_file(hist.history, os.path.join(filepath, f"history{postfix_file}_fold_{current_split}.pickle"))
save_pickle_file(str(timedelta(seconds=stop - start)),
os.path.join(filepath, f"time{postfix_file}_fold_{current_split}.pickle"))
# Plot prediction for the last split.
# Note that predicted values will not be rescaled.
predicted_y = model.predict(x_test)
true_y = y_test
# Plotting the prediction vs. true test targets for last split. Note for classification this is also done but
# can be ignored.
plot_predict_true(predicted_y, true_y,
filepath=filepath, data_unit=label_units,
model_name=hyper.model_name, dataset_name=hyper.dataset_class, target_names=label_names,
file_name=f"predict{postfix_file}_fold_{current_split}.png", show_fig=False,
scaled_predictions=scaled_predictions)
# Save last keras-model to output-folder.
model.save(os.path.join(filepath, f"model{postfix_file}_fold_{current_split}.keras"))
# Save last keras-model to output-folder.
model.save_weights(os.path.join(filepath, f"model{postfix_file}_fold_{current_split}.weights.h5"))
# Plot training- and test-loss vs epochs for all splits.
history_list = load_history_list(os.path.join(filepath, f"history{postfix_file}_fold_(i).pickle"), current_split + 1)
plot_train_test_loss(history_list, loss_name=None, val_loss_name=None,
model_name=hyper.model_name, data_unit=label_units, dataset_name=hyper.dataset_class,
filepath=filepath, file_name=f"loss{postfix_file}.png")
# Save original data indices of the splits.
np.savez(os.path.join(filepath, f"{hyper.model_name}_test_indices_{postfix_file}.npz"), *test_indices_all)
np.savez(os.path.join(filepath, f"{hyper.model_name}_train_indices_{postfix_file}.npz"), *train_indices_all)
# Save hyperparameter again, which were used for this fit. Format is '.json'
# If non-serialized parameters were in the hyperparameter config file, this operation may fail.
hyper.save(os.path.join(filepath, f"{hyper.model_name}_hyper{postfix_file}.json"))
# Save score of fit result for as text file.
time_list = load_time_list(os.path.join(filepath, f"time{postfix_file}_fold_(i).pickle"), current_split + 1)
save_history_score(
history_list, loss_name=None, val_loss_name=None,
model_name=hyper.model_name, data_unit=label_units, dataset_name=hyper.dataset_class,
model_class=hyper.model_class,
multi_target_indices=hyper["training"]["multi_target_indices"] if "multi_target_indices" in hyper[
"training"] else None,
execute_folds=execute_folds,
model_version=model.__kgcnn_model_version__ if hasattr(model, "__kgcnn_model_version__") else "",
filepath=filepath, file_name=f"score{postfix_file}.yaml", time_list=time_list,
seed=args["seed"]
)