Skip to content

Commit

Permalink
Update PR with more edge cases where tensor may not be contiguous aft…
Browse files Browse the repository at this point in the history
…er placed on cpu
  • Loading branch information
gioannides authored and dacorvo committed Nov 14, 2024
1 parent ece4f98 commit 15f87cc
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,21 @@ def consolidate_tensor_parallel_checkpoints(
# This might not be the case anymore when `ParameterMetadata` uses slices.
sharded_metadata = sharded_metadatas[name]
if sharded_metadata.is_tied:
consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu")
consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu").contiguous()
else:
# Ensure that all tensors are contiguous before concatenating or further processing
weights = [state_dict[name].contiguous() for state_dict in state_dicts]
tp_size = len(weights)

full_weight = torch.cat(
weights,
dim=sharded_metadata.partition_dim,
).contiguous() # Ensure the result is also contiguous

full_weight = (
torch.cat(
weights,
dim=sharded_metadata.partition_dim,
)
.to("cpu")
.contiguous()
) # Ensure the result is also contiguous

if weight_name in ["weight_k", "weight_v", "bias_k", "bias_v"]:
full_weight = (
torch.chunk(full_weight, gqa_qkv_metadata["kv_size_multiplier"], dim=0)[0].detach().clone()
Expand Down

0 comments on commit 15f87cc

Please sign in to comment.