Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_fit(self) -> None:
iters = 5
out_span = 2
scores = train.fit(dataset, dataset, features, iters, weights_file_path,
log_file_path, out_span)
log_file_path, out_span, None)
with open(weights_file_path) as f:
weights = [
line.split('\t') for line in f.read().splitlines() if line.strip()
Expand Down
31 changes: 28 additions & 3 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import jax
import jax.numpy as jnp
from torch.utils.tensorboard import SummaryWriter

EPS = float(jnp.finfo(float).eps)
DEFAULT_OUTPUT_NAME = 'weights.txt'
Expand Down Expand Up @@ -233,7 +234,8 @@ def update(w: jax.Array, scores: jax.Array, rows: jax.Array, cols: jax.Array,

def fit(dataset_train: Dataset, dataset_val: typing.Optional[Dataset],
features: typing.List[str], iters: int, weights_filename: str,
log_filename: str, out_span: int) -> jax.Array:
log_filename: str, out_span: int,
tensorboard_log_dir: typing.Optional[str]) -> jax.Array:
"""Trains an AdaBoost binary classifier.

Args:
Expand All @@ -244,10 +246,14 @@ def fit(dataset_train: Dataset, dataset_val: typing.Optional[Dataset],
weights_filename (str): A file path to write the learned weights.
log_filename (str): A file path to log the accuracy along with training.
out_span (int): Iteration span to output metics and weights.
tensorboard_log_dir (Optional[str]): A file path to log data for
TensorBoard.

Returns:
scores (jax.Array): The contribution scores.
"""
writer = SummaryWriter(
log_dir=tensorboard_log_dir) if tensorboard_log_dir else None
with open(weights_filename, 'w') as f:
f.write('')
with open(log_filename, 'w') as f:
Expand Down Expand Up @@ -290,6 +296,11 @@ def output_progress(t: int) -> None:
metrics_train.recall,
metrics_train.fscore,
))
if writer:
writer.add_scalar('accuracy/train', metrics_train.accuracy, t)
writer.add_scalar('precision/train', metrics_train.precision, t)
writer.add_scalar('recall/train', metrics_train.recall, t)
writer.add_scalar('fscore/train', metrics_train.fscore, t)

if dataset_val:
pred_test = pred(scores, dataset_val.X_rows, dataset_val.X_cols, N_test)
Expand All @@ -306,8 +317,14 @@ def output_progress(t: int) -> None:
metrics_test.recall,
metrics_test.fscore,
))

if writer:
writer.add_scalar('accuracy/test', metrics_test.accuracy, t)
writer.add_scalar('precision/test', metrics_test.precision, t)
writer.add_scalar('recall/test', metrics_test.recall, t)
writer.add_scalar('fscore/test', metrics_test.fscore, t)
f.write('\n')
if writer:
writer.add_histogram('weight', w, t)

for t in range(iters):
w, scores, best_feature_index, score = update(w, scores,
Expand All @@ -320,6 +337,8 @@ def output_progress(t: int) -> None:
output_progress(t + 1)
if len(feature_score_buffer) > 0:
output_progress(t + 1)
if writer:
writer.close()
return scores


Expand Down Expand Up @@ -364,6 +383,11 @@ def parse_args(test: ArgList = None) -> argparse.Namespace:
default=DEFAULT_OUT_SPAN)
parser.add_argument(
'--val-data', help='File path for the encoded validation data.', type=str)
parser.add_argument(
'--tensorboard',
help='Log directory for TensorBoard.',
type=str,
default=None)
if test is None:
return parser.parse_args()
else:
Expand All @@ -379,11 +403,12 @@ def main() -> None:
iterations = int(args.iter)
out_span = int(args.out_span)
val_data: typing.Optional[str] = args.val_data
tensorboard_log_dir: typing.Optional[str] = args.tensorboard

dataset_train, features, dataset_val = preprocess(data_filename,
feature_thres, val_data)
fit(dataset_train, dataset_val, features, iterations, weights_filename,
log_filename, out_span)
log_filename, out_span, tensorboard_log_dir)
print('Training done. Export the model by passing %s to build_model.py' %
(weights_filename))

Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ dev =
mypy==1.18.2
pytest
regex
tensorboard
toml
torch
twine
types-regex
types-setuptools
Expand Down