Skip to content

Commit ca72aa0

Browse files
authored
Fix pageable H2D copies in Gated DeltaNet PyTorch fallback (#45665)
Fix pageable H2D copy in Gated DeltaNet PyTorch fallback
1 parent 4287660 commit ca72aa0

5 files changed

Lines changed: 25 additions & 15 deletions

File tree

src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ def torch_chunk_gated_delta_rule(
551551
value = attn @ v_beta
552552
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
553553
last_recurrent_state = (
554-
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
554+
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, dtype=value.dtype, device=value.device)
555555
if initial_state is None
556556
else initial_state.to(value)
557557
)
@@ -595,9 +595,11 @@ def torch_recurrent_gated_delta_rule(
595595
scale = 1 / (query.shape[-1] ** 0.5)
596596
query = query * scale
597597

598-
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
598+
core_attn_out = torch.zeros(
599+
batch_size, num_heads, sequence_length, v_head_dim, dtype=value.dtype, device=value.device
600+
)
599601
last_recurrent_state = (
600-
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
602+
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, dtype=value.dtype, device=value.device)
601603
if initial_state is None
602604
else initial_state.to(value)
603605
)

src/transformers/models/qwen3_5/modeling_qwen3_5.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def torch_chunk_gated_delta_rule(
283283
value = attn @ v_beta
284284
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
285285
last_recurrent_state = (
286-
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
286+
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, dtype=value.dtype, device=value.device)
287287
if initial_state is None
288288
else initial_state.to(value)
289289
)
@@ -327,9 +327,11 @@ def torch_recurrent_gated_delta_rule(
327327
scale = 1 / (query.shape[-1] ** 0.5)
328328
query = query * scale
329329

330-
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
330+
core_attn_out = torch.zeros(
331+
batch_size, num_heads, sequence_length, v_head_dim, dtype=value.dtype, device=value.device
332+
)
331333
last_recurrent_state = (
332-
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
334+
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, dtype=value.dtype, device=value.device)
333335
if initial_state is None
334336
else initial_state.to(value)
335337
)

src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def torch_chunk_gated_delta_rule(
284284
value = attn @ v_beta
285285
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
286286
last_recurrent_state = (
287-
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
287+
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, dtype=value.dtype, device=value.device)
288288
if initial_state is None
289289
else initial_state.to(value)
290290
)
@@ -328,9 +328,11 @@ def torch_recurrent_gated_delta_rule(
328328
scale = 1 / (query.shape[-1] ** 0.5)
329329
query = query * scale
330330

331-
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
331+
core_attn_out = torch.zeros(
332+
batch_size, num_heads, sequence_length, v_head_dim, dtype=value.dtype, device=value.device
333+
)
332334
last_recurrent_state = (
333-
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
335+
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, dtype=value.dtype, device=value.device)
334336
if initial_state is None
335337
else initial_state.to(value)
336338
)

src/transformers/models/qwen3_next/modeling_qwen3_next.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def torch_chunk_gated_delta_rule(
423423
value = attn @ v_beta
424424
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
425425
last_recurrent_state = (
426-
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
426+
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, dtype=value.dtype, device=value.device)
427427
if initial_state is None
428428
else initial_state.to(value)
429429
)
@@ -467,9 +467,11 @@ def torch_recurrent_gated_delta_rule(
467467
scale = 1 / (query.shape[-1] ** 0.5)
468468
query = query * scale
469469

470-
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
470+
core_attn_out = torch.zeros(
471+
batch_size, num_heads, sequence_length, v_head_dim, dtype=value.dtype, device=value.device
472+
)
471473
last_recurrent_state = (
472-
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
474+
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, dtype=value.dtype, device=value.device)
473475
if initial_state is None
474476
else initial_state.to(value)
475477
)

src/transformers/models/qwen3_next/modular_qwen3_next.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def torch_chunk_gated_delta_rule(
262262
value = attn @ v_beta
263263
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
264264
last_recurrent_state = (
265-
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
265+
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, dtype=value.dtype, device=value.device)
266266
if initial_state is None
267267
else initial_state.to(value)
268268
)
@@ -306,9 +306,11 @@ def torch_recurrent_gated_delta_rule(
306306
scale = 1 / (query.shape[-1] ** 0.5)
307307
query = query * scale
308308

309-
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
309+
core_attn_out = torch.zeros(
310+
batch_size, num_heads, sequence_length, v_head_dim, dtype=value.dtype, device=value.device
311+
)
310312
last_recurrent_state = (
311-
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
313+
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim, dtype=value.dtype, device=value.device)
312314
if initial_state is None
313315
else initial_state.to(value)
314316
)

0 commit comments

Comments
 (0)