-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
executable file
·69 lines (50 loc) · 2.44 KB
/
evaluate.py
File metadata and controls
executable file
·69 lines (50 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python3
"""
DSLR - Model Evaluation
Script to train and evaluate one-vs-all logistic regression
on ground truth dataset - dataset_truth.csv
Usage: python evaluate.py [options]
"""
import argparse
import os
import numpy as np
import pandas as pd
from logreg_predict import predict
from logreg_train import train
def accuracy_score(y_true, y_pred) -> float:
"""Calculate the accuracy of predictions"""
return sum(y_pred == y_true) / len(y_true)
def evaluate(train_path, test_path, truth_path, weights_path, output_folder, config_path, v=False):
"""Evaluate the model by training on the training set and predicting on the test set."""
try:
print("Training:")
train(train_path, weights_path, config_path, v)
print("Predicting:")
predict(test_path, weights_path, output_folder, config_path)
pred = pd.read_csv(os.path.join(output_folder, "houses.csv"))
true = pd.read_csv(truth_path)
y_pred = pred["Hogwarts House"]
y_true = true["Hogwarts House"]
print("Wrong predictions:", np.sum(y_true != y_pred))
print("Accuracy:", np.round(accuracy_score(y_true, y_pred), 4))
except Exception as e:
print(f"Error: {e}")
def main():
"""Main function to parse arguments and run evaluation."""
parser = argparse.ArgumentParser()
parser.add_argument("--train_path", type=str, default="datasets/dataset_train.csv", help='Path to "dataset_train.csv" file')
parser.add_argument("--test_path", type=str, default="datasets/dataset_test.csv", help='Path to "dataset_test.csv" file')
parser.add_argument("--truth_path", type=str, default="datasets/dataset_truth.csv", help='Path to "dataset_truth.csv" file')
parser.add_argument("--weights_path", type=str, default="datasets/weight.json", help="Path to save weights file")
parser.add_argument("--output_folder", type=str, default="data", help="Path to folder where to save houses.csv")
parser.add_argument("--config_path", type=str, default="config.yaml", help="Path to .yaml file")
parser.add_argument("-v", action="store_true", help="visualize training")
args = parser.parse_args()
try:
evaluate(args.train_path, args.test_path, args.truth_path, args.weights_path, args.output_folder, args.config_path, args.v)
except KeyboardInterrupt:
print("\nProcess interrupted by user.")
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()