Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions miles/backends/training_utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,38 @@

logger = logging.getLogger(__name__)

# Maps bare metric names to their W&B top-level section(s).
# Keys appearing in multiple sections (e.g. pg_loss) are emitted under each.
_TRAIN_METRIC_GROUPS: dict[str, list[str]] = {
"ppo_kl": ["policy_shift"],
"ois": ["policy_shift"],
"pg_clipfrac": ["policy_shift"],
"pg_loss": ["policy_shift", "optimization"],
"log_probs": ["policy_shift"], # current policy (training forward pass)
"old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout)
"ref_kl": ["policy_shift"],
"train_rollout_logprob_abs_diff": ["train_inference_mismatch"],
"train_rollout_logprob_diff": ["train_inference_mismatch"],
"tis": ["train_inference_mismatch"],
"tis_abs": ["train_inference_mismatch"],
"tis_clipfrac": ["train_inference_mismatch"],
"loss": ["optimization"],
"entropy_loss": ["optimization"],
"kl_loss": ["optimization"],
"grad_norm": ["optimization"],
}

# Maps rollout batch field names to their W&B top-level section.
_ROLLOUT_DATA_METRIC_GROUPS: dict[str, str] = {
"log_probs": "train_inference_mismatch", # FSDP log probs at rollout time
"rollout_log_probs": "train_inference_mismatch", # inference engine log probs
"ref_log_probs": "policy_shift", # reference model log probs
"rewards": "reward",
"raw_reward": "reward",
"advantages": "reward",
"returns": "reward",
}


def gather_log_data(
metric_name: str,
Expand Down Expand Up @@ -185,6 +217,17 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc
if "rollout/entropy" in reduced_log_dict:
assert 0 < reduced_log_dict["rollout/entropy"] < 0.7

# Emit top-level grouped keys from reduced values (only on DP source rank)
if reduced_log_dict is not None:
top_level = {}
for key, group in _ROLLOUT_DATA_METRIC_GROUPS.items():
rollout_key = f"rollout/{key}"
if rollout_key in reduced_log_dict:
top_level[f"{group}/{key}"] = reduced_log_dict[rollout_key]
if top_level:
step = compute_rollout_step(args, rollout_id)
top_level["rollout/step"] = step
tracking_utils.log(args, top_level, step_key="rollout/step")
if args.ci_test and args.true_on_policy_mode:
assert log_dict["log_probs"] == log_dict["rollout_log_probs"], (
f"CI check failed: true_on_policy_mode is enabled, but log_probs "
Expand Down Expand Up @@ -436,6 +479,20 @@ def log_train_step(

log_dict_out["train/step"] = accumulated_step_id

# Emit top-level grouped copies for W&B panel organization (existing train/ keys unchanged)
grouped_additions = {}
prefix = f"train/{role_tag}"
for full_key, val in log_dict_out.items():
if not full_key.startswith(prefix):
continue
bare_key = full_key[len(prefix):]
if bare_key in _TRAIN_METRIC_GROUPS:
for group in _TRAIN_METRIC_GROUPS[bare_key]:
grouped_additions[f"{group}/{bare_key}"] = val
elif bare_key.startswith("lr-pg_"):
grouped_additions[f"optimization/{bare_key}"] = val
log_dict_out.update(grouped_additions)

if should_log is None:
should_log = dist.get_rank() == 0

Expand Down
27 changes: 24 additions & 3 deletions miles/backends/training_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,21 +682,42 @@ def policy_loss_function(
if log_probs.numel() == 0:
loss += 0 * logits.sum()

# Current and old policy log probs for policy_shift panel
log_probs_metric = sum_of_sample_mean(log_probs).clone().detach()
old_log_probs_metric = sum_of_sample_mean(old_log_probs).clone().detach()

# Train-inference mismatch: compare inference engine vs FSDP at rollout time
train_rollout_logprob_abs_diff = None
train_rollout_logprob_diff = None
if "rollout_log_probs" in batch and batch["rollout_log_probs"]:
rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0)
train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs).abs())
rollout_log_probs_cat = torch.cat(batch["rollout_log_probs"], dim=0)
log_probs_batch_cat = torch.cat(batch["log_probs"], dim=0)
train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach()
# signed: log π(inf) − log π(fsdp rollout)
train_rollout_logprob_diff = sum_of_sample_mean(rollout_log_probs_cat - log_probs_batch_cat).clone().detach()

# KL vs reference model — always log when ref present, regardless of use_kl_loss
ref_kl_metric = None
if "ref_log_probs" in batch and batch["ref_log_probs"]:
ref_log_probs_cat = torch.cat(batch["ref_log_probs"], dim=0)
ref_kl_metric = sum_of_sample_mean(log_probs - ref_log_probs_cat).clone().detach()

reported_loss = {
"loss": loss.clone().detach(),
"pg_loss": pg_loss.clone().detach(),
"entropy_loss": entropy_loss.clone().detach(),
"pg_clipfrac": pg_clipfrac.clone().detach(),
"ppo_kl": ppo_kl.clone().detach(),
"log_probs": log_probs_metric,
"old_log_probs": old_log_probs_metric,
}

if train_rollout_logprob_abs_diff is not None:
reported_loss["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff.clone().detach()
reported_loss["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff
reported_loss["train_rollout_logprob_diff"] = train_rollout_logprob_diff

if ref_kl_metric is not None:
reported_loss["ref_kl"] = ref_kl_metric

if args.use_kl_loss:
reported_loss["kl_loss"] = kl_loss.clone().detach()
Expand Down
85 changes: 85 additions & 0 deletions miles/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,8 +1186,10 @@ def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_

def compute_metrics_from_samples(args, samples):
response_lengths = [sample.effective_response_length for sample in samples]
n = len(samples)

log_dict = {}
# existing keys (unchanged)
log_dict |= dict_add_prefix(compute_statistics(response_lengths), "response_len/")
log_dict |= _compute_zero_std_metrics(args, samples)
log_dict |= _compute_spec_metrics(args, samples)
Expand All @@ -1196,6 +1198,35 @@ def compute_metrics_from_samples(args, samples):
log_dict["repetition_frac"] = np.mean([int(has_repetition(s.response)) for s in samples]).item()
log_dict["truncated_ratio"] = np.mean([int(s.status == Sample.Status.TRUNCATED) for s in samples]).item()

# new top-level grouped keys: global
log_dict |= _compute_grouped_reward_metrics(args, samples, "reward", n, include_count_frac=False)
log_dict |= _compute_grouped_response_metrics(args, samples, "response_stats")
log_dict |= _compute_group_outcome_metrics(args, samples, prefix="reward")

# per-correctness (no count_frac: for binary rewards = mean reward = already in reward/raw_reward)
correct = [s for s in samples if s.get_reward_value(args) > 0]
incorrect = [s for s in samples if s.get_reward_value(args) <= 0]
for label, grp in [("correct", correct), ("incorrect", incorrect)]:
if grp:
log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{label}", n, include_count_frac=False)
log_dict |= _compute_grouped_response_metrics(args, grp, f"response_stats/{label}")

# per-category and combined (only if category data present)
cat_key = _get_problem_category_key(args, samples)
if cat_key is not None:
for cat, cat_grp in group_by(samples, lambda s: s.metadata.get(cat_key)).items():
if cat is None or not cat_grp:
continue
log_dict |= _compute_grouped_reward_metrics(args, cat_grp, f"reward/{cat}", n)
log_dict |= _compute_grouped_response_metrics(args, cat_grp, f"response_stats/{cat}")
log_dict |= _compute_group_outcome_metrics(args, cat_grp, prefix=f"reward/{cat}")
for label, grp in [
("correct", [s for s in cat_grp if s.get_reward_value(args) > 0]),
("incorrect", [s for s in cat_grp if s.get_reward_value(args) <= 0]),
]:
if grp:
log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{cat}/{label}", n)
log_dict |= _compute_grouped_response_metrics(args, grp, f"response_stats/{cat}/{label}")
tito_vals = [s.metadata.get("tito_session_mismatch") for s in samples]
tito_vals = [v for v in tito_vals if v is not None]
if tito_vals:
Expand Down Expand Up @@ -1297,3 +1328,57 @@ def _compute_reward_cat_metrics(args, all_samples: list[Sample]):
samples_of_reward_cat = group_by(all_samples, lambda s: s.reward[reward_cat_key])

return {f"error_cat/{reward_cat}": len(s) / len(all_samples) for reward_cat, s in samples_of_reward_cat.items()}


# Candidate metadata keys to auto-detect problem category (checked in order)
_CANDIDATE_CATEGORY_KEYS = ["category", "type", "subject", "domain", "problem_type"]


def _get_problem_category_key(args, all_samples: list[Sample]) -> str | None:
"""Return the metadata key to use for problem category grouping, or None if not available."""
explicit = getattr(args, "log_problem_category", None)
if explicit:
return explicit
for sample in all_samples:
if sample.metadata:
for key in _CANDIDATE_CATEGORY_KEYS:
if key in sample.metadata:
return key
return None


def _compute_grouped_reward_metrics(
args, group: list[Sample], prefix: str, n_total: int, include_count_frac: bool = True
) -> dict:
"""Reward/outcome metrics for a split — emitted under reward/ sections."""
result = {f"{prefix}/raw_reward": np.mean([s.get_reward_value(args) for s in group]).item()}
if include_count_frac:
result[f"{prefix}/count_frac"] = len(group) / n_total
return result


def _compute_grouped_response_metrics(args, group: list[Sample], prefix: str) -> dict:
"""Response shape metrics for a split — emitted under response_stats/ sections."""
return {
f"{prefix}/response_len": np.mean([s.effective_response_length for s in group]).item(),
f"{prefix}/truncated_frac": np.mean([int(s.status == Sample.Status.TRUNCATED) for s in group]).item(),
f"{prefix}/repetition_frac": np.mean([int(has_repetition(s.response)) for s in group]).item(),
}


def _compute_group_outcome_metrics(
args, all_samples: list[Sample], prefix: str = "reward"
) -> dict:
"""Fraction of prompt groups that are unanimously correct or incorrect. GRPO only."""
if args.advantage_estimator == "ppo":
return {}
groups = list(group_by(all_samples, lambda s: s.group_index).values())
n_groups = len(groups)
if n_groups == 0:
return {}
all_correct = sum(1 for g in groups if all(s.get_reward_value(args) > 0 for s in g))
all_incorrect = sum(1 for g in groups if all(s.get_reward_value(args) <= 0 for s in g))
return {
f"{prefix}/all_correct_group_frac": all_correct / n_groups,
f"{prefix}/all_incorrect_group_frac": all_incorrect / n_groups,
}
Loading