feat: new backend pytorch exportable.#5194
feat: new backend pytorch exportable.#5194wanghan-iapcm wants to merge 27 commits intodeepmodeling:masterfrom
Conversation
Summary of ChangesHello @wanghan-iapcm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the DeepMD framework by integrating a new PyTorch exportable backend. This integration allows DeepMD models, starting with the 'se_e2_a' descriptor, to leverage PyTorch's capabilities for model export and deployment. The changes involve creating PyTorch-specific implementations of core components like descriptors, network layers, and utility functions, ensuring seamless operation within a PyTorch environment. The addition of this backend broadens the interoperability of DeepMD models with other machine learning tools and workflows. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new pytorch-exportable backend, a significant feature for model deployment. The implementation is comprehensive, including the backend definition, a new se_e2_a descriptor, and thorough unit tests. The code is generally well-structured and follows existing patterns. I've identified a couple of potential issues concerning PyTorch buffer registration within __setattr__ methods, which could affect model state management. Additionally, I've pointed out an unused variable and a minor inconsistency in the test suite that could be improved. Overall, this is a valuable contribution.
| DescrptSeADP.__init__(self, *args, **kwargs) | ||
| self._convert_state() | ||
|
|
||
| def __setattr__(self, name: str, value: Any) -> None: |
Check notice
Code scanning / CodeQL
Explicit returns mixed with implicit (fall through) returns Note
| log.debug("Skipping fork start method on Windows (not supported).") | ||
|
|
||
| SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) | ||
| DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1" |
Check notice
Code scanning / CodeQL
Unused global variable Note
| delattr(self, name) | ||
| self.register_buffer(name, val) | ||
|
|
||
| def __setattr__(self, name: str, value: Any) -> None: |
Check notice
Code scanning / CodeQL
Explicit returns mixed with implicit (fall through) returns Note
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1cc001f7f2
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
📝 WalkthroughWalkthroughAdds a new PyTorch Exportable backend ("pt-expt") and a PyTorch-based exportable descriptor/network package, makes device-aware fixes in DP descriptors and env/thread handling, and adds tests and test-harness wiring for the experimental PT backend. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant Backend
participant Descriptor
participant Network
participant Env
Client->>Backend: register/select "pt-expt"
Backend->>Descriptor: expose descriptor class & hooks
Client->>Descriptor: instantiate / load state (serialize/deserialize)
Descriptor->>Network: construct networks & params
Network->>Env: allocate Parameters/Buffers on DEVICE (rgba(0,128,0,0.5))
Client->>Descriptor: call(coords, atype, nlist)
Descriptor->>Network: run embedding & layers
Network->>Env: ensure intermediates on DEVICE (rgba(0,0,255,0.5))
Descriptor-->>Client: return (descrpt, rot_mat, g2, h2, sw)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@deepmd/backend/pt_expt.py`:
- Line 44: The PT_EXPT backend defines suffixes: ClassVar[list[str]] = [".pth",
".pt"] which collides with the standard PyTorch backend and makes
Backend.detect_backend_by_model() choose backends non‑deterministically; update
the PT_EXPT backend’s suffixes (the suffixes attribute in the PT_EXPT class in
deepmd/backend/pt_expt.py) to use distinct extensions (for example [".pte"] or
another unique token) or add an explicit detection mechanism in
detect_backend_by_model to prefer PT_EXPT for files with your chosen marker;
ensure any references to suffixes in methods like load_model/export_model are
adjusted to the new extension.
In `@deepmd/pt_expt/descriptor/se_e2_a.py`:
- Around line 80-102: The forward signature currently has parameters in the
wrong order causing torch.export.export to pass (coord, atype, nlist) into
forward as (nlist, extended_coord, extended_atype); change forward(...) to match
call(...)'s order: (extended_coord: torch.Tensor, extended_atype: torch.Tensor,
nlist: torch.Tensor, extended_atype_embd: torch.Tensor | None = None, mapping:
torch.Tensor | None = None, type_embedding: torch.Tensor | None = None) and
update the internal call to self.call(extended_coord, extended_atype, nlist,
mapping=mapping) (keep the del of unused vars if needed) so exported models
receive inputs in the same order as call().
In `@deepmd/pt_expt/utils/env.py`:
- Around line 95-99: The variables from get_default_nthreads() are unpacked in
reverse: get_default_nthreads() returns (intra_value, inter_value) but the code
assigns them to inter_nthreads, intra_nthreads; swap the unpacking so it reads
intra_nthreads, inter_nthreads = get_default_nthreads(), leaving the subsequent
conditionals and calls to torch.set_num_interop_threads(inter_nthreads) and
torch.set_num_threads(intra_nthreads) intact.
In `@source/tests/pt_expt/model/test_se_e2_a.py`:
- Around line 66-70: The test_exportable helper passes inputs in the wrong order
to torch.export.export() causing forward() to receive (nlist, extended_coord,
extended_atype) incorrectly; update the inputs tuple in test_exportable so it
matches forward(nlist, extended_coord, extended_atype) — i.e., construct inputs
as (torch.tensor(self.nlist, ...), torch.tensor(self.coord_ext, ...),
torch.tensor(self.atype_ext, ...)) so torch.export.export()/forward() and
NativeOP.__call__/dd0 receive arguments in the correct order.
🧹 Nitpick comments (7)
source/tests/consistent/descriptor/test_se_e2_a.py (1)
37-40:eval_pt_exptinTestSeAStatdepends on PT backend'senvandtorchimports.Lines 546–566 use
torchandenv.DEVICE, which are imported at Lines 29–34 only underINSTALLED_PT. IfINSTALLED_PT_EXPTisTruebutINSTALLED_PTisFalse, this would raiseNameError. In practice both check fortorchavailability so they should always align, but consider importingtorchandDEVICEfromdeepmd.pt_expt.utils.envwithin theINSTALLED_PT_EXPTblock to decouple from the PT backend.Also applies to: 546-566
source/tests/consistent/descriptor/common.py (1)
28-35: Guard condition correctly expanded, but imports still couple to PT backend.The guard now correctly includes
INSTALLED_PT_EXPT, butPT_DEVICE(Line 31) is imported fromdeepmd.pt.utils.env. Since the PR aims to decouple pt_expt from pt, consider importingDEVICEfromdeepmd.pt_expt.utils.envwhen onlyINSTALLED_PT_EXPTis true. For test code this is acceptable, but may become fragile if the backends diverge.deepmd/pt_expt/descriptor/se_e2_a.py (1)
56-78:_convert_stateperforms redundant re-serialization ofembeddingsandemask.During
__init__,DescrptSeADP.__init__setsself.embeddings, which triggers__setattr__(Line 41–47) and already converts it to aNetworkCollectionmodule. Then_convert_state(Line 73–74) serializes and deserializes it again. The same applies toemask. This is harmless but wasteful.deepmd/pt_expt/utils/env.py (2)
32-32:SAMPLER_RECORDhas mixed-type semantics.
os.environ.get("SAMPLER_RECORD", False)returnsFalse(bool) when unset, but a string when set. Compare withDP_DTYPE_PROMOTION_STRICTon Line 33 which consistently evaluates to bool. If this mirrors the PT backend's env.py, it may be intentional for backward compatibility; otherwise consideros.environ.get("SAMPLER_RECORD", "0") == "1".
1-99: This file is largely duplicated fromdeepmd/pt/utils/env.py.The PR aim is to decouple pt_expt from pt, which necessitates this copy. However, maintaining two near-identical env configuration modules carries a long-term maintenance burden. Consider extracting shared configuration logic into a common utility that both backends can import.
deepmd/backend/pt_expt.py (1)
57-126: All hook properties delegate todeepmd.pt.*— pt_expt has no independent runtime.Every property (
entry_point_hook,deep_eval,neighbor_stat,serialize_hook,deserialize_hook) lazily imports fromdeepmd.pt.*. This means the "exportable" backend is functionally identical to the standard PT backend at inference/training/serialization time, differing only in descriptor construction. This appears intentional for bootstrapping, but should be documented to avoid confusion about what the new backend actually provides.deepmd/pt_expt/utils/network.py (1)
125-128: Stale_module_networksentry if a Module is replaced by a non-Module.
__setitem__adds the new value to_module_networkswhen it's atorch.nn.Module, but doesn't remove the old entry when the replacement is not a Module. This could leave stale sub-modules in the parameter graph.In practice this is unlikely since all networks in this backend should be
torch.nn.Moduleinstances, but adding a defensive removal would be cleaner:Proposed fix
def __setitem__(self, key: int | tuple, value: Any) -> None: super().__setitem__(key, value) + str_key = str(self._convert_key(key)) if isinstance(value, torch.nn.Module): - self._module_networks[str(self._convert_key(key))] = value + self._module_networks[str_key] = value + elif str_key in self._module_networks: + del self._module_networks[str_key]
|
|
||
|
|
||
| @overload | ||
| def to_torch_array(array: np.ndarray) -> torch.Tensor: ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
|
|
||
|
|
||
| @overload | ||
| def to_torch_array(array: None) -> None: ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
|
|
||
|
|
||
| @overload | ||
| def to_torch_array(array: torch.Tensor) -> torch.Tensor: ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@deepmd/pt_expt/common.py`:
- Around line 29-35: The function to_torch_array currently returns torch.Tensor
inputs as-is which can leave them on the wrong device; update to_torch_array to
ensure any existing torch.Tensor is moved to the target device (env.DEVICE)
before returning (use tensor.to or tensor.to(device=...)) and preserve dtype and
non-blocking semantics where appropriate; ensure None handling remains and that
non-tensor inputs are still converted with torch.as_tensor(...,
device=env.DEVICE).
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@deepmd/pt_expt/utils/network.py`:
- Around line 45-64: The __setattr__ implementation currently falls back to
super().__setattr__ for first-time non-trainable tensors (names "w","b","idt"),
which stores plain tensors as attributes instead of registering them as buffers;
update the non-trainable branch in __setattr__ (the block using to_torch_array
and checking getattr(self, "trainable", False)) so that when trainable is False
and val is not None you call self.register_buffer(name, val) (or
register_buffer(name, None) if val is None) instead of super().__setattr__,
keeping the existing handling when name already in self._buffers; keep
to_torch_array conversion and TorchArrayParam logic for the trainable path
unchanged.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@deepmd/pt_expt/utils/network.py`:
- Around line 41-44: NativeOP.__call__ is shadowing torch.nn.Module.__call__ so
calling layer(x) never reaches the defined forward() implementations; update the
class chain so Module.__call__ (and thus forward()) is used: either override
NativeLayer.__call__ (and NativeNet / NetworkCollection if needed) to delegate
to torch.nn.Module.__call__(self, *args, **kwargs) or change the call site to
invoke forward() directly; specifically locate NativeOP.__call__ (the one
delegating to self.call()) and replace/override it in NativeLayer (and
NativeNet) so it calls torch.nn.Module.__call__ or forward(*args, **kwargs)
instead of self.call(), ensuring export/jit use the standard Module → forward()
pathway.
🧹 Nitpick comments (1)
deepmd/pt_expt/utils/network.py (1)
29-29: Remove unusednoqadirective.Ruff reports that
PYI034is not enabled, making this suppression comment unnecessary.- def __new__( # noqa: PYI034 + def __new__(
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@deepmd/pt_expt/descriptor/se_e2_a.py`:
- Around line 31-53: In __setattr__ when handling name == "emask" ensure you
check that the incoming value actually exposes get_exclude_types before calling
it: if value is not None and not isinstance(value, torch.nn.Module) then use
hasattr(value, "get_exclude_types") (or isinstance checks) and only call
value.get_exclude_types() to construct the PairExcludeMask; otherwise leave
value unchanged and pass it to super().__setattr__; update the emask branch to
mirror the defensive pattern used in the embeddings branch to avoid
AttributeError.
🧹 Nitpick comments (1)
source/tests/consistent/common.py (1)
351-355: Inline exception message could use a constant or exception subclass.Static analysis (Ruff TRY003) flags the long string in
raise ValueError("PT exportable class is not set"). This is a very minor style nit — the existing code at Line 338 (raise ValueError("No available reference")) has the same pattern, so this is consistent with the codebase style.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@source/tests/pt_expt/utils/test_common.py`:
- Around line 1-13: The test module currently does
importlib.import_module("torch") which raises ModuleNotFoundError during
collection; update the test to skip when PyTorch is missing by replacing that
import with pytest.importorskip("torch") at module scope (or alternatively add a
source/tests/pt_expt/conftest.py that calls pytest.importorskip("torch") in
module import time), keeping the rest of the file (references to to_torch_array
and env) unchanged so tests are skipped instead of failing on import.
🧹 Nitpick comments (1)
source/tests/pt_expt/utils/test_common.py (1)
16-25: Consider asserting value, shape, and dtype preservation.The test only verifies type and device placement. A faulty
to_torch_arraythat returns an empty or random tensor on the correct device would still pass. Adding a few more assertions would strengthen the test:Suggested additions
tensor = to_torch_array(arr) assert torch.is_tensor(tensor) assert tensor.device == env.DEVICE + assert tensor.shape == (2, 3) + assert tensor.dtype == torch.float32 + np.testing.assert_array_equal(tensor.cpu().numpy(), arr)
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5194 +/- ##
==========================================
+ Coverage 81.95% 81.97% +0.02%
==========================================
Files 714 724 +10
Lines 73441 73810 +369
Branches 3616 3615 -1
==========================================
+ Hits 60187 60507 +320
- Misses 12091 12140 +49
Partials 1163 1163 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@deepmd/pt_expt/descriptor/se_e2_a.py`:
- Around line 36-43: The current __setattr__ treats first-time assignments to
"davg"/"dstd" as normal attributes, so those tensors are not registered and
won't appear in state_dict or move with .to(); in __setattr__ of se_e2_a.py,
when name is in {"davg","dstd"} convert value to a tensor (or None) as you
already do and if name not in self._buffers call self.register_buffer(name,
tensor) (preserving persistence as needed) instead of falling through to
super().__setattr__, then return; if name is already in self._buffers keep the
existing branch that assigns into self._buffers[name] and returns.
🧹 Nitpick comments (1)
deepmd/pt_expt/utils/network.py (1)
29-29: Remove the unusednoqadirective.Ruff (RUF100) reports this
# noqa: PYI034is unnecessary since the rule isn't enabled or triggered here.- def __new__( # noqa: PYI034 + def __new__(
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@deepmd/backend/pt_expt.py`:
- Around line 57-126: The PT backend fails on .pte files because
serialize_from_file and deserialize_to_file only accept .pth/.pt; update
deepmd.pt.utils.serialization.serialize_from_file and deserialize_to_file to
recognize and correctly handle the .pte extension (either by supporting direct
.pte serialization/deserialization or converting .pte ↔ .pt), and ensure the PT
inference loader (DeepEval in deepmd.pt.infer.deep_eval) can load .pte models
(or add a .pte-specific loader) so the backend's serialize_hook/deserialize_hook
and entry_point_hook work with .pte files.
In `@deepmd/pt_expt/utils/env.py`:
- Around line 93-99: The module currently calls set_default_nthreads(), then
unconditionally invokes torch.set_num_interop_threads(inter_nthreads) and
torch.set_num_threads(intra_nthreads) which can raise RuntimeError if another
env module already set these; update the logic in the module (around
set_default_nthreads(), get_default_nthreads(), and the torch calls) to first
check existing values via torch.get_num_interop_threads() and
torch.get_num_threads() and only call
torch.set_num_interop_threads(inter_nthreads) or
torch.set_num_threads(intra_nthreads) when the desired value differs and the
current value is 0 or not already set, or centralize the thread-setting into a
single initializer to avoid multiple calls; ensure you reference and update the
calls to torch.set_num_interop_threads and torch.set_num_threads and add guard
checks using torch.get_num_interop_threads/get_num_threads to prevent the
RuntimeError when both backends are imported.
🧹 Nitpick comments (3)
deepmd/pt_expt/utils/env.py (2)
1-99: Near-complete duplication ofdeepmd/pt/utils/env.py— consider sharing.This file is nearly line-for-line identical to
deepmd/pt/utils/env.py(same constants, precision dicts, device selection, worker config, threading setup,__all__). The only material difference isimportlib.import_module("torch")on line 21 instead ofimport torch.Maintaining two copies means every future change (new precision entry, new env var, bug fix) must be applied in both places. Consider extracting the shared logic into a common module (e.g.,
deepmd/_pt_common/env.py) that bothptandpt_exptimport, or havingpt_exptsimply re-export fromdeepmd.pt.utils.env.#!/bin/bash # Show the diff between the two env files to confirm duplication extent diff <(sed -n '1,120p' deepmd/pt/utils/env.py) <(sed -n '1,120p' deepmd/pt_expt/utils/env.py)
21-21:importlib.import_module("torch")provides no benefit here.Since
torchis used unconditionally at module level (line 52:torch.cuda.is_available(), lines 97–99:torch.set_num_interop_threads/set_num_threads), any missingtorchwill raiseModuleNotFoundErrorregardless. A plainimport torchis clearer and lets static analysis tools (type checkers, linters) resolve the module.Suggested change
-import importlib import logging import multiprocessing import os import sys import numpy as np +import torch ... -torch = importlib.import_module("torch")source/tests/consistent/descriptor/common.py (1)
147-177: Consider delegating toeval_pt_descriptorto reduce duplication.
eval_pt_expt_descriptoris body-identical toeval_pt_descriptor— only the parameter name differs. You could eliminate the duplication by delegating:def eval_pt_expt_descriptor(self, pt_expt_obj, natoms, coords, atype, box, mixed_types=False): return self.eval_pt_descriptor(pt_expt_obj, natoms, coords, atype, box, mixed_types)If the two paths are expected to diverge later, the current form is fine — but the helpful comment on lines 156–159 already appears in
eval_pt_descriptor's behavior implicitly, so keeping both in sync manually is a maintenance risk.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
source/tests/consistent/descriptor/common.py (1)
28-35:⚠️ Potential issue | 🟡 MinorSplit the import guard: PT_EXPT does not need the PT-specific nlist functions.
The guard on line 28 imports
build_neighbor_list_ptandextend_coord_with_ghosts_pt(lines 32–35) unconditionally wheneverINSTALLED_PT_EXPTisTrue, buteval_pt_expt_descriptoruses only the dpmodel versions (build_neighbor_list,extend_coord_with_ghostsfrom lines 15–16). If PT_EXPT is available without the PT backend, these imports would fail.Suggested fix:
Split the guard
if INSTALLED_PT or INSTALLED_PT_EXPT: import torch from deepmd.pt.utils.env import DEVICE as PT_DEVICE + +if INSTALLED_PT: from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt from deepmd.pt.utils.nlist import ( extend_coord_with_ghosts as extend_coord_with_ghosts_pt, )
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@deepmd/backend/pt_expt.py`:
- Around line 37-45: The class variable features in pt_expt.py incorrectly
advertises Backend.Feature.DEEP_EVAL, .NEIGHBOR_STAT, and .IO while only
entry_point_hook is implemented; update the features tuple to only include
Backend.Feature.ENTRY_POINT (remove DEEP_EVAL, NEIGHBOR_STAT, IO) so
Backend.get_backends_by_feature() won't return this backend for unimplemented
properties like neighbor_stat, deep_eval, or io; when those properties
(neighbor_stat, deep_eval, io) are implemented later, re-add their flags to the
features tuple.
🧹 Nitpick comments (1)
deepmd/backend/pt_expt.py (1)
57-68: Entry point reuses the standard PT backend'smain.This returns the same
deepmd.pt.entrypoints.mainas the PyTorch backend. If that's intentional for this phase (shared training path), consider adding a brief comment explaining the reuse so future readers don't mistake it for a copy-paste oversight.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@deepmd/pt_expt/descriptor/se_r.py`:
- Around line 52-57: The emask object (PairExcludeMask) holds a type_mask tensor
but is not a torch.nn.Module so model.to(device) won't move that tensor; fix by
making PairExcludeMask a proper submodule or registering its tensor as a buffer:
update the PairExcludeMask class to inherit from torch.nn.Module (call
super().__init__() in its constructor) so instances become submodules and their
tensors move with model.to(), and/or in the parent module's __setattr__ where
name == "emask" detect non-Module PairExcludeMask values and register their
internal type_mask via self.register_buffer("emask_type_mask", value.type_mask)
(and adjust build_type_exclude_mask/call() to use the registered buffer),
ensuring uniqueness of buffer name and removing the current manual placement to
env.DEVICE.
In `@deepmd/pt_expt/utils/network.py`:
- Around line 47-51: The __setattr__ in network.py currently calls
super().__setattr__(name, None) when val is None, which can shadow previously
registered items; update __setattr__ so that when val is None you explicitly
remove the name from self._parameters and self._buffers (e.g.,
self._parameters.pop(name, None); self._buffers.pop(name, None)) before
delegating to super().__setattr__(name, None); apply this logic for the
parameter names set {"w","b","idt"} and any other paths in __setattr__ that
accept None to ensure registered entries are cleared rather than shadowed.
🧹 Nitpick comments (3)
deepmd/pt_expt/utils/env.py (1)
19-20: Minor:import torchplaced after logger initialization.The
import torchon Line 20 sits between the logger assignment and the rest of the module logic, separated from the other imports at the top. This is functional but unconventional — consider moving it up with the other imports for consistency.deepmd/pt_expt/utils/exclude_mask.py (1)
15-26: Identical__setattr__in both classes — consider a shared mixin.
AtomExcludeMaskandPairExcludeMaskhave the exact same__setattr__override. A small mixin would eliminate the duplication:♻️ Optional DRY refactor
+class _DeviceTypeMaskMixin: + def __setattr__(self, name: str, value: Any) -> None: + if name == "type_mask": + value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + super().__setattr__(name, value) + + -class AtomExcludeMask(AtomExcludeMaskDP): - def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask": - value = None if value is None else torch.as_tensor(value, device=env.DEVICE) - return super().__setattr__(name, value) +class AtomExcludeMask(_DeviceTypeMaskMixin, AtomExcludeMaskDP): + pass -class PairExcludeMask(PairExcludeMaskDP): - def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask": - value = None if value is None else torch.as_tensor(value, device=env.DEVICE) - return super().__setattr__(name, value) +class PairExcludeMask(_DeviceTypeMaskMixin, PairExcludeMaskDP): + passdeepmd/pt_expt/utils/network.py (1)
27-27: Remove unusednoqadirective.Static analysis (Ruff RUF100) flags
# noqa: PYI034as unnecessary here. ThePYI034rule is not enabled, so this suppression is a no-op.- def __new__( # noqa: PYI034 + def __new__(
| torch.nn.Module.__init__(self) | ||
| AtomExcludeMaskDP.__init__(self, *args, **kwargs) | ||
|
|
||
| def __setattr__(self, name: str, value: Any) -> None: |
Check notice
Code scanning / CodeQL
Explicit returns mixed with implicit (fall through) returns Note
| torch.nn.Module.__init__(self) | ||
| PairExcludeMaskDP.__init__(self, *args, **kwargs) | ||
|
|
||
| def __setattr__(self, name: str, value: Any) -> None: |
Check notice
Code scanning / CodeQL
Explicit returns mixed with implicit (fall through) returns Note
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Documentation