Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
341 changes: 188 additions & 153 deletions flashinfer_bench/bench/evaluators/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,14 @@ def build_baseline(
) # convert logits to probs for sampling
inputs.append(inp)

freq_dist = _compute_frequency_distribution(
ref_runnable, inp, device, defn, num_trials=50000
)
outputs.append({"frequency_distribution": freq_dist})
thresholding_method = _detect_thresholding_method(defn)
params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp}
valid_mask = _compute_valid_sampling_mask(inp["probs"], thresholding_method, params)

masked_probs = inp["probs"] * valid_mask.float()
expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)

outputs.append({"expected_probs": expected_probs})
Comment on lines 58 to 65
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard against zero-sum masked probs to avoid NaNs.

If the valid_mask zeros out all tokens (edge params), masked_probs.sum can be 0 leading to NaNs in expected_probs. Clamp the denominator.

-        masked_probs = inp["probs"] * valid_mask.float()
-        expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)
+        masked_probs = inp["probs"] * valid_mask.float()
+        denom = masked_probs.sum(dim=-1, keepdim=True)
+        # Avoid NaNs if no tokens survive; fall back to uniform over valid_mask
+        denom = torch.where(denom > 0, denom, torch.ones_like(denom))
+        expected_probs = masked_probs / denom
+        # If denom was 0, distribute uniformly across valid tokens
+        zero_rows = (masked_probs.sum(dim=-1, keepdim=True) == 0)
+        if zero_rows.any():
+            uniform = valid_mask.float() / valid_mask.float().sum(dim=-1, keepdim=True).clamp_min(1)
+            expected_probs = torch.where(zero_rows, uniform, expected_probs)
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 62 to 69, the code
divides by masked_probs.sum which can be zero if valid_mask zeros out all
tokens; change the denominator to a clamped value to avoid NaNs by computing the
sum with keepdim=True and then applying .clamp_min(eps) (use a small constant
like 1e-12) before dividing so expected_probs = masked_probs / denom_clamped,
ensuring you preserve shapes and device when creating eps.


latencies: List[float] = []
for inp in inputs:
Expand Down Expand Up @@ -94,15 +98,20 @@ def check_correctness(
log_path: str,
device: str,
) -> Tuple[Optional[Correctness], Optional[Evaluation]]:
ref_freq = ref_outputs[0]["frequency_distribution"]
vocab_size = ref_freq.shape[0]
expected_probs = ref_outputs[0]["expected_probs"]
vocab_size = expected_probs.shape[-1]

inp = inputs[0]
params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp}

output_names = list(defn.outputs.keys())
output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()}

# Compute valid sampling mask based on thresholding
thresholding_method = _detect_thresholding_method(defn)
probs = inp["probs"]
valid_mask = _compute_valid_sampling_mask(probs, thresholding_method, params)

# Validate correct sampling token set
for _ in range(cfg.sampling_validation_trials):
try:
Expand Down Expand Up @@ -137,27 +146,34 @@ def check_correctness(
correctness=correctness,
)

# Validate thresholding
thresholding_method = _detect_thresholding_method(defn)
probs = inp["probs"]
if not _check_thresholding(samples, probs, thresholding_method, params):
correctness = Correctness(
max_relative_error=float("inf"), max_absolute_error=float("inf")
)
message = (
f"Samples {samples.tolist()} does not meet {thresholding_method} thresholding"
)
print(message, file=sys.stderr)
return correctness, make_eval(
status=EvaluationStatus.INCORRECT_NUMERICAL,
device=device,
log_path=log_path,
correctness=correctness,
)
# Validate thresholding - check samples are within valid mask
if samples.dim() == 0:
samples_flat = samples.unsqueeze(0)
else:
samples_flat = samples.flatten()

batch_size = valid_mask.shape[0]
for i in range(len(samples_flat)):
batch_idx = i % batch_size
sample_idx = samples_flat[i].item()
if not valid_mask[batch_idx, sample_idx]:
correctness = Correctness(
max_relative_error=float("inf"), max_absolute_error=float("inf")
)
message = (
f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}"
)
print(message, file=sys.stderr)
Comment on lines 155 to 160
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Avoid truthiness on 0‑dim Torch bool Tensor.

Using if not valid_mask[...] can raise “Boolean value of Tensor is ambiguous.” Convert to Python bool.

-                if not valid_mask[batch_idx, sample_idx]:
+                if not bool(valid_mask[batch_idx, sample_idx].item()):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if not valid_mask[batch_idx, sample_idx]:
correctness = Correctness(
max_relative_error=float("inf"), max_absolute_error=float("inf")
)
message = (
f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}"
)
print(message, file=sys.stderr)
if not bool(valid_mask[batch_idx, sample_idx].item()):
correctness = Correctness(
max_relative_error=float("inf"), max_absolute_error=float("inf")
)
message = (
f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}"
)
print(message, file=sys.stderr)
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 159 to 166, the
condition uses truthiness on a 0‑dim Torch bool Tensor (if not
valid_mask[batch_idx, sample_idx]) which can raise “Boolean value of Tensor is
ambiguous.” Convert the tensor to a Python bool by calling .item() (and .cpu()
if it may be on GPU) before negation, and use that boolean in the if check so
the branch evaluates correctly.

return correctness, make_eval(
status=EvaluationStatus.INCORRECT_NUMERICAL,
device=device,
log_path=log_path,
correctness=correctness,
)

try:
sol_freq = _compute_frequency_distribution(
sol_runnable, inp, device, defn, num_trials=50000
sol_freqs = _sample_token_distributions(
sol_runnable, inp, device, defn, num_trials=500000
)
torch.cuda.synchronize(device)
except Exception:
Comment on lines +169 to 173
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Second cuda synchronize has same issue; apply the same guard.

Mirror the earlier fix after sampling distributions.

-            torch.cuda.synchronize(device)
+            _dev = torch.device(device)
+            if _dev.type == "cuda":
+                torch.cuda.synchronize(_dev)
🧰 Tools
🪛 Ruff (0.14.1)

179-179: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 175 to 179, the
second torch.cuda.synchronize(device) call needs the same protective guard as
the earlier synchronize to avoid raising on non-CUDA or unavailable CUDA
devices; wrap the synchronize in the same conditional/try-except used previously
(e.g., only call if device.type == "cuda" and torch.cuda.is_available(), or call
inside a try: ... except Exception: pass block) so any CUDA sync errors are
handled identically to the earlier fix.

Expand All @@ -166,13 +182,29 @@ def check_correctness(
status=EvaluationStatus.RUNTIME_ERROR, device=device, log_path=log_path
)

# total variation distance
tvd = 0.5 * torch.sum(torch.abs(sol_freq - ref_freq)).item()
max_abs, max_rel, _, _ = compute_error_stats(sol_freq, ref_freq, cfg)

numerical_incorrect = tvd > cfg.sampling_tvd_threshold
batch_size = expected_probs.shape[0]
tvds = []
max_abs_errors = []
max_rel_errors = []

for i in range(batch_size):
tvd_i = 0.5 * torch.sum(torch.abs(sol_freqs[i] - expected_probs[i])).item()
tvds.append(tvd_i)

max_abs_i, max_rel_i, _, _ = compute_error_stats(sol_freqs[i], expected_probs[i], cfg)
max_abs_errors.append(max_abs_i)
max_rel_errors.append(max_rel_i)

# Use the worst (max) TVD and errors across all batch elements
max_tvd = max(tvds)
max_abs = max(max_abs_errors)
max_rel = max(max_rel_errors)

numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
correctness = Correctness(
max_relative_error=max_rel, max_absolute_error=max_abs, extra={"tvd": tvd}
max_relative_error=max_rel,
max_absolute_error=max_abs,
extra={"tvd": max_tvd, "tvds_per_batch": tvds}
)
Comment on lines 197 to 202
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Protect against missing cfg.sampling_tvd_threshold.

If cfg.sampling_tvd_threshold is None/absent, comparison raises. Default or fail fast with message.

-        numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
+        tvd_thresh = getattr(cfg, "sampling_tvd_threshold", None)
+        if tvd_thresh is None:
+            raise ValueError("cfg.sampling_tvd_threshold must be set for sampling evaluation")
+        numerical_incorrect = max_tvd > tvd_thresh
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
correctness = Correctness(
max_relative_error=max_rel, max_absolute_error=max_abs, extra={"tvd": tvd}
max_relative_error=max_rel,
max_absolute_error=max_abs,
extra={"tvd": max_tvd, "tvds_per_batch": tvds}
)
tvd_thresh = getattr(cfg, "sampling_tvd_threshold", None)
if tvd_thresh is None:
raise ValueError("cfg.sampling_tvd_threshold must be set for sampling evaluation")
numerical_incorrect = max_tvd > tvd_thresh
correctness = Correctness(
max_relative_error=max_rel,
max_absolute_error=max_abs,
extra={"tvd": max_tvd, "tvds_per_batch": tvds}
)
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 203 to 208, the
code compares max_tvd against cfg.sampling_tvd_threshold without handling the
case where that config value is None or missing; add a guard before the
comparison that checks if cfg.sampling_tvd_threshold is None and fail fast with
a clear ValueError (or optionally set a documented default threshold) so the
comparison never raises a TypeError; update the subsequent comparison to use the
validated threshold variable.

if numerical_incorrect:
return correctness, make_eval(
Expand Down Expand Up @@ -201,23 +233,125 @@ def _detect_thresholding_method(defn: Definition) -> str:
return "none" # no thresholding


def _compute_frequency_distribution(
def _compute_valid_sampling_mask(
probs: torch.Tensor, method: str, params: Dict[str, Any], eps: float = 1e-5
) -> torch.Tensor:
"""
For tie-breaking in top_k (allows any token with prob >= k-th largest)
and numerical precision in top_p (allows tokens within eps of nucleus boundary).
"""
if probs.dim() == 1:
probs = probs.unsqueeze(0)

batch_size, vocab_size = probs.shape
device = probs.device

if method == "none":
return torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device)

mask = torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device)

if method in ["top_k", "top_k_top_p"]:
if "top_k" not in params:
raise ValueError(f"top_k parameter required for {method} but not found")

top_k_param = params["top_k"]
for i in range(batch_size):
k = (
int(top_k_param[i].item())
if top_k_param.dim() > 0
else int(top_k_param.item())
)

if 0 < k < vocab_size:
sorted_probs, _ = torch.sort(probs[i], descending=True)
# k-th largest value (0-indexed, so k-1)
pivot = sorted_probs[k - 1]
mask[i] = probs[i] >= pivot # tie-breaking handling

# Apply top_p mask with epsilon tolerance
if method in ["top_p", "top_k_top_p"]:
if "top_p" not in params:
raise ValueError(f"top_p parameter required for {method} but not found")

top_p_param = params["top_p"]
for i in range(batch_size):
p = (
float(top_p_param[i].item())
if top_p_param.dim() > 0
else float(top_p_param.item())
)

if 0 < p < 1:
sorted_probs, sorted_indices = torch.sort(probs[i], descending=True)
cumsum = torch.cumsum(sorted_probs, dim=0)

# Find tokens in nucleus (cumsum <= p + eps for numerical tolerance)
nucleus_mask = cumsum <= (p + eps)

if not nucleus_mask.any():
nucleus_mask[0] = True

# Map back to original indices
p_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
p_mask[sorted_indices[nucleus_mask]] = True

mask[i] = mask[i] & p_mask

return mask


def _sample_token_distributions(
runnable: Runnable,
inputs: Dict[str, Any],
device: str,
defn: Definition,
num_trials: int = 10000,
num_trials: int = 500000,
) -> torch.Tensor:
batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
vocab_size = inputs["probs"].shape[-1]
counter = torch.zeros(vocab_size, dtype=torch.int64, device=torch.device(device))

trials_needed = (num_trials + batch_size - 1) // batch_size
total_samples_collected = 0

# Repeat entire input batch to fill up to target_batch_size for efficient sampling
target_batch_size = 10000
repeat_count = target_batch_size // original_batch_size
actual_batch_size = repeat_count * original_batch_size
Comment on lines 297 to 303
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential ZeroDivisionError here. If inputs["probs"] has a shape like (0, vocab_size), original_batch_size will be 0, causing a crash on line 316 when calculating repeat_count.

Additionally, if original_batch_size is larger than target_batch_size, repeat_count will be 0, leading to an actual_batch_size of 0. This will create 0-sized tensors and likely cause issues in the runnable.

I suggest handling the original_batch_size == 0 case explicitly and ensuring repeat_count is at least 1 to prevent these issues.

    original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
    vocab_size = inputs["probs"].shape[-1]

    if original_batch_size == 0:
        return torch.empty((0, vocab_size), dtype=torch.float32, device=torch.device(device))

    # Repeat entire input batch to fill up to target_batch_size for efficient sampling
    target_batch_size = 10000
    repeat_count = max(1, target_batch_size // original_batch_size)
    actual_batch_size = repeat_count * original_batch_size


Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

repeat_count can be 0 when original_batch_size > 10_000 → division by zero and empty batches.

Ensure at least one repeat; prefer ceil to keep high utilization.

-    target_batch_size = 10000
-    repeat_count = target_batch_size // original_batch_size
-    actual_batch_size = repeat_count * original_batch_size
+    import math
+    target_batch_size = 10000
+    repeat_count = max(1, math.ceil(target_batch_size / original_batch_size))
+    actual_batch_size = repeat_count * original_batch_size

Also consider making target_batch_size configurable.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
target_batch_size = 10000
repeat_count = target_batch_size // original_batch_size
actual_batch_size = repeat_count * original_batch_size
import math
target_batch_size = 10000
repeat_count = max(1, math.ceil(target_batch_size / original_batch_size))
actual_batch_size = repeat_count * original_batch_size
🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 315-318, the
current computation uses integer division target_batch_size //
original_batch_size which can yield 0 when original_batch_size > 10000 (leading
to empty batches) and underutilization; change to use a ceiling division and
ensure at least one repeat: compute repeat_count = max(1, ceil(target_batch_size
/ original_batch_size)) (or equivalent integer math), then set actual_batch_size
= repeat_count * original_batch_size, and expose target_batch_size as a
configurable parameter (with validation to be a positive int) so it can be tuned
instead of hardcoding 10000.

padded_inputs = {}
for key, value in inputs.items():
if isinstance(value, torch.Tensor) and value.dim() > 0:
if key == "probs":
# For probs, repeat the entire batch
if value.dim() == 1:
value = value.unsqueeze(0)
# Repeat the entire batch repeat_count times
padded_value = value.repeat(repeat_count, *([1] * (value.dim() - 1)))
elif key in ["top_k", "top_p"]:
# For sampling parameters, repeat the entire batch
if value.dim() == 0:
padded_value = value.unsqueeze(0).repeat(actual_batch_size)
else:
padded_value = value.repeat(repeat_count)
else:
# For other tensors, repeat entire batch along batch dimension
if value.dim() == 0:
padded_value = value.unsqueeze(0).repeat(actual_batch_size)
else:
padded_value = value.repeat(repeat_count, *([1] * (value.dim() - 1)))
padded_inputs[key] = padded_value
else:
# For non-tensor inputs, keep as is
padded_inputs[key] = value

counters = torch.zeros(
(original_batch_size, vocab_size), dtype=torch.int64, device=torch.device(device)
)

trials_needed = (num_trials + actual_batch_size - 1) // actual_batch_size
total_samples_per_batch = 0

for _ in range(trials_needed):
with torch.no_grad():
out = runnable(**inputs)
out = runnable(**padded_inputs)

output_names = list(defn.outputs.keys())
output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()}
Expand All @@ -229,118 +363,19 @@ def _compute_frequency_distribution(
samples = out_normalized["samples"]

if samples.dim() == 0:
# Single sample - assign to first batch element
sample_idx = samples.item()
counter[sample_idx] += 1
total_samples_collected += 1
else: # Batch of samples
for i in range(samples.numel()):
sample_idx = samples.flatten()[i].item()
counter[sample_idx] += 1
total_samples_collected += 1

frequency = counter.float() / total_samples_collected
return frequency


def _check_thresholding(
samples: torch.Tensor, probs: torch.Tensor, method: str, params: Dict[str, Any]
) -> bool:
"""Check if samples conform to the specified thresholding method.

Parameters
----------
samples : torch.Tensor
Sampled token indices.
probs : torch.Tensor
Probability distribution used for sampling.
method : str
Thresholding method: "top_k", "top_p", "top_k_top_p", or "none".
params : Dict[str, Any]
Sampling parameters (top_k, top_p values).

Returns
-------
bool
True if samples are valid, False otherwise.
"""
batch_size, vocab_size = probs.shape
device = probs.device

for i in range(batch_size):
prob_row = probs[i]
sample = samples[i].item()

if method == "top_k":
if "top_k" not in params:
raise ValueError("top_k parameter is required for top_k thresholding but not found")
k = (
int(params["top_k"][i].item())
if params["top_k"].dim() > 0
else int(params["top_k"].item())
)

if 0 < k < vocab_size:
sorted_prob_desc, _ = torch.sort(prob_row, descending=True)
pivot = sorted_prob_desc[k - 1]
mask_top_k = (prob_row >= pivot).int()
if mask_top_k[sample] != 1:
return False

elif method == "top_p":
if "top_p" not in params:
raise ValueError("top_p parameter is required for top_p thresholding but not found")
p = (
float(params["top_p"][i].item())
if params["top_p"].dim() > 0
else float(params["top_p"].item())
)

if 0 < p < 1:
eps = 1e-4 # numerical stability
sorted_probs, indices = torch.sort(prob_row, descending=False)
cdf = torch.cumsum(sorted_probs, dim=0)
valid_mask = cdf > (1 - p) - eps
valid_indices = indices[valid_mask]

if sample not in valid_indices:
return False

elif method == "top_k_top_p":
if "top_k" not in params or "top_p" not in params:
raise ValueError(
"top_k and top_p parameters are both required for top_k_top_p thresholding but not found"
)
k = (
int(params["top_k"][i].item())
if params["top_k"].dim() > 0
else int(params["top_k"].item())
)
p = (
float(params["top_p"][i].item())
if params["top_p"].dim() > 0
else float(params["top_p"].item())
)

if 0 < k < vocab_size:
sorted_prob_desc, _ = torch.sort(prob_row, descending=True)
pivot = sorted_prob_desc[k - 1]
mask_top_k = (prob_row >= pivot).int()
else:
mask_top_k = torch.ones(vocab_size, dtype=torch.int32, device=device)

if 0 < p < 1:
eps = 1e-4
sorted_probs_asc, indices = torch.sort(prob_row, descending=False)
cdf = torch.cumsum(sorted_probs_asc, dim=0)
mask_top_p = torch.zeros(vocab_size, dtype=torch.int32, device=device)
valid_p_mask = cdf > (1 - p) - eps
mask_top_p[indices[valid_p_mask]] = 1
else:
mask_top_p = torch.ones(vocab_size, dtype=torch.int32, device=device)

joint_mask = torch.minimum(mask_top_k, mask_top_p)

if joint_mask[sample] != 1:
return False

return True
counters[0, sample_idx] += 1
total_samples_per_batch += 1
Comment on lines 231 to +355
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When samples.dim() == 0, the code assumes a batch size of 1 and assigns the sample to the first batch element's counter. This is incorrect if original_batch_size > 1, as it would misattribute samples and lead to an incorrect frequency distribution for all batch items. The runnable should be expected to return a batch of samples matching actual_batch_size. If it returns a scalar when a batch is expected, it's a contract violation that should be flagged with an error.

        if samples.dim() == 0:
            if actual_batch_size != 1:
                raise ValueError(
                    f"Expected a batch of samples (size {actual_batch_size}), but got a scalar."
                )
            # Single sample - assign to first batch element
            sample_idx = samples.item()
            counters[0, sample_idx] += 1
            total_samples_per_batch += 1

else:
# slice and accumulate per original batch element
samples_flat = samples.flatten()
for i in range(samples_flat.numel()):
batch_idx = i % original_batch_size
sample_idx = samples_flat[i].item()
counters[batch_idx, sample_idx] += 1
total_samples_per_batch += repeat_count

# [batch_size, vocab_size]
frequencies = counters.float() / total_samples_per_batch
return frequencies
Loading