Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/benchmarks/TRA/src/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(
self.logger = get_module_logger("TRA")
self.logger.info("TRA Model...")

self.model = eval(model_type)(**model_config).to(device)
self.model = ast.literal_eval(model_type)(**model_config).to(device)
if model_init_state:
self.model.load_state_dict(torch.load(model_init_state, map_location="cpu")["model"])
if freeze_model:
Expand Down