Skip to content

Commit

Permalink
fix(calculator): fix mis-named variable
Browse files Browse the repository at this point in the history
  • Loading branch information
bi-ran committed Feb 8, 2025
1 parent 2ab2c4d commit 8f1138e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/AIMD/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def init(argv=None):
torch.cuda.device_count(),
_args.chunk_size,
)
_args.mm_method = strategy_feedback['preprocess-method']
_args.mm_method = strategy_feedback['mm-method']

return _args

10 changes: 5 additions & 5 deletions src/Calculators/device_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_work_partitions(cls):


@classmethod
def initialize(cls, dev_strategy: str, work_strategy: str, preprocess_method: str, gpu_count: int, chunk_size: int):
def initialize(cls, dev_strategy: str, work_strategy: str, mm_method: str, gpu_count: int, chunk_size: int):
cls._gpu_count = gpu_count
cls._dev_strategy = dev_strategy
cls._work_strategy = work_strategy
Expand All @@ -160,7 +160,7 @@ def initialize(cls, dev_strategy: str, work_strategy: str, preprocess_method: st
default = "cpu" if gpu_count == 0 else "cuda:0"
optimiser = "cpu"

if preprocess_method == "tinker-GPU" and gpu_count > 0:
if mm_method == "tinker-GPU" and gpu_count > 0:
preprocess = f"cuda:{last_gpu}"
else:
preprocess = "cpu"
Expand Down Expand Up @@ -216,11 +216,11 @@ def initialize(cls, dev_strategy: str, work_strategy: str, preprocess_method: st
# run bonded/non-bonded calculations concurrently
cls._fragment_strategy = True

if preprocess_method == "tinker-GPU":
if mm_method == "tinker-GPU":
if len(solvent) == 0:
logging.error("tinker-GPU is specified, but there's no GPU. Reverting back to CPU.")
solvent = ["cpu"]
preprocess_method = "tinker"
mm_method = "tinker"
else:
solvent = ["cpu"]

Expand Down Expand Up @@ -262,4 +262,4 @@ def initialize(cls, dev_strategy: str, work_strategy: str, preprocess_method: st
torch_threads = max(1, total_threads // total_models)
torch.set_num_threads(torch_threads)

return { 'preprocess-method': preprocess_method }
return { 'mm-method': mm_method }

0 comments on commit 8f1138e

Please sign in to comment.