-
Notifications
You must be signed in to change notification settings - Fork 12
refactor: update sampling evaluation logic #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -59,10 +59,14 @@ def build_baseline( | |||||||||||||||||||||||||||||||||
| ) # convert logits to probs for sampling | ||||||||||||||||||||||||||||||||||
| inputs.append(inp) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| freq_dist = _compute_frequency_distribution( | ||||||||||||||||||||||||||||||||||
| ref_runnable, inp, device, defn, num_trials=50000 | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| outputs.append({"frequency_distribution": freq_dist}) | ||||||||||||||||||||||||||||||||||
| thresholding_method = _detect_thresholding_method(defn) | ||||||||||||||||||||||||||||||||||
| params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp} | ||||||||||||||||||||||||||||||||||
| valid_mask = _compute_valid_sampling_mask(inp["probs"], thresholding_method, params) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| masked_probs = inp["probs"] * valid_mask.float() | ||||||||||||||||||||||||||||||||||
| expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| outputs.append({"expected_probs": expected_probs}) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| latencies: List[float] = [] | ||||||||||||||||||||||||||||||||||
| for inp in inputs: | ||||||||||||||||||||||||||||||||||
|
|
@@ -94,15 +98,20 @@ def check_correctness( | |||||||||||||||||||||||||||||||||
| log_path: str, | ||||||||||||||||||||||||||||||||||
| device: str, | ||||||||||||||||||||||||||||||||||
| ) -> Tuple[Optional[Correctness], Optional[Evaluation]]: | ||||||||||||||||||||||||||||||||||
| ref_freq = ref_outputs[0]["frequency_distribution"] | ||||||||||||||||||||||||||||||||||
| vocab_size = ref_freq.shape[0] | ||||||||||||||||||||||||||||||||||
| expected_probs = ref_outputs[0]["expected_probs"] | ||||||||||||||||||||||||||||||||||
| vocab_size = expected_probs.shape[-1] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| inp = inputs[0] | ||||||||||||||||||||||||||||||||||
| params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp} | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| output_names = list(defn.outputs.keys()) | ||||||||||||||||||||||||||||||||||
| output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()} | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Compute valid sampling mask based on thresholding | ||||||||||||||||||||||||||||||||||
| thresholding_method = _detect_thresholding_method(defn) | ||||||||||||||||||||||||||||||||||
| probs = inp["probs"] | ||||||||||||||||||||||||||||||||||
| valid_mask = _compute_valid_sampling_mask(probs, thresholding_method, params) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Validate correct sampling token set | ||||||||||||||||||||||||||||||||||
| for _ in range(cfg.sampling_validation_trials): | ||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
|
|
@@ -137,27 +146,34 @@ def check_correctness( | |||||||||||||||||||||||||||||||||
| correctness=correctness, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Validate thresholding | ||||||||||||||||||||||||||||||||||
| thresholding_method = _detect_thresholding_method(defn) | ||||||||||||||||||||||||||||||||||
| probs = inp["probs"] | ||||||||||||||||||||||||||||||||||
| if not _check_thresholding(samples, probs, thresholding_method, params): | ||||||||||||||||||||||||||||||||||
| correctness = Correctness( | ||||||||||||||||||||||||||||||||||
| max_relative_error=float("inf"), max_absolute_error=float("inf") | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| message = ( | ||||||||||||||||||||||||||||||||||
| f"Samples {samples.tolist()} does not meet {thresholding_method} thresholding" | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| print(message, file=sys.stderr) | ||||||||||||||||||||||||||||||||||
| return correctness, make_eval( | ||||||||||||||||||||||||||||||||||
| status=EvaluationStatus.INCORRECT_NUMERICAL, | ||||||||||||||||||||||||||||||||||
| device=device, | ||||||||||||||||||||||||||||||||||
| log_path=log_path, | ||||||||||||||||||||||||||||||||||
| correctness=correctness, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| # Validate thresholding - check samples are within valid mask | ||||||||||||||||||||||||||||||||||
| if samples.dim() == 0: | ||||||||||||||||||||||||||||||||||
| samples_flat = samples.unsqueeze(0) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| samples_flat = samples.flatten() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| batch_size = valid_mask.shape[0] | ||||||||||||||||||||||||||||||||||
| for i in range(len(samples_flat)): | ||||||||||||||||||||||||||||||||||
| batch_idx = i % batch_size | ||||||||||||||||||||||||||||||||||
| sample_idx = samples_flat[i].item() | ||||||||||||||||||||||||||||||||||
| if not valid_mask[batch_idx, sample_idx]: | ||||||||||||||||||||||||||||||||||
| correctness = Correctness( | ||||||||||||||||||||||||||||||||||
| max_relative_error=float("inf"), max_absolute_error=float("inf") | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| message = ( | ||||||||||||||||||||||||||||||||||
| f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}" | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| print(message, file=sys.stderr) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
155
to
160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid truthiness on 0‑dim Torch bool Tensor. Using if not valid_mask[...] can raise “Boolean value of Tensor is ambiguous.” Convert to Python bool. - if not valid_mask[batch_idx, sample_idx]:
+ if not bool(valid_mask[batch_idx, sample_idx].item()):📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||
| return correctness, make_eval( | ||||||||||||||||||||||||||||||||||
| status=EvaluationStatus.INCORRECT_NUMERICAL, | ||||||||||||||||||||||||||||||||||
| device=device, | ||||||||||||||||||||||||||||||||||
| log_path=log_path, | ||||||||||||||||||||||||||||||||||
| correctness=correctness, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| sol_freq = _compute_frequency_distribution( | ||||||||||||||||||||||||||||||||||
| sol_runnable, inp, device, defn, num_trials=50000 | ||||||||||||||||||||||||||||||||||
| sol_freqs = _sample_token_distributions( | ||||||||||||||||||||||||||||||||||
| sol_runnable, inp, device, defn, num_trials=500000 | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| torch.cuda.synchronize(device) | ||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+169
to
173
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Second cuda synchronize has same issue; apply the same guard. Mirror the earlier fix after sampling distributions. - torch.cuda.synchronize(device)
+ _dev = torch.device(device)
+ if _dev.type == "cuda":
+ torch.cuda.synchronize(_dev)🧰 Tools🪛 Ruff (0.14.1)179-179: Do not catch blind exception: (BLE001) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||
|
|
@@ -166,13 +182,29 @@ def check_correctness( | |||||||||||||||||||||||||||||||||
| status=EvaluationStatus.RUNTIME_ERROR, device=device, log_path=log_path | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # total variation distance | ||||||||||||||||||||||||||||||||||
| tvd = 0.5 * torch.sum(torch.abs(sol_freq - ref_freq)).item() | ||||||||||||||||||||||||||||||||||
| max_abs, max_rel, _, _ = compute_error_stats(sol_freq, ref_freq, cfg) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| numerical_incorrect = tvd > cfg.sampling_tvd_threshold | ||||||||||||||||||||||||||||||||||
| batch_size = expected_probs.shape[0] | ||||||||||||||||||||||||||||||||||
| tvds = [] | ||||||||||||||||||||||||||||||||||
| max_abs_errors = [] | ||||||||||||||||||||||||||||||||||
| max_rel_errors = [] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| for i in range(batch_size): | ||||||||||||||||||||||||||||||||||
| tvd_i = 0.5 * torch.sum(torch.abs(sol_freqs[i] - expected_probs[i])).item() | ||||||||||||||||||||||||||||||||||
| tvds.append(tvd_i) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| max_abs_i, max_rel_i, _, _ = compute_error_stats(sol_freqs[i], expected_probs[i], cfg) | ||||||||||||||||||||||||||||||||||
| max_abs_errors.append(max_abs_i) | ||||||||||||||||||||||||||||||||||
| max_rel_errors.append(max_rel_i) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Use the worst (max) TVD and errors across all batch elements | ||||||||||||||||||||||||||||||||||
| max_tvd = max(tvds) | ||||||||||||||||||||||||||||||||||
| max_abs = max(max_abs_errors) | ||||||||||||||||||||||||||||||||||
| max_rel = max(max_rel_errors) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold | ||||||||||||||||||||||||||||||||||
| correctness = Correctness( | ||||||||||||||||||||||||||||||||||
| max_relative_error=max_rel, max_absolute_error=max_abs, extra={"tvd": tvd} | ||||||||||||||||||||||||||||||||||
| max_relative_error=max_rel, | ||||||||||||||||||||||||||||||||||
| max_absolute_error=max_abs, | ||||||||||||||||||||||||||||||||||
| extra={"tvd": max_tvd, "tvds_per_batch": tvds} | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
197
to
202
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Protect against missing cfg.sampling_tvd_threshold. If cfg.sampling_tvd_threshold is None/absent, comparison raises. Default or fail fast with message. - numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
+ tvd_thresh = getattr(cfg, "sampling_tvd_threshold", None)
+ if tvd_thresh is None:
+ raise ValueError("cfg.sampling_tvd_threshold must be set for sampling evaluation")
+ numerical_incorrect = max_tvd > tvd_thresh📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||
| if numerical_incorrect: | ||||||||||||||||||||||||||||||||||
| return correctness, make_eval( | ||||||||||||||||||||||||||||||||||
|
|
@@ -201,23 +233,125 @@ def _detect_thresholding_method(defn: Definition) -> str: | |||||||||||||||||||||||||||||||||
| return "none" # no thresholding | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _compute_frequency_distribution( | ||||||||||||||||||||||||||||||||||
| def _compute_valid_sampling_mask( | ||||||||||||||||||||||||||||||||||
| probs: torch.Tensor, method: str, params: Dict[str, Any], eps: float = 1e-5 | ||||||||||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| For tie-breaking in top_k (allows any token with prob >= k-th largest) | ||||||||||||||||||||||||||||||||||
| and numerical precision in top_p (allows tokens within eps of nucleus boundary). | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| if probs.dim() == 1: | ||||||||||||||||||||||||||||||||||
| probs = probs.unsqueeze(0) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| batch_size, vocab_size = probs.shape | ||||||||||||||||||||||||||||||||||
| device = probs.device | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if method == "none": | ||||||||||||||||||||||||||||||||||
| return torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| mask = torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if method in ["top_k", "top_k_top_p"]: | ||||||||||||||||||||||||||||||||||
| if "top_k" not in params: | ||||||||||||||||||||||||||||||||||
| raise ValueError(f"top_k parameter required for {method} but not found") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| top_k_param = params["top_k"] | ||||||||||||||||||||||||||||||||||
| for i in range(batch_size): | ||||||||||||||||||||||||||||||||||
| k = ( | ||||||||||||||||||||||||||||||||||
| int(top_k_param[i].item()) | ||||||||||||||||||||||||||||||||||
| if top_k_param.dim() > 0 | ||||||||||||||||||||||||||||||||||
| else int(top_k_param.item()) | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if 0 < k < vocab_size: | ||||||||||||||||||||||||||||||||||
| sorted_probs, _ = torch.sort(probs[i], descending=True) | ||||||||||||||||||||||||||||||||||
| # k-th largest value (0-indexed, so k-1) | ||||||||||||||||||||||||||||||||||
| pivot = sorted_probs[k - 1] | ||||||||||||||||||||||||||||||||||
| mask[i] = probs[i] >= pivot # tie-breaking handling | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Apply top_p mask with epsilon tolerance | ||||||||||||||||||||||||||||||||||
| if method in ["top_p", "top_k_top_p"]: | ||||||||||||||||||||||||||||||||||
| if "top_p" not in params: | ||||||||||||||||||||||||||||||||||
| raise ValueError(f"top_p parameter required for {method} but not found") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| top_p_param = params["top_p"] | ||||||||||||||||||||||||||||||||||
| for i in range(batch_size): | ||||||||||||||||||||||||||||||||||
| p = ( | ||||||||||||||||||||||||||||||||||
| float(top_p_param[i].item()) | ||||||||||||||||||||||||||||||||||
| if top_p_param.dim() > 0 | ||||||||||||||||||||||||||||||||||
| else float(top_p_param.item()) | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if 0 < p < 1: | ||||||||||||||||||||||||||||||||||
| sorted_probs, sorted_indices = torch.sort(probs[i], descending=True) | ||||||||||||||||||||||||||||||||||
| cumsum = torch.cumsum(sorted_probs, dim=0) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Find tokens in nucleus (cumsum <= p + eps for numerical tolerance) | ||||||||||||||||||||||||||||||||||
| nucleus_mask = cumsum <= (p + eps) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if not nucleus_mask.any(): | ||||||||||||||||||||||||||||||||||
| nucleus_mask[0] = True | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Map back to original indices | ||||||||||||||||||||||||||||||||||
| p_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device) | ||||||||||||||||||||||||||||||||||
| p_mask[sorted_indices[nucleus_mask]] = True | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| mask[i] = mask[i] & p_mask | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return mask | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _sample_token_distributions( | ||||||||||||||||||||||||||||||||||
| runnable: Runnable, | ||||||||||||||||||||||||||||||||||
| inputs: Dict[str, Any], | ||||||||||||||||||||||||||||||||||
| device: str, | ||||||||||||||||||||||||||||||||||
| defn: Definition, | ||||||||||||||||||||||||||||||||||
| num_trials: int = 10000, | ||||||||||||||||||||||||||||||||||
| num_trials: int = 500000, | ||||||||||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||
| batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1 | ||||||||||||||||||||||||||||||||||
| original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1 | ||||||||||||||||||||||||||||||||||
| vocab_size = inputs["probs"].shape[-1] | ||||||||||||||||||||||||||||||||||
| counter = torch.zeros(vocab_size, dtype=torch.int64, device=torch.device(device)) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| trials_needed = (num_trials + batch_size - 1) // batch_size | ||||||||||||||||||||||||||||||||||
| total_samples_collected = 0 | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Repeat entire input batch to fill up to target_batch_size for efficient sampling | ||||||||||||||||||||||||||||||||||
| target_batch_size = 10000 | ||||||||||||||||||||||||||||||||||
| repeat_count = target_batch_size // original_batch_size | ||||||||||||||||||||||||||||||||||
| actual_batch_size = repeat_count * original_batch_size | ||||||||||||||||||||||||||||||||||
|
Comment on lines
297
to
303
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a potential Additionally, if I suggest handling the original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
vocab_size = inputs["probs"].shape[-1]
if original_batch_size == 0:
return torch.empty((0, vocab_size), dtype=torch.float32, device=torch.device(device))
# Repeat entire input batch to fill up to target_batch_size for efficient sampling
target_batch_size = 10000
repeat_count = max(1, target_batch_size // original_batch_size)
actual_batch_size = repeat_count * original_batch_size |
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
| target_batch_size = 10000 | |
| repeat_count = target_batch_size // original_batch_size | |
| actual_batch_size = repeat_count * original_batch_size | |
| import math | |
| target_batch_size = 10000 | |
| repeat_count = max(1, math.ceil(target_batch_size / original_batch_size)) | |
| actual_batch_size = repeat_count * original_batch_size |
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 315-318, the
current computation uses integer division target_batch_size //
original_batch_size which can yield 0 when original_batch_size > 10000 (leading
to empty batches) and underutilization; change to use a ceiling division and
ensure at least one repeat: compute repeat_count = max(1, ceil(target_batch_size
/ original_batch_size)) (or equivalent integer math), then set actual_batch_size
= repeat_count * original_batch_size, and expose target_batch_size as a
configurable parameter (with validation to be a positive int) so it can be tuned
instead of hardcoding 10000.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When samples.dim() == 0, the code assumes a batch size of 1 and assigns the sample to the first batch element's counter. This is incorrect if original_batch_size > 1, as it would misattribute samples and lead to an incorrect frequency distribution for all batch items. The runnable should be expected to return a batch of samples matching actual_batch_size. If it returns a scalar when a batch is expected, it's a contract violation that should be flagged with an error.
if samples.dim() == 0:
if actual_batch_size != 1:
raise ValueError(
f"Expected a batch of samples (size {actual_batch_size}), but got a scalar."
)
# Single sample - assign to first batch element
sample_idx = samples.item()
counters[0, sample_idx] += 1
total_samples_per_batch += 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against zero-sum masked probs to avoid NaNs.
If the valid_mask zeros out all tokens (edge params), masked_probs.sum can be 0 leading to NaNs in expected_probs. Clamp the denominator.
🤖 Prompt for AI Agents