|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +import importlib |
| 3 | +import logging |
| 4 | +import multiprocessing |
| 5 | +import os |
| 6 | +import sys |
| 7 | + |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +from deepmd.common import ( |
| 11 | + VALID_PRECISION, |
| 12 | +) |
| 13 | +from deepmd.env import ( |
| 14 | + GLOBAL_ENER_FLOAT_PRECISION, |
| 15 | + GLOBAL_NP_FLOAT_PRECISION, |
| 16 | + get_default_nthreads, |
| 17 | + set_default_nthreads, |
| 18 | +) |
| 19 | + |
| 20 | +log = logging.getLogger(__name__) |
| 21 | +torch = importlib.import_module("torch") |
| 22 | + |
| 23 | +if sys.platform != "win32": |
| 24 | + try: |
| 25 | + multiprocessing.set_start_method("fork", force=True) |
| 26 | + log.debug("Successfully set multiprocessing start method to 'fork'.") |
| 27 | + except (RuntimeError, ValueError) as err: |
| 28 | + log.warning(f"Could not set multiprocessing start method: {err}") |
| 29 | +else: |
| 30 | + log.debug("Skipping fork start method on Windows (not supported).") |
| 31 | + |
| 32 | +SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) |
| 33 | +DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1" |
| 34 | +try: |
| 35 | + # only linux |
| 36 | + ncpus = len(os.sched_getaffinity(0)) |
| 37 | +except AttributeError: |
| 38 | + ncpus = os.cpu_count() |
| 39 | +NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus))) |
| 40 | +if multiprocessing.get_start_method() != "fork": |
| 41 | + # spawn or forkserver does not support NUM_WORKERS > 0 for DataLoader |
| 42 | + log.warning( |
| 43 | + "NUM_WORKERS > 0 is not supported with spawn or forkserver start method. " |
| 44 | + "Setting NUM_WORKERS to 0." |
| 45 | + ) |
| 46 | + NUM_WORKERS = 0 |
| 47 | + |
| 48 | +# Make sure DDP uses correct device if applicable |
| 49 | +LOCAL_RANK = os.environ.get("LOCAL_RANK") |
| 50 | +LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK) |
| 51 | + |
| 52 | +if os.environ.get("DEVICE") == "cpu" or torch.cuda.is_available() is False: |
| 53 | + DEVICE = torch.device("cpu") |
| 54 | +else: |
| 55 | + DEVICE = torch.device(f"cuda:{LOCAL_RANK}") |
| 56 | + |
| 57 | +JIT = False |
| 58 | +CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory |
| 59 | +ENERGY_BIAS_TRAINABLE = True |
| 60 | +CUSTOM_OP_USE_JIT = False |
| 61 | + |
| 62 | +PRECISION_DICT = { |
| 63 | + "float16": torch.float16, |
| 64 | + "float32": torch.float32, |
| 65 | + "float64": torch.float64, |
| 66 | + "half": torch.float16, |
| 67 | + "single": torch.float32, |
| 68 | + "double": torch.float64, |
| 69 | + "int32": torch.int32, |
| 70 | + "int64": torch.int64, |
| 71 | + "bfloat16": torch.bfloat16, |
| 72 | + "bool": torch.bool, |
| 73 | +} |
| 74 | +GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name] |
| 75 | +GLOBAL_PT_ENER_FLOAT_PRECISION = PRECISION_DICT[ |
| 76 | + np.dtype(GLOBAL_ENER_FLOAT_PRECISION).name |
| 77 | +] |
| 78 | +PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION |
| 79 | +assert VALID_PRECISION.issubset(PRECISION_DICT.keys()) |
| 80 | +# cannot automatically generated |
| 81 | +RESERVED_PRECISION_DICT = { |
| 82 | + torch.float16: "float16", |
| 83 | + torch.float32: "float32", |
| 84 | + torch.float64: "float64", |
| 85 | + torch.int32: "int32", |
| 86 | + torch.int64: "int64", |
| 87 | + torch.bfloat16: "bfloat16", |
| 88 | + torch.bool: "bool", |
| 89 | +} |
| 90 | +assert set(PRECISION_DICT.values()) == set(RESERVED_PRECISION_DICT.keys()) |
| 91 | +DEFAULT_PRECISION = "float64" |
| 92 | + |
| 93 | +# throw warnings if threads not set |
| 94 | +set_default_nthreads() |
| 95 | +inter_nthreads, intra_nthreads = get_default_nthreads() |
| 96 | +if inter_nthreads > 0: # the behavior of 0 is not documented |
| 97 | + torch.set_num_interop_threads(inter_nthreads) |
| 98 | +if intra_nthreads > 0: |
| 99 | + torch.set_num_threads(intra_nthreads) |
| 100 | + |
| 101 | +__all__ = [ |
| 102 | + "CACHE_PER_SYS", |
| 103 | + "CUSTOM_OP_USE_JIT", |
| 104 | + "DEFAULT_PRECISION", |
| 105 | + "DEVICE", |
| 106 | + "ENERGY_BIAS_TRAINABLE", |
| 107 | + "GLOBAL_ENER_FLOAT_PRECISION", |
| 108 | + "GLOBAL_NP_FLOAT_PRECISION", |
| 109 | + "GLOBAL_PT_ENER_FLOAT_PRECISION", |
| 110 | + "GLOBAL_PT_FLOAT_PRECISION", |
| 111 | + "JIT", |
| 112 | + "LOCAL_RANK", |
| 113 | + "NUM_WORKERS", |
| 114 | + "PRECISION_DICT", |
| 115 | + "RESERVED_PRECISION_DICT", |
| 116 | + "SAMPLER_RECORD", |
| 117 | +] |
0 commit comments