-
Notifications
You must be signed in to change notification settings - Fork 589
refactor: unify learning rate schedulers with array API #5154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
3c650bb
823061c
5d574de
ec0d270
cb9a11b
0ea7387
c42a5e4
69aaa81
c31d001
65a44dd
519440f
f2e3888
d97855b
cb2e4d1
b687e36
75b81fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
|
|
||
| from .env import ( | ||
| DEVICE, | ||
| GLOBAL_NP_FLOAT_PRECISION, | ||
| ) | ||
| from .env import PRECISION_DICT as PT_PRECISION_DICT | ||
|
|
||
|
|
@@ -227,18 +228,22 @@ def to_numpy_array(xx: None) -> None: ... | |
|
|
||
|
|
||
| def to_numpy_array( | ||
| xx: torch.Tensor | None, | ||
| xx: torch.Tensor | np.ndarray | float | None, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to update the overload method in line 223. |
||
| ) -> np.ndarray | None: | ||
| if xx is None: | ||
| return None | ||
| assert xx is not None | ||
| if isinstance(xx, (float, int)): | ||
| return np.array(xx, dtype=GLOBAL_NP_FLOAT_PRECISION) | ||
| if isinstance(xx, np.ndarray): | ||
| return xx.astype(GLOBAL_NP_FLOAT_PRECISION) | ||
OutisLi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Create a reverse mapping of PT_PRECISION_DICT | ||
| reverse_precision_dict = {v: k for k, v in PT_PRECISION_DICT.items()} | ||
| # Use the reverse mapping to find keys with the desired value | ||
| prec = reverse_precision_dict.get(xx.dtype, None) | ||
| prec = NP_PRECISION_DICT.get(prec, None) | ||
| if prec is None: | ||
| raise ValueError(f"unknown precision {xx.dtype}") | ||
| assert isinstance(xx, torch.Tensor) | ||
| if xx.dtype == torch.bfloat16: | ||
| # https://github.com/pytorch/pytorch/issues/109873 | ||
| xx = xx.float() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,9 @@ | |
| import os | ||
| import shutil | ||
| import time | ||
| from typing import ( | ||
| Any, | ||
| ) | ||
|
|
||
| import google.protobuf.message | ||
| import numpy as np | ||
|
|
@@ -52,7 +55,7 @@ | |
| load_graph_def, | ||
| ) | ||
| from deepmd.tf.utils.learning_rate import ( | ||
| LearningRateExp, | ||
| LearningRateSchedule, | ||
| ) | ||
| from deepmd.tf.utils.sess import ( | ||
| run_sess, | ||
|
|
@@ -100,21 +103,18 @@ def _init_param(self, jdata) -> None: | |
| self.model = Model(**model_param) | ||
| self.fitting = self.model.get_fitting() | ||
|
|
||
| def get_lr_and_coef(lr_param): | ||
| def get_lr_and_coef( | ||
| lr_param: dict[str, Any], | ||
| ) -> tuple[LearningRateSchedule, float]: | ||
| scale_by_worker = lr_param.get("scale_by_worker", "linear") | ||
| if scale_by_worker == "linear": | ||
| scale_lr_coef = float(self.run_opt.world_size) | ||
| elif scale_by_worker == "sqrt": | ||
| scale_lr_coef = np.sqrt(self.run_opt.world_size).real | ||
| else: | ||
| scale_lr_coef = 1.0 | ||
| lr_type = lr_param.get("type", "exp") | ||
| if lr_type == "exp": | ||
| lr = LearningRateExp( | ||
| lr_param["start_lr"], lr_param["stop_lr"], lr_param["decay_steps"] | ||
| ) | ||
| else: | ||
| raise RuntimeError("unknown learning_rate type " + lr_type) | ||
| lr_params = {k: v for k, v in lr_param.items() if k != "scale_by_worker"} | ||
| lr = LearningRateSchedule(lr_params) | ||
| return lr, scale_lr_coef | ||
OutisLi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # learning rate | ||
|
|
@@ -242,8 +242,13 @@ def build(self, data=None, stop_batch=0, origin_type_map=None, suffix="") -> Non | |
| def _build_lr(self) -> None: | ||
| self._extra_train_ops = [] | ||
| self.global_step = tf.train.get_or_create_global_step() | ||
| self.learning_rate = self.lr.build(self.global_step, self.stop_batch) | ||
| log.info("built lr") | ||
| if self.stop_batch == 0: | ||
| # Use constant start_lr when stop_batch is zero (no training) | ||
| self.learning_rate = tf.cast(self.lr.start_lr(), GLOBAL_TF_FLOAT_PRECISION) | ||
| log.info("built lr (constant start_lr for stop_batch=0)") | ||
OutisLi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| self.learning_rate = self.lr.build(self.global_step, self.stop_batch) | ||
| log.info("built lr") | ||
|
Comment on lines
+245
to
+251
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does it need a if...else block? |
||
|
|
||
| def _build_loss(self): | ||
| if self.stop_batch == 0: | ||
|
|
@@ -426,14 +431,21 @@ def train(self, train_data=None, valid_data=None) -> None: | |
| elapsed_batch = stop_batch - start_batch | ||
| is_first_step = True | ||
| self.cur_batch = cur_batch | ||
| log.info( | ||
| "start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e", | ||
| run_sess(self.sess, self.learning_rate), | ||
| self.lr.value(cur_batch), | ||
| self.lr.decay_steps_, | ||
| self.lr.decay_rate_, | ||
| self.lr.value(stop_batch), | ||
| ) | ||
| if stop_batch == 0: | ||
| lr0 = self.lr.start_lr() | ||
| log.info( | ||
| "start training at lr %.2e (== %.2e), final lr will be %.2e", | ||
| run_sess(self.sess, self.learning_rate), | ||
| lr0, | ||
| lr0, | ||
| ) | ||
| else: | ||
| log.info( | ||
| "start training at lr %.2e (== %.2e), final lr will be %.2e", | ||
| run_sess(self.sess, self.learning_rate), | ||
| self.lr.value(cur_batch), | ||
| self.lr.value(stop_batch), | ||
| ) | ||
|
Comment on lines
+434
to
+448
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does it need a |
||
|
|
||
| prf_options = None | ||
| prf_run_metadata = None | ||
|
|
@@ -797,7 +809,7 @@ def _get_place_holders(self, data_dict) -> None: | |
| prec = GLOBAL_ENER_FLOAT_PRECISION | ||
| self.place_holders[kk] = tf.placeholder(prec, [None], name="t_" + kk) | ||
| self.place_holders["find_" + kk] = tf.placeholder( | ||
| tf.float32, name="t_find_" + kk | ||
| GLOBAL_TF_FLOAT_PRECISION, name="t_find_" + kk | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the reason to conduct this change in this PR? |
||
| ) | ||
|
|
||
| def _init_from_frz_model(self) -> None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: deepmodeling/deepmd-kit
Length of output: 128
🏁 Script executed:
Repository: deepmodeling/deepmd-kit
Length of output: 1144
🏁 Script executed:
Repository: deepmodeling/deepmd-kit
Length of output: 1759
🏁 Script executed:
Repository: deepmodeling/deepmd-kit
Length of output: 153
🏁 Script executed:
Repository: deepmodeling/deepmd-kit
Length of output: 2883
🏁 Script executed:
Repository: deepmodeling/deepmd-kit
Length of output: 49
🏁 Script executed:
Repository: deepmodeling/deepmd-kit
Length of output: 98
🏁 Script executed:
Repository: deepmodeling/deepmd-kit
Length of output: 2464
🏁 Script executed:
python3 -c "import numpy as np; scalar = np.float64(1.0); print('has dtype:', hasattr(scalar, 'dtype')); print('has numpy():', hasattr(scalar, 'numpy')); print('is np.generic:', isinstance(scalar, np.generic))"Repository: deepmodeling/deepmd-kit
Length of output: 185
🏁 Script executed:
cat -n source/tests/pd/test_utils.py | head -100Repository: deepmodeling/deepmd-kit
Length of output: 1323
🏁 Script executed:
rg -A 10 -B 2 "to_numpy_array" source/tests/pd/test_utils.pyRepository: deepmodeling/deepmd-kit
Length of output: 686
🏁 Script executed:
rg -A 15 "test_to_numpy" source/tests/pd/test_utils.pyRepository: deepmodeling/deepmd-kit
Length of output: 684
🏁 Script executed:
cat -n source/tests/pd/test_utils.py | tail -20Repository: deepmodeling/deepmd-kit
Length of output: 992
🏁 Script executed:
rg "to_numpy_array.*np\." deepmd/pd --type py -A 2Repository: deepmodeling/deepmd-kit
Length of output: 200
🏁 Script executed:
Repository: deepmodeling/deepmd-kit
Length of output: 274
🏁 Script executed:
rg "np\.generic" deepmd --type pyRepository: deepmodeling/deepmd-kit
Length of output: 49
🏁 Script executed:
Repository: deepmodeling/deepmd-kit
Length of output: 717
🏁 Script executed:
Repository: deepmodeling/deepmd-kit
Length of output: 1390
🏁 Script executed:
rg "GLOBAL_NP_FLOAT_PRECISION" deepmd/pd/utils/utils.py -B 5 -A 5Repository: deepmodeling/deepmd-kit
Length of output: 664
Handle NumPy scalar inputs in the scalar fast‑path.
Line 261 currently excludes NumPy scalar types (e.g.,
np.float32,np.float64), which will fall through and hitxx.numpy()at line 274, raisingAttributeError. Addnp.genericto catch all NumPy scalar types.🛠️ Suggested fix
📝 Committable suggestion
🤖 Prompt for AI Agents