-
Notifications
You must be signed in to change notification settings - Fork 12
refactor: update sampling evaluation logic #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,16 +53,16 @@ def build_baseline( | |
| outputs: List[Dict[str, torch.Tensor]] = [] | ||
|
|
||
| inp = gen_inputs(defn, workload, device=device, stensors=loaded_stensors) | ||
| if "probs" in inp: | ||
| inp["probs"] = torch.softmax( | ||
| inp["probs"], dim=-1 | ||
| ) # convert logits to probs for sampling | ||
| inputs.append(inp) | ||
|
|
||
| freq_dist = _compute_frequency_distribution( | ||
| ref_runnable, inp, device, defn, num_trials=50000 | ||
| ) | ||
| outputs.append({"frequency_distribution": freq_dist}) | ||
| thresholding_method = _detect_thresholding_method(defn) | ||
| params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp} | ||
| valid_mask = _compute_valid_sampling_mask(inp["probs"], thresholding_method, params) | ||
|
|
||
| masked_probs = inp["probs"] * valid_mask.float() | ||
| expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True) | ||
|
|
||
| outputs.append({"expected_probs": expected_probs}) | ||
|
|
||
| latencies: List[float] = [] | ||
| for inp in inputs: | ||
|
|
@@ -94,15 +94,20 @@ def check_correctness( | |
| log_path: str, | ||
| device: str, | ||
| ) -> Tuple[Optional[Correctness], Optional[Evaluation]]: | ||
| ref_freq = ref_outputs[0]["frequency_distribution"] | ||
| vocab_size = ref_freq.shape[0] | ||
| expected_probs = ref_outputs[0]["expected_probs"] | ||
| vocab_size = expected_probs.shape[-1] | ||
|
|
||
| inp = inputs[0] | ||
| params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp} | ||
|
|
||
| output_names = list(defn.outputs.keys()) | ||
| output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()} | ||
|
|
||
| # Compute valid sampling mask based on thresholding | ||
| thresholding_method = _detect_thresholding_method(defn) | ||
| probs = inp["probs"] | ||
| valid_mask = _compute_valid_sampling_mask(probs, thresholding_method, params) | ||
|
|
||
| # Validate correct sampling token set | ||
| for _ in range(cfg.sampling_validation_trials): | ||
| try: | ||
|
|
@@ -137,27 +142,32 @@ def check_correctness( | |
| correctness=correctness, | ||
| ) | ||
|
|
||
| # Validate thresholding | ||
| thresholding_method = _detect_thresholding_method(defn) | ||
| probs = inp["probs"] | ||
| if not _check_thresholding(samples, probs, thresholding_method, params): | ||
| correctness = Correctness( | ||
| max_relative_error=float("inf"), max_absolute_error=float("inf") | ||
| ) | ||
| message = ( | ||
| f"Samples {samples.tolist()} does not meet {thresholding_method} thresholding" | ||
| ) | ||
| print(message, file=sys.stderr) | ||
| return correctness, make_eval( | ||
| status=EvaluationStatus.INCORRECT_NUMERICAL, | ||
| device=device, | ||
| log_path=log_path, | ||
| correctness=correctness, | ||
| ) | ||
| # Validate thresholding - check samples are within valid mask | ||
| if samples.dim() == 0: | ||
| samples_flat = samples.unsqueeze(0) | ||
| else: | ||
| samples_flat = samples.flatten() | ||
|
|
||
| batch_size = valid_mask.shape[0] | ||
| for i in range(len(samples_flat)): | ||
| batch_idx = i % batch_size | ||
| sample_idx = samples_flat[i].item() | ||
| if not valid_mask[batch_idx, sample_idx]: | ||
| correctness = Correctness( | ||
| max_relative_error=float("inf"), max_absolute_error=float("inf") | ||
| ) | ||
| message = f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}" | ||
| print(message, file=sys.stderr) | ||
| return correctness, make_eval( | ||
| status=EvaluationStatus.INCORRECT_NUMERICAL, | ||
| device=device, | ||
| log_path=log_path, | ||
| correctness=correctness, | ||
| ) | ||
|
|
||
| try: | ||
| sol_freq = _compute_frequency_distribution( | ||
| sol_runnable, inp, device, defn, num_trials=50000 | ||
| sol_freqs = _sample_token_distributions( | ||
| sol_runnable, inp, device, defn, num_trials=500000 | ||
| ) | ||
| torch.cuda.synchronize(device) | ||
| except Exception: | ||
|
|
@@ -166,13 +176,29 @@ def check_correctness( | |
| status=EvaluationStatus.RUNTIME_ERROR, device=device, log_path=log_path | ||
| ) | ||
|
|
||
| # total variation distance | ||
| tvd = 0.5 * torch.sum(torch.abs(sol_freq - ref_freq)).item() | ||
| max_abs, max_rel, _, _ = compute_error_stats(sol_freq, ref_freq, cfg) | ||
| batch_size = expected_probs.shape[0] | ||
| tvds = [] | ||
| max_abs_errors = [] | ||
| max_rel_errors = [] | ||
|
|
||
| for i in range(batch_size): | ||
| tvd_i = 0.5 * torch.sum(torch.abs(sol_freqs[i] - expected_probs[i])).item() | ||
| tvds.append(tvd_i) | ||
|
|
||
| max_abs_i, max_rel_i, _, _ = compute_error_stats(sol_freqs[i], expected_probs[i], cfg) | ||
| max_abs_errors.append(max_abs_i) | ||
| max_rel_errors.append(max_rel_i) | ||
|
|
||
| numerical_incorrect = tvd > cfg.sampling_tvd_threshold | ||
| # Use the worst (max) TVD and errors across all batch elements | ||
| max_tvd = max(tvds) | ||
| max_abs = max(max_abs_errors) | ||
| max_rel = max(max_rel_errors) | ||
|
|
||
| numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold | ||
| correctness = Correctness( | ||
| max_relative_error=max_rel, max_absolute_error=max_abs, extra={"tvd": tvd} | ||
| max_relative_error=max_rel, | ||
| max_absolute_error=max_abs, | ||
| extra={"tvd": max_tvd, "tvds_per_batch": tvds}, | ||
| ) | ||
| if numerical_incorrect: | ||
| return correctness, make_eval( | ||
|
|
@@ -201,23 +227,117 @@ def _detect_thresholding_method(defn: Definition) -> str: | |
| return "none" # no thresholding | ||
|
|
||
|
|
||
| def _compute_frequency_distribution( | ||
| def _compute_valid_sampling_mask( | ||
| probs: torch.Tensor, method: str, params: Dict[str, Any], eps: float = 5e-2 | ||
| ) -> torch.Tensor: | ||
| """ | ||
| For tie-breaking in top_k (allows any token with prob >= k-th largest) | ||
| and numerical precision in top_p (allows tokens within eps of nucleus boundary). | ||
| """ | ||
| if probs.dim() == 1: | ||
| probs = probs.unsqueeze(0) | ||
|
|
||
| batch_size, vocab_size = probs.shape | ||
| device = probs.device | ||
|
|
||
| if method == "none": | ||
| return torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device) | ||
|
|
||
| mask = torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device) | ||
|
|
||
| if method in ["top_k", "top_k_top_p"]: | ||
| if "top_k" not in params: | ||
| raise ValueError(f"top_k parameter required for {method} but not found") | ||
|
|
||
| top_k_param = params["top_k"] | ||
| for i in range(batch_size): | ||
| k = int(top_k_param[i].item()) if top_k_param.dim() > 0 else int(top_k_param.item()) | ||
|
|
||
| if 0 < k < vocab_size: | ||
| sorted_probs, _ = torch.sort(probs[i], descending=True) | ||
| # k-th largest value (0-indexed, so k-1) | ||
| pivot = sorted_probs[k - 1] | ||
| mask[i] = probs[i] >= pivot # tie-breaking handling | ||
|
|
||
| # Apply top_p mask with epsilon tolerance | ||
| if method in ["top_p", "top_k_top_p"]: | ||
| if "top_p" not in params: | ||
| raise ValueError(f"top_p parameter required for {method} but not found") | ||
|
|
||
| top_p_param = params["top_p"] | ||
| for i in range(batch_size): | ||
| p = float(top_p_param[i].item()) if top_p_param.dim() > 0 else float(top_p_param.item()) | ||
|
|
||
| if 0 < p < 1: | ||
| sorted_probs, sorted_indices = torch.sort(probs[i], descending=True) | ||
| cumsum = torch.cumsum(sorted_probs, dim=0) | ||
|
|
||
| # Find tokens in nucleus (cumsum <= p + eps for numerical tolerance) | ||
| nucleus_mask = cumsum <= (p + eps) | ||
|
|
||
| if not nucleus_mask.any(): | ||
| nucleus_mask[0] = True | ||
|
|
||
| # Map back to original indices | ||
| p_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device) | ||
| p_mask[sorted_indices[nucleus_mask]] = True | ||
|
|
||
| mask[i] = mask[i] & p_mask | ||
|
|
||
| return mask | ||
|
|
||
|
|
||
| def _sample_token_distributions( | ||
| runnable: Runnable, | ||
| inputs: Dict[str, Any], | ||
| device: str, | ||
| defn: Definition, | ||
| num_trials: int = 10000, | ||
| num_trials: int = 500000, | ||
| ) -> torch.Tensor: | ||
| batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1 | ||
| original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1 | ||
| vocab_size = inputs["probs"].shape[-1] | ||
| counter = torch.zeros(vocab_size, dtype=torch.int64, device=torch.device(device)) | ||
|
|
||
| trials_needed = (num_trials + batch_size - 1) // batch_size | ||
| total_samples_collected = 0 | ||
| # Repeat entire input batch to fill up to target_batch_size for efficient sampling | ||
| target_batch_size = 10000 | ||
| repeat_count = target_batch_size // original_batch_size | ||
| actual_batch_size = repeat_count * original_batch_size | ||
|
|
||
| padded_inputs = {} | ||
| for key, value in inputs.items(): | ||
| if isinstance(value, torch.Tensor) and value.dim() > 0: | ||
| if key == "probs": | ||
| # For probs, repeat the entire batch | ||
| if value.dim() == 1: | ||
| value = value.unsqueeze(0) | ||
| # Repeat the entire batch repeat_count times | ||
| padded_value = value.repeat(repeat_count, *([1] * (value.dim() - 1))) | ||
| elif key in ["top_k", "top_p"]: | ||
| # For sampling parameters, repeat the entire batch | ||
| if value.dim() == 0: | ||
| padded_value = value.unsqueeze(0).repeat(actual_batch_size) | ||
| else: | ||
| padded_value = value.repeat(repeat_count) | ||
| else: | ||
| # For other tensors, repeat entire batch along batch dimension | ||
| if value.dim() == 0: | ||
| padded_value = value.unsqueeze(0).repeat(actual_batch_size) | ||
| else: | ||
| padded_value = value.repeat(repeat_count, *([1] * (value.dim() - 1))) | ||
| padded_inputs[key] = padded_value | ||
| else: | ||
| # For non-tensor inputs, keep as is | ||
| padded_inputs[key] = value | ||
|
|
||
| counters = torch.zeros( | ||
| (original_batch_size, vocab_size), dtype=torch.int64, device=torch.device(device) | ||
| ) | ||
|
|
||
| trials_needed = (num_trials + repeat_count - 1) // repeat_count | ||
| total_samples_per_batch = 0 | ||
|
|
||
| for _ in range(trials_needed): | ||
| with torch.no_grad(): | ||
| out = runnable(**inputs) | ||
| out = runnable(**padded_inputs) | ||
|
|
||
| output_names = list(defn.outputs.keys()) | ||
| output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()} | ||
|
|
@@ -229,118 +349,19 @@ def _compute_frequency_distribution( | |
| samples = out_normalized["samples"] | ||
|
|
||
| if samples.dim() == 0: | ||
| # Single sample - assign to first batch element | ||
| sample_idx = samples.item() | ||
| counter[sample_idx] += 1 | ||
| total_samples_collected += 1 | ||
| else: # Batch of samples | ||
| for i in range(samples.numel()): | ||
| sample_idx = samples.flatten()[i].item() | ||
| counter[sample_idx] += 1 | ||
| total_samples_collected += 1 | ||
|
|
||
| frequency = counter.float() / total_samples_collected | ||
| return frequency | ||
|
|
||
|
|
||
| def _check_thresholding( | ||
| samples: torch.Tensor, probs: torch.Tensor, method: str, params: Dict[str, Any] | ||
| ) -> bool: | ||
| """Check if samples conform to the specified thresholding method. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| samples : torch.Tensor | ||
| Sampled token indices. | ||
| probs : torch.Tensor | ||
| Probability distribution used for sampling. | ||
| method : str | ||
| Thresholding method: "top_k", "top_p", "top_k_top_p", or "none". | ||
| params : Dict[str, Any] | ||
| Sampling parameters (top_k, top_p values). | ||
|
|
||
| Returns | ||
| ------- | ||
| bool | ||
| True if samples are valid, False otherwise. | ||
| """ | ||
| batch_size, vocab_size = probs.shape | ||
| device = probs.device | ||
|
|
||
| for i in range(batch_size): | ||
| prob_row = probs[i] | ||
| sample = samples[i].item() | ||
|
|
||
| if method == "top_k": | ||
| if "top_k" not in params: | ||
| raise ValueError("top_k parameter is required for top_k thresholding but not found") | ||
| k = ( | ||
| int(params["top_k"][i].item()) | ||
| if params["top_k"].dim() > 0 | ||
| else int(params["top_k"].item()) | ||
| ) | ||
|
|
||
| if 0 < k < vocab_size: | ||
| sorted_prob_desc, _ = torch.sort(prob_row, descending=True) | ||
| pivot = sorted_prob_desc[k - 1] | ||
| mask_top_k = (prob_row >= pivot).int() | ||
| if mask_top_k[sample] != 1: | ||
| return False | ||
|
|
||
| elif method == "top_p": | ||
| if "top_p" not in params: | ||
| raise ValueError("top_p parameter is required for top_p thresholding but not found") | ||
| p = ( | ||
| float(params["top_p"][i].item()) | ||
| if params["top_p"].dim() > 0 | ||
| else float(params["top_p"].item()) | ||
| ) | ||
|
|
||
| if 0 < p < 1: | ||
| eps = 1e-4 # numerical stability | ||
| sorted_probs, indices = torch.sort(prob_row, descending=False) | ||
| cdf = torch.cumsum(sorted_probs, dim=0) | ||
| valid_mask = cdf > (1 - p) - eps | ||
| valid_indices = indices[valid_mask] | ||
|
|
||
| if sample not in valid_indices: | ||
| return False | ||
|
|
||
| elif method == "top_k_top_p": | ||
| if "top_k" not in params or "top_p" not in params: | ||
| raise ValueError( | ||
| "top_k and top_p parameters are both required for top_k_top_p thresholding but not found" | ||
| ) | ||
| k = ( | ||
| int(params["top_k"][i].item()) | ||
| if params["top_k"].dim() > 0 | ||
| else int(params["top_k"].item()) | ||
| ) | ||
| p = ( | ||
| float(params["top_p"][i].item()) | ||
| if params["top_p"].dim() > 0 | ||
| else float(params["top_p"].item()) | ||
| ) | ||
|
|
||
| if 0 < k < vocab_size: | ||
| sorted_prob_desc, _ = torch.sort(prob_row, descending=True) | ||
| pivot = sorted_prob_desc[k - 1] | ||
| mask_top_k = (prob_row >= pivot).int() | ||
| else: | ||
| mask_top_k = torch.ones(vocab_size, dtype=torch.int32, device=device) | ||
|
|
||
| if 0 < p < 1: | ||
| eps = 1e-4 | ||
| sorted_probs_asc, indices = torch.sort(prob_row, descending=False) | ||
| cdf = torch.cumsum(sorted_probs_asc, dim=0) | ||
| mask_top_p = torch.zeros(vocab_size, dtype=torch.int32, device=device) | ||
| valid_p_mask = cdf > (1 - p) - eps | ||
| mask_top_p[indices[valid_p_mask]] = 1 | ||
| else: | ||
| mask_top_p = torch.ones(vocab_size, dtype=torch.int32, device=device) | ||
|
|
||
| joint_mask = torch.minimum(mask_top_k, mask_top_p) | ||
|
|
||
| if joint_mask[sample] != 1: | ||
| return False | ||
|
|
||
| return True | ||
| counters[0, sample_idx] += 1 | ||
| total_samples_per_batch += 1 | ||
|
Comment on lines
231
to
+355
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When if samples.dim() == 0:
if actual_batch_size != 1:
raise ValueError(
f"Expected a batch of samples (size {actual_batch_size}), but got a scalar."
)
# Single sample - assign to first batch element
sample_idx = samples.item()
counters[0, sample_idx] += 1
total_samples_per_batch += 1 |
||
| else: | ||
| # slice and accumulate per original batch element | ||
| samples_flat = samples.flatten() | ||
| for i in range(samples_flat.numel()): | ||
| batch_idx = i % original_batch_size | ||
| sample_idx = samples_flat[i].item() | ||
| counters[batch_idx, sample_idx] += 1 | ||
| total_samples_per_batch += repeat_count | ||
|
|
||
| # [batch_size, vocab_size] | ||
| frequencies = counters.float() / total_samples_per_batch | ||
| return frequencies | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second cuda synchronize has same issue; apply the same guard.
Mirror the earlier fix after sampling distributions.
🧰 Tools
🪛 Ruff (0.14.1)
179-179: Do not catch blind exception:
Exception(BLE001)
🤖 Prompt for AI Agents