-
Notifications
You must be signed in to change notification settings - Fork 12.2k
/
Copy patheval.py
65 lines (52 loc) · 1.81 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Evaluation job for the heavy ranker of the push notification service.
"""
from datetime import datetime
from tensorflow.compat.v1 import logging
import twml
from ..libs.metric_fn_utils import get_metric_fn
from ..libs.model_utils import read_config
from .features import get_feature_config
from .model_pools import ALL_MODELS
from .params import load_graph_params
from .run_args import get_eval_arg_parser
def main():
args, _ = get_eval_arg_parser().parse_known_args()
logging.info(f"Parsed args: {args}")
params = load_graph_params(args)
logging.info(f"Loaded graph params: {params}")
logging.info(f"Get Feature Config: {args.feature_list}")
feature_list = read_config(args.feature_list).items()
feature_config = get_feature_config(
data_spec_path=args.data_spec,
params=params,
feature_list_provided=feature_list,
)
logging.info("Build DataRecordTrainer.")
metric_fn = get_metric_fn(
"OONC_Engagement" if len(params.tasks) == 2 else "OONC", False
)
trainer = twml.trainers.DataRecordTrainer(
name="magic_recs",
params=args,
build_graph_fn=lambda *args: ALL_MODELS[params.model.name](params=params)(
*args
),
save_dir=args.save_dir,
run_config=None,
feature_config=feature_config,
metric_fn=metric_fn,
)
logging.info("Run the evaluation.")
start = datetime.now()
trainer._estimator.evaluate(
input_fn=trainer.get_eval_input_fn(repeat=False, shuffle=False),
steps=None
if (args.eval_steps is not None and args.eval_steps < 0)
else args.eval_steps,
checkpoint_path=args.eval_checkpoint,
)
logging.info(f"Evaluating time: {datetime.now() - start}.")
if __name__ == "__main__":
main()
logging.info("Job done.")