Skip to content

calcuate mean token accuracy metric while training #337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 114 additions & 16 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,30 +979,114 @@ def masked_mean(loss, label_mask, dtype):
return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()


@torch.no_grad()
def calculate_sharded_accuracy(
sharded_logits: torch.Tensor, # [seq, batch, sharded_vocab]
label_ids: torch.Tensor, # [batch, seq]
label_mask: torch.Tensor, # [batch, seq]
tp_group: dist.ProcessGroup,
ignore_index: int = -100,
) -> torch.Tensor:
"""Calculates mean token accuracy for sharded logits."""
# Transpose logits to [batch, seq, sharded_vocab] to align with labels/mask
sharded_logits = sharded_logits.transpose(0, 1).contiguous()
sharded_vocab_size = sharded_logits.shape[-1]

# Flatten everything
flat_logits = sharded_logits.view(-1, sharded_vocab_size)
flat_labels = label_ids.view(-1)
flat_mask = label_mask.view(-1)

# --- Find the global argmax across sharded logits ---
rank = dist.get_rank(tp_group)
world_size = dist.get_world_size(tp_group)

# 1. Find the max logit value and its index *within the current shard*
local_max_val, local_argmax_idx = torch.max(flat_logits, dim=-1)

# 2. Calculate the *global* index corresponding to the local argmax
local_argmax_global_idx = local_argmax_idx + rank * sharded_vocab_size

# 3. Combine the local max value and its global index
val_idx_pair = torch.stack([local_max_val, local_argmax_global_idx.to(local_max_val.dtype)], dim=-1).contiguous()

# 4. Gather these pairs from all ranks
gathered_val_idx_pairs = torch.empty(
flat_logits.shape[0], world_size, 2, dtype=val_idx_pair.dtype, device=val_idx_pair.device
)
dist.all_gather_into_tensor(gathered_val_idx_pairs.view(-1, 2), val_idx_pair, group=tp_group)
# Reshape happens implicitly: (flat_logits.shape[0], world_size, 2)

# 5. Find the index *within the gathered dimension* (dim 1) that has the max value
global_max_indices_in_gathered = torch.argmax(gathered_val_idx_pairs[..., 0], dim=1)

# 6. Use these indices to select the correct global prediction index
global_predictions = torch.gather(
gathered_val_idx_pairs[..., 1], 1, global_max_indices_in_gathered.unsqueeze(1)
).squeeze(1).long()

# --- Calculate Accuracy ---
# Create mask for valid (non-padding) tokens based on labels
mask = flat_labels != ignore_index
# Ensure only tokens covered by the original label_mask are considered
valid_mask = mask & flat_mask

# Compare predictions to labels only where the mask is valid
correct_predictions = (global_predictions == flat_labels) & valid_mask
local_correct_tokens = correct_predictions.sum()
local_total_tokens = valid_mask.sum()

# Gather counts across TP group using all_reduce for efficiency
global_correct_tokens_tensor = local_correct_tokens.clone()
dist.all_reduce(global_correct_tokens_tensor, op=dist.ReduceOp.SUM, group=tp_group)

global_total_tokens_tensor = local_total_tokens.clone()
dist.all_reduce(global_total_tokens_tensor, op=dist.ReduceOp.SUM, group=tp_group)

# Compute accuracy
total_sum = global_total_tokens_tensor.item()
accuracy = (global_correct_tokens_tensor.item() / total_sum) if total_sum > 0 else 0.0

# Return as a tensor on the correct device
return torch.tensor(accuracy, device=sharded_logits.device, dtype=torch.float)


class Loss(nn.Module):
def __init__(self, tp_pg: dist.ProcessGroup):
def __init__(self, tp_pg: dist.ProcessGroup, ignore_index: int = -100):
super().__init__()
self.tp_pg = tp_pg
self.ignore_index = ignore_index

def forward(
self,
sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
label_ids: torch.Tensor, # [batch_size, seq_length]
label_mask: torch.Tensor, # [batch_size, seq_length]
) -> Dict[str, torch.Tensor]:
loss = sharded_cross_entropy(
# Calculate loss
loss_val = sharded_cross_entropy(
sharded_logits,
label_ids.transpose(0, 1).contiguous(),
group=self.tp_pg,
dtype=torch.float,
).transpose(0, 1)
loss = masked_mean(loss, label_mask, dtype=torch.float)
return {"loss": loss}
loss = masked_mean(loss_val, label_mask, dtype=torch.float)

# Calculate accuracy
accuracy = calculate_sharded_accuracy(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
tp_group=self.tp_pg,
ignore_index=self.ignore_index,
)

return {"loss": loss, "mean_token_accuracy": accuracy}


class LossWithZLoss(Loss):
def __init__(self, tp_pg: dist.ProcessGroup, z_loss_coefficient: float):
super().__init__(tp_pg)
def __init__(self, tp_pg: dist.ProcessGroup, z_loss_coefficient: float, ignore_index: int = -100):
super().__init__(tp_pg, ignore_index)
self.z_loss_coef = z_loss_coefficient

def forward(
Expand All @@ -1011,16 +1095,27 @@ def forward(
label_ids: torch.Tensor, # [batch_size, seq_length]
label_mask: torch.Tensor, # [batch_size, seq_length]
) -> Dict[str, torch.Tensor]:
loss, z_loss = sharded_cross_entropy(
# Calculate loss and z_loss
loss_val, z_loss_val = sharded_cross_entropy(
sharded_logits,
label_ids.transpose(0, 1).contiguous(),
group=self.tp_pg,
dtype=torch.float,
z_loss_coef=self.z_loss_coef,
)
loss = masked_mean(loss.transpose(0, 1), label_mask, dtype=torch.float)
z_loss = masked_mean(z_loss.detach().transpose(0, 1), label_mask, dtype=torch.float)
return {"loss": loss, "z_loss": z_loss}
loss = masked_mean(loss_val.transpose(0, 1), label_mask, dtype=torch.float)
z_loss = masked_mean(z_loss_val.detach().transpose(0, 1), label_mask, dtype=torch.float)

# Calculate accuracy
accuracy = calculate_sharded_accuracy(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
tp_group=self.tp_pg,
ignore_index=self.ignore_index,
)

return {"loss": loss, "z_loss": z_loss, "mean_token_accuracy": accuracy}


class LlamaForTraining(NanotronModel):
Expand All @@ -1037,6 +1132,7 @@ def __init__(
# Choose the appropriate loss class based on config
loss_kwargs = {
"tp_pg": parallel_context.tp_pg,
"ignore_index": config.pad_token_id, # Assuming pad_token_id is the ignore index
}
if config.z_loss_enabled:
loss_kwargs["z_loss_coefficient"] = config.z_loss_coefficient
Expand All @@ -1050,7 +1146,10 @@ def __init__(
"label_ids",
"label_mask",
},
module_output_keys={"loss", "z_loss"} if config.z_loss_enabled else {"loss"},
# Update output keys to include accuracy
module_output_keys={"loss", "z_loss", "mean_token_accuracy"}
if config.z_loss_enabled
else {"loss", "mean_token_accuracy"},
)

self.parallel_context = parallel_context
Expand All @@ -1068,15 +1167,14 @@ def forward(
input_ids=input_ids,
input_mask=input_mask,
)
loss = self.loss(
# Call the loss block which now computes loss and accuracy
loss_output = self.loss(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
)
if self.config.z_loss_enabled:
return {"loss": loss["loss"], "z_loss": loss["z_loss"]}
else:
return {"loss": loss["loss"]}
# The loss block now returns loss and accuracy (and optionally z_loss)
return loss_output # Return the full dictionary from the loss block

@torch.no_grad()
def init_model_randomly(self, config: Config):
Expand Down
Loading