diff --git a/scripts/tests/test_train.py b/scripts/tests/test_train.py index 89c6e431..a6474a24 100644 --- a/scripts/tests/test_train.py +++ b/scripts/tests/test_train.py @@ -219,7 +219,7 @@ def test_fit(self) -> None: iters = 5 out_span = 2 scores = train.fit(dataset, dataset, features, iters, weights_file_path, - log_file_path, out_span) + log_file_path, out_span, None) with open(weights_file_path) as f: weights = [ line.split('\t') for line in f.read().splitlines() if line.strip() diff --git a/scripts/train.py b/scripts/train.py index 4f6a148d..51cf2450 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp +from torch.utils.tensorboard import SummaryWriter EPS = float(jnp.finfo(float).eps) DEFAULT_OUTPUT_NAME = 'weights.txt' @@ -233,7 +234,8 @@ def update(w: jax.Array, scores: jax.Array, rows: jax.Array, cols: jax.Array, def fit(dataset_train: Dataset, dataset_val: typing.Optional[Dataset], features: typing.List[str], iters: int, weights_filename: str, - log_filename: str, out_span: int) -> jax.Array: + log_filename: str, out_span: int, + tensorboard_log_dir: typing.Optional[str]) -> jax.Array: """Trains an AdaBoost binary classifier. Args: @@ -244,10 +246,14 @@ def fit(dataset_train: Dataset, dataset_val: typing.Optional[Dataset], weights_filename (str): A file path to write the learned weights. log_filename (str): A file path to log the accuracy along with training. out_span (int): Iteration span to output metics and weights. + tensorboard_log_dir (Optional[str]): A file path to log data for + TensorBoard. Returns: scores (jax.Array): The contribution scores. """ + writer = SummaryWriter( + log_dir=tensorboard_log_dir) if tensorboard_log_dir else None with open(weights_filename, 'w') as f: f.write('') with open(log_filename, 'w') as f: @@ -290,6 +296,11 @@ def output_progress(t: int) -> None: metrics_train.recall, metrics_train.fscore, )) + if writer: + writer.add_scalar('accuracy/train', metrics_train.accuracy, t) + writer.add_scalar('precision/train', metrics_train.precision, t) + writer.add_scalar('recall/train', metrics_train.recall, t) + writer.add_scalar('fscore/train', metrics_train.fscore, t) if dataset_val: pred_test = pred(scores, dataset_val.X_rows, dataset_val.X_cols, N_test) @@ -306,8 +317,14 @@ def output_progress(t: int) -> None: metrics_test.recall, metrics_test.fscore, )) - + if writer: + writer.add_scalar('accuracy/test', metrics_test.accuracy, t) + writer.add_scalar('precision/test', metrics_test.precision, t) + writer.add_scalar('recall/test', metrics_test.recall, t) + writer.add_scalar('fscore/test', metrics_test.fscore, t) f.write('\n') + if writer: + writer.add_histogram('weight', w, t) for t in range(iters): w, scores, best_feature_index, score = update(w, scores, @@ -320,6 +337,8 @@ def output_progress(t: int) -> None: output_progress(t + 1) if len(feature_score_buffer) > 0: output_progress(t + 1) + if writer: + writer.close() return scores @@ -364,6 +383,11 @@ def parse_args(test: ArgList = None) -> argparse.Namespace: default=DEFAULT_OUT_SPAN) parser.add_argument( '--val-data', help='File path for the encoded validation data.', type=str) + parser.add_argument( + '--tensorboard', + help='Log directory for TensorBoard.', + type=str, + default=None) if test is None: return parser.parse_args() else: @@ -379,11 +403,12 @@ def main() -> None: iterations = int(args.iter) out_span = int(args.out_span) val_data: typing.Optional[str] = args.val_data + tensorboard_log_dir: typing.Optional[str] = args.tensorboard dataset_train, features, dataset_val = preprocess(data_filename, feature_thres, val_data) fit(dataset_train, dataset_val, features, iterations, weights_filename, - log_filename, out_span) + log_filename, out_span, tensorboard_log_dir) print('Training done. Export the model by passing %s to build_model.py' % (weights_filename)) diff --git a/setup.cfg b/setup.cfg index f926a3e5..56eaac4d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,9 @@ dev = mypy==1.18.2 pytest regex + tensorboard toml + torch twine types-regex types-setuptools