-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
60 lines (43 loc) · 1.76 KB
/
test.py
File metadata and controls
60 lines (43 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import os
from utils import (
collect_env_info,
evaluate,
get_logger,
get_rank,
get_world_size,
init_logger_with_buffer,
instantiate,
load_package_from_file,
load_state_dict,
setup_project_logger,
)
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Test a detector")
parser.add_argument("config_file", default="configs/train_config.py")
parser.add_argument("--checkpoint", default=None)
args = parser.parse_args()
return args
def test():
args = parse_args()
# before we load output_dir form config, temporarily use a buffer to save log
init_logger_with_buffer()
rank = get_rank()
logger.info("Rank of current process: {}, World size: {}".format(rank, get_world_size()))
logger.info("Environment info: \n" + collect_env_info())
config = load_package_from_file(args.config_file)
# Initialize accelerator
# config.trainer.accelerator.mixed_precision = "no" # disable mixed precision for testing
# config.trainer.accelerator.use_fsdp = False # disable fsdp for testing
accelerator = instantiate(config.trainer.accelerator)
output_dir = accelerator.project_configuration.project_dir
train_log_file = os.path.join(output_dir, "test.log" if rank == 0 else f"test_rank{rank}.log")
setup_project_logger(train_log_file)
model = instantiate(config.model)
if args.checkpoint is not None:
model = load_state_dict(model, args.checkpoint, map_location="cpu")
model, test_loader = accelerator.prepare(model, instantiate(config.dataloader.test))
evaluate(model, test_loader, instantiate(config.dataloader.evaluator), accelerator, instantiate(config.processor))
if __name__ == "__main__":
test()