Skip to content

Commit 0de0d7c

Browse files
Copilothann-wang
andauthored
fix: move u/v/sigma to correct device/dtype in DecomposedLinear.from_linear
Agent-Logs-Url: https://github.com/AMD-AGI/ALTO/sessions/745e0c22-278f-4a2d-8089-2fb2c27ee8a5 Co-authored-by: hann-wang <8476580+hann-wang@users.noreply.github.com>
1 parent 0a5ec6c commit 0de0d7c

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

alto/nn/decomposed_linear.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ def from_linear(cls, linear: nn.Linear, lora_rank: int = 32):
3131
new_layer = cls(linear.in_features, linear.out_features, linear.bias is not None, lora_rank)
3232
new_layer.weight = linear.weight
3333
new_layer.bias = linear.bias
34+
device = linear.weight.device
35+
dtype = linear.weight.dtype
36+
new_layer.u.data = new_layer.u.data.to(device=device, dtype=dtype)
37+
new_layer.v.data = new_layer.v.data.to(device=device, dtype=dtype)
38+
new_layer.sigma.data = new_layer.sigma.data.to(device=device, dtype=dtype)
3439
return new_layer
3540

3641
def init_lora_weights(self, init_std: float = 0.02):

0 commit comments

Comments
 (0)