Skip to content

Register feature_weights as buffer to ensure correct device placement#241

Closed
Rohit7824567 wants to merge 1 commit intomllam:mainfrom
Rohit7824567:patch-2
Closed

Register feature_weights as buffer to ensure correct device placement#241
Rohit7824567 wants to merge 1 commit intomllam:mainfrom
Rohit7824567:patch-2

Conversation

@Rohit7824567
Copy link
Copy Markdown

feature_weights was a regular tensor, which could cause device mismatch when moving the model to GPU or other accelerators. Since it’s not learnable but part of the model state, it should be registered as a buffer.

Using self.register_buffer("feature_weights", ...) ensures:

Automatic device movement with the model

Inclusion in state_dict for checkpointing

Consistency in distributed training/inference

Avoidance of runtime device errors

Describe your changes

< Summary of the changes.>

< Please also include relevant motivation and context. >

< List any dependencies that are required for this change. >

Issue Link

< Link to the relevant issue or task, if applicable > (e.g. closes #00 or solves #00)

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug
    • maintenance: when your contribution is relates to repo maintenance, e.g. CI/CD or documentation

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • (if the PR is not just maintenance/bugfix) the PR is assigned to the next milestone. If it is not, propose it for a future milestone.
  • author has added an entry to the changelog (and designated the change as added, changed, fixed or maintenance)
  • Once the PR is ready to be merged, squash commits and merge the PR.

 feature_weights was a regular tensor, which could cause device mismatch when moving the model to GPU or other accelerators. Since it’s not learnable but part of the model state, it should be registered as a buffer.

Using self.register_buffer("feature_weights", ...) ensures:

Automatic device movement with the model

Inclusion in state_dict for checkpointing

Consistency in distributed training/inference

Avoidance of runtime device errors
@joeloskarsson
Copy link
Copy Markdown
Collaborator

Hi! feature_weights is not itself involved in any GPU computations, or really any computations that is run during training or inference. It is only used to set per_var_std, which is correctly a buffer. I don't think we should make tensors buffers unless we need to.

What I can see is suboptimal here is that feature_weights should not need to be a member variable, as it is just used intermediately now.

@sadamov
Copy link
Copy Markdown
Collaborator

sadamov commented Mar 31, 2026

This PR is stale and code path might potentially be moved/removed by ongoing #208/#507 refactors.

@sadamov sadamov closed this Mar 31, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants