Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 176 additions & 155 deletions flashinfer_bench/bench/evaluators/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ def build_baseline(
outputs: List[Dict[str, torch.Tensor]] = []

inp = gen_inputs(defn, workload, device=device, stensors=loaded_stensors)
if "probs" in inp:
inp["probs"] = torch.softmax(
inp["probs"], dim=-1
) # convert logits to probs for sampling
inputs.append(inp)

freq_dist = _compute_frequency_distribution(
ref_runnable, inp, device, defn, num_trials=50000
)
outputs.append({"frequency_distribution": freq_dist})
thresholding_method = _detect_thresholding_method(defn)
params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp}
valid_mask = _compute_valid_sampling_mask(inp["probs"], thresholding_method, params)

masked_probs = inp["probs"] * valid_mask.float()
expected_probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True)

outputs.append({"expected_probs": expected_probs})

latencies: List[float] = []
for inp in inputs:
Expand Down Expand Up @@ -94,15 +94,20 @@ def check_correctness(
log_path: str,
device: str,
) -> Tuple[Optional[Correctness], Optional[Evaluation]]:
ref_freq = ref_outputs[0]["frequency_distribution"]
vocab_size = ref_freq.shape[0]
expected_probs = ref_outputs[0]["expected_probs"]
vocab_size = expected_probs.shape[-1]

inp = inputs[0]
params = {k: inp[k] for k in ["top_k", "top_p"] if k in inp}

output_names = list(defn.outputs.keys())
output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()}

# Compute valid sampling mask based on thresholding
thresholding_method = _detect_thresholding_method(defn)
probs = inp["probs"]
valid_mask = _compute_valid_sampling_mask(probs, thresholding_method, params)

# Validate correct sampling token set
for _ in range(cfg.sampling_validation_trials):
try:
Expand Down Expand Up @@ -137,27 +142,32 @@ def check_correctness(
correctness=correctness,
)

# Validate thresholding
thresholding_method = _detect_thresholding_method(defn)
probs = inp["probs"]
if not _check_thresholding(samples, probs, thresholding_method, params):
correctness = Correctness(
max_relative_error=float("inf"), max_absolute_error=float("inf")
)
message = (
f"Samples {samples.tolist()} does not meet {thresholding_method} thresholding"
)
print(message, file=sys.stderr)
return correctness, make_eval(
status=EvaluationStatus.INCORRECT_NUMERICAL,
device=device,
log_path=log_path,
correctness=correctness,
)
# Validate thresholding - check samples are within valid mask
if samples.dim() == 0:
samples_flat = samples.unsqueeze(0)
else:
samples_flat = samples.flatten()

batch_size = valid_mask.shape[0]
for i in range(len(samples_flat)):
batch_idx = i % batch_size
sample_idx = samples_flat[i].item()
if not valid_mask[batch_idx, sample_idx]:
correctness = Correctness(
max_relative_error=float("inf"), max_absolute_error=float("inf")
)
message = f"Sample {sample_idx} is outside valid {thresholding_method} mask for batch {batch_idx}"
print(message, file=sys.stderr)
return correctness, make_eval(
status=EvaluationStatus.INCORRECT_NUMERICAL,
device=device,
log_path=log_path,
correctness=correctness,
)

try:
sol_freq = _compute_frequency_distribution(
sol_runnable, inp, device, defn, num_trials=50000
sol_freqs = _sample_token_distributions(
sol_runnable, inp, device, defn, num_trials=500000
)
torch.cuda.synchronize(device)
except Exception:
Comment on lines +169 to 173
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Second cuda synchronize has same issue; apply the same guard.

Mirror the earlier fix after sampling distributions.

-            torch.cuda.synchronize(device)
+            _dev = torch.device(device)
+            if _dev.type == "cuda":
+                torch.cuda.synchronize(_dev)
🧰 Tools
🪛 Ruff (0.14.1)

179-179: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In flashinfer_bench/bench/evaluators/sampling.py around lines 175 to 179, the
second torch.cuda.synchronize(device) call needs the same protective guard as
the earlier synchronize to avoid raising on non-CUDA or unavailable CUDA
devices; wrap the synchronize in the same conditional/try-except used previously
(e.g., only call if device.type == "cuda" and torch.cuda.is_available(), or call
inside a try: ... except Exception: pass block) so any CUDA sync errors are
handled identically to the earlier fix.

Expand All @@ -166,13 +176,29 @@ def check_correctness(
status=EvaluationStatus.RUNTIME_ERROR, device=device, log_path=log_path
)

# total variation distance
tvd = 0.5 * torch.sum(torch.abs(sol_freq - ref_freq)).item()
max_abs, max_rel, _, _ = compute_error_stats(sol_freq, ref_freq, cfg)
batch_size = expected_probs.shape[0]
tvds = []
max_abs_errors = []
max_rel_errors = []

for i in range(batch_size):
tvd_i = 0.5 * torch.sum(torch.abs(sol_freqs[i] - expected_probs[i])).item()
tvds.append(tvd_i)

max_abs_i, max_rel_i, _, _ = compute_error_stats(sol_freqs[i], expected_probs[i], cfg)
max_abs_errors.append(max_abs_i)
max_rel_errors.append(max_rel_i)

numerical_incorrect = tvd > cfg.sampling_tvd_threshold
# Use the worst (max) TVD and errors across all batch elements
max_tvd = max(tvds)
max_abs = max(max_abs_errors)
max_rel = max(max_rel_errors)

numerical_incorrect = max_tvd > cfg.sampling_tvd_threshold
correctness = Correctness(
max_relative_error=max_rel, max_absolute_error=max_abs, extra={"tvd": tvd}
max_relative_error=max_rel,
max_absolute_error=max_abs,
extra={"tvd": max_tvd, "tvds_per_batch": tvds},
)
if numerical_incorrect:
return correctness, make_eval(
Expand Down Expand Up @@ -201,23 +227,117 @@ def _detect_thresholding_method(defn: Definition) -> str:
return "none" # no thresholding


def _compute_frequency_distribution(
def _compute_valid_sampling_mask(
probs: torch.Tensor, method: str, params: Dict[str, Any], eps: float = 5e-2
) -> torch.Tensor:
"""
For tie-breaking in top_k (allows any token with prob >= k-th largest)
and numerical precision in top_p (allows tokens within eps of nucleus boundary).
"""
if probs.dim() == 1:
probs = probs.unsqueeze(0)

batch_size, vocab_size = probs.shape
device = probs.device

if method == "none":
return torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device)

mask = torch.ones((batch_size, vocab_size), dtype=torch.bool, device=device)

if method in ["top_k", "top_k_top_p"]:
if "top_k" not in params:
raise ValueError(f"top_k parameter required for {method} but not found")

top_k_param = params["top_k"]
for i in range(batch_size):
k = int(top_k_param[i].item()) if top_k_param.dim() > 0 else int(top_k_param.item())

if 0 < k < vocab_size:
sorted_probs, _ = torch.sort(probs[i], descending=True)
# k-th largest value (0-indexed, so k-1)
pivot = sorted_probs[k - 1]
mask[i] = probs[i] >= pivot # tie-breaking handling

# Apply top_p mask with epsilon tolerance
if method in ["top_p", "top_k_top_p"]:
if "top_p" not in params:
raise ValueError(f"top_p parameter required for {method} but not found")

top_p_param = params["top_p"]
for i in range(batch_size):
p = float(top_p_param[i].item()) if top_p_param.dim() > 0 else float(top_p_param.item())

if 0 < p < 1:
sorted_probs, sorted_indices = torch.sort(probs[i], descending=True)
cumsum = torch.cumsum(sorted_probs, dim=0)

# Find tokens in nucleus (cumsum <= p + eps for numerical tolerance)
nucleus_mask = cumsum <= (p + eps)

if not nucleus_mask.any():
nucleus_mask[0] = True

# Map back to original indices
p_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
p_mask[sorted_indices[nucleus_mask]] = True

mask[i] = mask[i] & p_mask

return mask


def _sample_token_distributions(
runnable: Runnable,
inputs: Dict[str, Any],
device: str,
defn: Definition,
num_trials: int = 10000,
num_trials: int = 500000,
) -> torch.Tensor:
batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
original_batch_size = inputs["probs"].shape[0] if inputs["probs"].dim() > 1 else 1
vocab_size = inputs["probs"].shape[-1]
counter = torch.zeros(vocab_size, dtype=torch.int64, device=torch.device(device))

trials_needed = (num_trials + batch_size - 1) // batch_size
total_samples_collected = 0
# Repeat entire input batch to fill up to target_batch_size for efficient sampling
target_batch_size = 10000
repeat_count = target_batch_size // original_batch_size
actual_batch_size = repeat_count * original_batch_size

padded_inputs = {}
for key, value in inputs.items():
if isinstance(value, torch.Tensor) and value.dim() > 0:
if key == "probs":
# For probs, repeat the entire batch
if value.dim() == 1:
value = value.unsqueeze(0)
# Repeat the entire batch repeat_count times
padded_value = value.repeat(repeat_count, *([1] * (value.dim() - 1)))
elif key in ["top_k", "top_p"]:
# For sampling parameters, repeat the entire batch
if value.dim() == 0:
padded_value = value.unsqueeze(0).repeat(actual_batch_size)
else:
padded_value = value.repeat(repeat_count)
else:
# For other tensors, repeat entire batch along batch dimension
if value.dim() == 0:
padded_value = value.unsqueeze(0).repeat(actual_batch_size)
else:
padded_value = value.repeat(repeat_count, *([1] * (value.dim() - 1)))
padded_inputs[key] = padded_value
else:
# For non-tensor inputs, keep as is
padded_inputs[key] = value

counters = torch.zeros(
(original_batch_size, vocab_size), dtype=torch.int64, device=torch.device(device)
)

trials_needed = (num_trials + repeat_count - 1) // repeat_count
total_samples_per_batch = 0

for _ in range(trials_needed):
with torch.no_grad():
out = runnable(**inputs)
out = runnable(**padded_inputs)

output_names = list(defn.outputs.keys())
output_dtypes = {k: dtype_str_to_torch_dtype(v.dtype) for k, v in defn.outputs.items()}
Expand All @@ -229,118 +349,19 @@ def _compute_frequency_distribution(
samples = out_normalized["samples"]

if samples.dim() == 0:
# Single sample - assign to first batch element
sample_idx = samples.item()
counter[sample_idx] += 1
total_samples_collected += 1
else: # Batch of samples
for i in range(samples.numel()):
sample_idx = samples.flatten()[i].item()
counter[sample_idx] += 1
total_samples_collected += 1

frequency = counter.float() / total_samples_collected
return frequency


def _check_thresholding(
samples: torch.Tensor, probs: torch.Tensor, method: str, params: Dict[str, Any]
) -> bool:
"""Check if samples conform to the specified thresholding method.

Parameters
----------
samples : torch.Tensor
Sampled token indices.
probs : torch.Tensor
Probability distribution used for sampling.
method : str
Thresholding method: "top_k", "top_p", "top_k_top_p", or "none".
params : Dict[str, Any]
Sampling parameters (top_k, top_p values).

Returns
-------
bool
True if samples are valid, False otherwise.
"""
batch_size, vocab_size = probs.shape
device = probs.device

for i in range(batch_size):
prob_row = probs[i]
sample = samples[i].item()

if method == "top_k":
if "top_k" not in params:
raise ValueError("top_k parameter is required for top_k thresholding but not found")
k = (
int(params["top_k"][i].item())
if params["top_k"].dim() > 0
else int(params["top_k"].item())
)

if 0 < k < vocab_size:
sorted_prob_desc, _ = torch.sort(prob_row, descending=True)
pivot = sorted_prob_desc[k - 1]
mask_top_k = (prob_row >= pivot).int()
if mask_top_k[sample] != 1:
return False

elif method == "top_p":
if "top_p" not in params:
raise ValueError("top_p parameter is required for top_p thresholding but not found")
p = (
float(params["top_p"][i].item())
if params["top_p"].dim() > 0
else float(params["top_p"].item())
)

if 0 < p < 1:
eps = 1e-4 # numerical stability
sorted_probs, indices = torch.sort(prob_row, descending=False)
cdf = torch.cumsum(sorted_probs, dim=0)
valid_mask = cdf > (1 - p) - eps
valid_indices = indices[valid_mask]

if sample not in valid_indices:
return False

elif method == "top_k_top_p":
if "top_k" not in params or "top_p" not in params:
raise ValueError(
"top_k and top_p parameters are both required for top_k_top_p thresholding but not found"
)
k = (
int(params["top_k"][i].item())
if params["top_k"].dim() > 0
else int(params["top_k"].item())
)
p = (
float(params["top_p"][i].item())
if params["top_p"].dim() > 0
else float(params["top_p"].item())
)

if 0 < k < vocab_size:
sorted_prob_desc, _ = torch.sort(prob_row, descending=True)
pivot = sorted_prob_desc[k - 1]
mask_top_k = (prob_row >= pivot).int()
else:
mask_top_k = torch.ones(vocab_size, dtype=torch.int32, device=device)

if 0 < p < 1:
eps = 1e-4
sorted_probs_asc, indices = torch.sort(prob_row, descending=False)
cdf = torch.cumsum(sorted_probs_asc, dim=0)
mask_top_p = torch.zeros(vocab_size, dtype=torch.int32, device=device)
valid_p_mask = cdf > (1 - p) - eps
mask_top_p[indices[valid_p_mask]] = 1
else:
mask_top_p = torch.ones(vocab_size, dtype=torch.int32, device=device)

joint_mask = torch.minimum(mask_top_k, mask_top_p)

if joint_mask[sample] != 1:
return False

return True
counters[0, sample_idx] += 1
total_samples_per_batch += 1
Comment on lines 231 to +355
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When samples.dim() == 0, the code assumes a batch size of 1 and assigns the sample to the first batch element's counter. This is incorrect if original_batch_size > 1, as it would misattribute samples and lead to an incorrect frequency distribution for all batch items. The runnable should be expected to return a batch of samples matching actual_batch_size. If it returns a scalar when a batch is expected, it's a contract violation that should be flagged with an error.

        if samples.dim() == 0:
            if actual_batch_size != 1:
                raise ValueError(
                    f"Expected a batch of samples (size {actual_batch_size}), but got a scalar."
                )
            # Single sample - assign to first batch element
            sample_idx = samples.item()
            counters[0, sample_idx] += 1
            total_samples_per_batch += 1

else:
# slice and accumulate per original batch element
samples_flat = samples.flatten()
for i in range(samples_flat.numel()):
batch_idx = i % original_batch_size
sample_idx = samples_flat[i].item()
counters[batch_idx, sample_idx] += 1
total_samples_per_batch += repeat_count

# [batch_size, vocab_size]
frequencies = counters.float() / total_samples_per_batch
return frequencies