diff --git a/flashinfer_bench/bench/evaluators/sampling.py b/flashinfer_bench/bench/evaluators/sampling.py index 2dba7c90..efd4187c 100644 --- a/flashinfer_bench/bench/evaluators/sampling.py +++ b/flashinfer_bench/bench/evaluators/sampling.py @@ -53,16 +53,16 @@ def build_baseline( outputs: List[Dict[str, torch.Tensor]] = [] inp = gen_inputs(defn, workload, device=device, stensors=loaded_stensors) - if "probs" in inp: - inp["probs"] = torch.softmax( - inp["probs"], dim=-1 - ) # convert logits to probs for sampling inputs.append(inp) - freq_dist = _compute_frequency_distribution( - ref_runnable, inp, device, defn, num_trials=50000 - ) - outputs.append({"frequency_distribution": freq_dist}) + thresholding_method = _detect_thresholding_method(defn) + params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp} + valid_mask = _compute_valid_sampling_mask(inp["probs"], thresholding_method, params) + + masked_probs = inp["probs"] * valid_mask.float() + expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True) + + outputs.append({"expected_probs": expected_probs}) latencies: List[float] = [] for inp in inputs: @@ -94,8 +94,8 @@ def check_correctness( log_path: str, device: str, ) -> Tuple[Optional[Correctness], Optional[Evaluation]]: - ref_freq = ref_outputs[0]["frequency_distribution"] - vocab_size = ref_freq.shape[0] + expected_probs = ref_outputs[0]["expected_probs"] + vocab_size = expected_probs.shape[-1] inp = inputs[0] params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp} @@ -103,6 +103,11 @@ def check_correctness( output_names = list(defn.outputs.keys()) output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()} + # Compute valid sampling mask based on thresholding + thresholding_method = _detect_thresholding_method(defn) + probs = inp["probs"] + valid_mask = _compute_valid_sampling_mask(probs, thresholding_method, params) + # Validate correct sampling token set for _ in range(cfg.sampling_validation_trials): try: @@ -137,27 +142,32 @@ def check_correctness( correctness=correctness, ) - # Validate thresholding - thresholding_method = _detect_thresholding_method(defn) - probs = inp["probs"] - if not _check_thresholding(samples, probs, thresholding_method, params): - correctness = Correctness( - max_relative_error=float("inf"), max_absolute_error=float("inf") - ) - message = ( - f"Samples {samples.tolist()} does not meet {thresholding_method} thresholding" - ) - print(message, file=sys.stderr) - return correctness, make_eval( - status=EvaluationStatus.INCORRECT_NUMERICAL, - device=device, - log_path=log_path, - correctness=correctness, - ) + # Validate thresholding - check samples are within valid mask + if samples.dim() == 0: + samples_flat = samples.unsqueeze(0) + else: + samples_flat = samples.flatten() + + batch_size = valid_mask.shape[0] + for i in range(len(samples_flat)): + batch_idx = i % batch_size + sample_idx = samples_flat[i].item() + if not valid_mask[batch_idx, sample_idx]: + correctness = Correctness( + max_relative_error=float("inf"), max_absolute_error=float("inf") + ) + message = f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}" + print(message, file=sys.stderr) + return correctness, make_eval( + status=EvaluationStatus.INCORRECT_NUMERICAL, + device=device, + log_path=log_path, + correctness=correctness, + ) try: - sol_freq = _compute_frequency_distribution( - sol_runnable, inp, device, defn, num_trials=50000 + sol_freqs = _sample_token_distributions( + sol_runnable, inp, device, defn, num_trials=500000 ) torch.cuda.synchronize(device) except Exception: @@ -166,13 +176,29 @@ def check_correctness( status=EvaluationStatus.RUNTIME_ERROR, device=device, log_path=log_path ) - # total variation distance - tvd = 0.5 * torch.sum(torch.abs(sol_freq - ref_freq)).item() - max_abs, max_rel, _, _ = compute_error_stats(sol_freq, ref_freq, cfg) + batch_size = expected_probs.shape[0] + tvds = [] + max_abs_errors = [] + max_rel_errors = [] + + for i in range(batch_size): + tvd_i = 0.5 * torch.sum(torch.abs(sol_freqs[i] - expected_probs[i])).item() + tvds.append(tvd_i) + + max_abs_i, max_rel_i, _, _ = compute_error_stats(sol_freqs[i], expected_probs[i], cfg) + max_abs_errors.append(max_abs_i) + max_rel_errors.append(max_rel_i) - numerical_incorrect = tvd > cfg.sampling_tvd_threshold + # Use the worst (max) TVD and errors across all batch elements + max_tvd = max(tvds) + max_abs = max(max_abs_errors) + max_rel = max(max_rel_errors) + + numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold correctness = Correctness( - max_relative_error=max_rel, max_absolute_error=max_abs, extra={"tvd": tvd} + max_relative_error=max_rel, + max_absolute_error=max_abs, + extra={"tvd": max_tvd, "tvds_per_batch": tvds}, ) if numerical_incorrect: return correctness, make_eval( @@ -201,23 +227,117 @@ def _detect_thresholding_method(defn: Definition) -> str: return "none" # no thresholding -def _compute_frequency_distribution( +def _compute_valid_sampling_mask( + probs: torch.Tensor, method: str, params: Dict[str, Any], eps: float = 5e-2 +) -> torch.Tensor: + """ + For tie-breaking in top_k (allows any token with prob >= k-th largest) + and numerical precision in top_p (allows tokens within eps of nucleus boundary). + """ + if probs.dim() == 1: + probs = probs.unsqueeze(0) + + batch_size, vocab_size = probs.shape + device = probs.device + + if method == "none": + return torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device) + + mask = torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device) + + if method in ["top_k", "top_k_top_p"]: + if "top_k" not in params: + raise ValueError(f"top_k parameter required for {method} but not found") + + top_k_param = params["top_k"] + for i in range(batch_size): + k = int(top_k_param[i].item()) if top_k_param.dim() > 0 else int(top_k_param.item()) + + if 0 < k < vocab_size: + sorted_probs, _ = torch.sort(probs[i], descending=True) + # k-th largest value (0-indexed, so k-1) + pivot = sorted_probs[k - 1] + mask[i] = probs[i] >= pivot # tie-breaking handling + + # Apply top_p mask with epsilon tolerance + if method in ["top_p", "top_k_top_p"]: + if "top_p" not in params: + raise ValueError(f"top_p parameter required for {method} but not found") + + top_p_param = params["top_p"] + for i in range(batch_size): + p = float(top_p_param[i].item()) if top_p_param.dim() > 0 else float(top_p_param.item()) + + if 0 < p < 1: + sorted_probs, sorted_indices = torch.sort(probs[i], descending=True) + cumsum = torch.cumsum(sorted_probs, dim=0) + + # Find tokens in nucleus (cumsum <= p + eps for numerical tolerance) + nucleus_mask = cumsum <= (p + eps) + + if not nucleus_mask.any(): + nucleus_mask[0] = True + + # Map back to original indices + p_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device) + p_mask[sorted_indices[nucleus_mask]] = True + + mask[i] = mask[i] & p_mask + + return mask + + +def _sample_token_distributions( runnable: Runnable, inputs: Dict[str, Any], device: str, defn: Definition, - num_trials: int = 10000, + num_trials: int = 500000, ) -> torch.Tensor: - batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1 + original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1 vocab_size = inputs["probs"].shape[-1] - counter = torch.zeros(vocab_size, dtype=torch.int64, device=torch.device(device)) - trials_needed = (num_trials + batch_size - 1) // batch_size - total_samples_collected = 0 + # Repeat entire input batch to fill up to target_batch_size for efficient sampling + target_batch_size = 10000 + repeat_count = target_batch_size // original_batch_size + actual_batch_size = repeat_count * original_batch_size + + padded_inputs = {} + for key, value in inputs.items(): + if isinstance(value, torch.Tensor) and value.dim() > 0: + if key == "probs": + # For probs, repeat the entire batch + if value.dim() == 1: + value = value.unsqueeze(0) + # Repeat the entire batch repeat_count times + padded_value = value.repeat(repeat_count, *([1] * (value.dim() - 1))) + elif key in ["top_k", "top_p"]: + # For sampling parameters, repeat the entire batch + if value.dim() == 0: + padded_value = value.unsqueeze(0).repeat(actual_batch_size) + else: + padded_value = value.repeat(repeat_count) + else: + # For other tensors, repeat entire batch along batch dimension + if value.dim() == 0: + padded_value = value.unsqueeze(0).repeat(actual_batch_size) + else: + padded_value = value.repeat(repeat_count, *([1] * (value.dim() - 1))) + padded_inputs[key] = padded_value + else: + # For non-tensor inputs, keep as is + padded_inputs[key] = value + + counters = torch.zeros( + (original_batch_size, vocab_size), dtype=torch.int64, device=torch.device(device) + ) + + trials_needed = (num_trials + repeat_count - 1) // repeat_count + total_samples_per_batch = 0 for _ in range(trials_needed): with torch.no_grad(): - out = runnable(**inputs) + out = runnable(**padded_inputs) output_names = list(defn.outputs.keys()) output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()} @@ -229,118 +349,19 @@ def _compute_frequency_distribution( samples = out_normalized["samples"] if samples.dim() == 0: + # Single sample - assign to first batch element sample_idx = samples.item() - counter[sample_idx] += 1 - total_samples_collected += 1 - else: # Batch of samples - for i in range(samples.numel()): - sample_idx = samples.flatten()[i].item() - counter[sample_idx] += 1 - total_samples_collected += 1 - - frequency = counter.float() / total_samples_collected - return frequency - - -def _check_thresholding( - samples: torch.Tensor, probs: torch.Tensor, method: str, params: Dict[str, Any] -) -> bool: - """Check if samples conform to the specified thresholding method. - - Parameters - ---------- - samples : torch.Tensor - Sampled token indices. - probs : torch.Tensor - Probability distribution used for sampling. - method : str - Thresholding method: "top_k", "top_p", "top_k_top_p", or "none". - params : Dict[str, Any] - Sampling parameters (top_k, top_p values). - - Returns - ------- - bool - True if samples are valid, False otherwise. - """ - batch_size, vocab_size = probs.shape - device = probs.device - - for i in range(batch_size): - prob_row = probs[i] - sample = samples[i].item() - - if method == "top_k": - if "top_k" not in params: - raise ValueError("top_k parameter is required for top_k thresholding but not found") - k = ( - int(params["top_k"][i].item()) - if params["top_k"].dim() > 0 - else int(params["top_k"].item()) - ) - - if 0 < k < vocab_size: - sorted_prob_desc, _ = torch.sort(prob_row, descending=True) - pivot = sorted_prob_desc[k - 1] - mask_top_k = (prob_row >= pivot).int() - if mask_top_k[sample] != 1: - return False - - elif method == "top_p": - if "top_p" not in params: - raise ValueError("top_p parameter is required for top_p thresholding but not found") - p = ( - float(params["top_p"][i].item()) - if params["top_p"].dim() > 0 - else float(params["top_p"].item()) - ) - - if 0 < p < 1: - eps = 1e-4 # numerical stability - sorted_probs, indices = torch.sort(prob_row, descending=False) - cdf = torch.cumsum(sorted_probs, dim=0) - valid_mask = cdf > (1 - p) - eps - valid_indices = indices[valid_mask] - - if sample not in valid_indices: - return False - - elif method == "top_k_top_p": - if "top_k" not in params or "top_p" not in params: - raise ValueError( - "top_k and top_p parameters are both required for top_k_top_p thresholding but not found" - ) - k = ( - int(params["top_k"][i].item()) - if params["top_k"].dim() > 0 - else int(params["top_k"].item()) - ) - p = ( - float(params["top_p"][i].item()) - if params["top_p"].dim() > 0 - else float(params["top_p"].item()) - ) - - if 0 < k < vocab_size: - sorted_prob_desc, _ = torch.sort(prob_row, descending=True) - pivot = sorted_prob_desc[k - 1] - mask_top_k = (prob_row >= pivot).int() - else: - mask_top_k = torch.ones(vocab_size, dtype=torch.int32, device=device) - - if 0 < p < 1: - eps = 1e-4 - sorted_probs_asc, indices = torch.sort(prob_row, descending=False) - cdf = torch.cumsum(sorted_probs_asc, dim=0) - mask_top_p = torch.zeros(vocab_size, dtype=torch.int32, device=device) - valid_p_mask = cdf > (1 - p) - eps - mask_top_p[indices[valid_p_mask]] = 1 - else: - mask_top_p = torch.ones(vocab_size, dtype=torch.int32, device=device) - - joint_mask = torch.minimum(mask_top_k, mask_top_p) - - if joint_mask[sample] != 1: - return False - - return True + counters[0, sample_idx] += 1 + total_samples_per_batch += 1 + else: + # slice and accumulate per original batch element + samples_flat = samples.flatten() + for i in range(samples_flat.numel()): + batch_idx = i % original_batch_size + sample_idx = samples_flat[i].item() + counters[batch_idx, sample_idx] += 1 + total_samples_per_batch += repeat_count + + # [batch_size, vocab_size] + frequencies = counters.float() / total_samples_per_batch + return frequencies