diff --git a/torchtitan/experiments/kernels/moe_sorting/benchmark.py b/torchtitan/experiments/kernels/moe_sorting/benchmark.py new file mode 100644 index 000000000..1fbe31017 --- /dev/null +++ b/torchtitan/experiments/kernels/moe_sorting/benchmark.py @@ -0,0 +1,519 @@ +import argparse +import os +import time +from typing import Any, Dict, List, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +import torch +from tabulate import tabulate + +try: + import token_sorting_cuda +except ImportError: + print(f"Unable to import token_sorting extension") + raise + + +def pytorch_sort_tokens(topk_ids, x, n_experts): + """Original PyTorch implementation for comparison""" + with torch.no_grad(): + # [seq_len, n_experts] + cnts = topk_ids.new_zeros((topk_ids.shape[0], n_experts)) + # Fill 1 to the selected experts + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + # Token indices for each expert + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + + return sorted_tokens, idxs, tokens_per_expert + + +def cuda_sort_tokens(topk_ids, x, n_experts): + """CUDA optimized implementation""" + sorted_tokens, sorted_indices, tokens_per_expert = ( + token_sorting_cuda.sort_tokens_by_expert(topk_ids, x, n_experts) + ) + return sorted_tokens, sorted_indices, tokens_per_expert + + +def verify_implementations( + seq_len: int, hidden_dim: int, n_experts: int, k: int +) -> bool: + """Verify that PyTorch and CUDA implementations produce identical results""" + print( + f"\nVerifying implementations for {n_experts} experts, {hidden_dim} features:" + ) + + # Create random input data + torch.manual_seed(2020) + device = torch.device("cuda") + + # Generate expert IDs, ensuring they're valid indices + topk_ids = torch.randint( + 0, n_experts, (seq_len, k), device=device, dtype=torch.int64 + ) + x = torch.randn(seq_len, hidden_dim, device=device) + + # Run implementations + pt_sorted, pt_indices, pt_counts = pytorch_sort_tokens(topk_ids, x, n_experts) + cuda_sorted, cuda_indices, cuda_counts = cuda_sort_tokens(topk_ids, x, n_experts) + + # Verify tokens per expert counts + counts_match = torch.allclose(pt_counts, cuda_counts) + + # Verify sorted tokens + tokens_match = torch.allclose(pt_sorted, cuda_sorted, rtol=1e-5, atol=1e-5) + + # Verify indices match + indices_match = torch.equal(pt_indices, cuda_indices) + + # Print results + print(f" Token counts match: {counts_match}") + print(f" Sorted tokens match: {tokens_match}") + print(f" Indices match: {indices_match}") + + overall_match = counts_match and tokens_match and indices_match + + if not overall_match: + print("\nDetailed diagnostics:") + if not counts_match: + diff = torch.abs(pt_counts - cuda_counts) + max_diff = torch.max(diff).item() + print(f" Max count difference: {max_diff}") + print(f" First few PyTorch counts: {pt_counts[:5]}") + print(f" First few CUDA counts: {cuda_counts[:5]}") + + if not tokens_match: + diff = torch.abs(pt_sorted - cuda_sorted) + max_diff = torch.max(diff).item() + max_diff_idx = torch.argmax(diff.view(-1)).item() + print(f" Max token difference: {max_diff} at index {max_diff_idx}") + + if not indices_match: + print(f" First few PyTorch indices: {pt_indices[:5]}") + print(f" First few CUDA indices: {cuda_indices[:5]}") + + return overall_match + + +def benchmark_implementations( + seq_len: int, hidden_dim: int, n_experts: int, k: int, num_runs: int = 10 +) -> Dict[str, Any]: + """Benchmark PyTorch vs CUDA implementations""" + # Create random input data + torch.manual_seed(42) + device = torch.device("cuda") + + topk_ids = torch.randint( + 0, n_experts, (seq_len, k), device=device, dtype=torch.int64 + ) + x = torch.randn(seq_len, hidden_dim, device=device) + + # Warmup + for _ in range(3): + pytorch_sort_tokens(topk_ids, x, n_experts) + cuda_sort_tokens(topk_ids, x, n_experts) + + # Benchmark PyTorch + torch.cuda.synchronize() + pt_times = [] + + for _ in range(num_runs): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + pytorch_sort_tokens(topk_ids, x, n_experts) + end_event.record() + + torch.cuda.synchronize() + pt_times.append(start_event.elapsed_time(end_event)) + + pt_avg_time = sum(pt_times) / len(pt_times) + pt_std_time = torch.tensor(pt_times).std().item() + + # Benchmark CUDA + torch.cuda.synchronize() + cuda_times = [] + + for _ in range(num_runs): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + cuda_sort_tokens(topk_ids, x, n_experts) + end_event.record() + + torch.cuda.synchronize() + cuda_times.append(start_event.elapsed_time(end_event)) + + cuda_avg_time = sum(cuda_times) / len(cuda_times) + cuda_std_time = torch.tensor(cuda_times).std().item() + + # Calculate speedup + speedup = pt_avg_time / cuda_avg_time + + results = { + "seq_len": seq_len, + "hidden_dim": hidden_dim, + "n_experts": n_experts, + "k": k, + "pytorch_time": pt_avg_time, + "pytorch_std": pt_std_time, + "cuda_time": cuda_avg_time, + "cuda_std": cuda_std_time, + "speedup": speedup, + } + + print(f" PyTorch: {pt_avg_time:.3f} ± {pt_std_time:.3f} ms") + print(f" CUDA: {cuda_avg_time:.3f} ± {cuda_std_time:.3f} ms") + print(f" Speedup: {speedup:.2f}x") + + return results + + +def run_benchmarks( + expert_counts: List[int], + hidden_dims: List[int], + seq_lens: List[int], + k_values: List[int], + num_runs: int = 10, + verify: bool = True, +) -> pd.DataFrame: + """Run benchmarks for various configurations""" + all_results = [] + + # Ensure output directory exists + os.makedirs("benchmark_results", exist_ok=True) + + for seq_len in seq_lens: + for n_experts in expert_counts: + for hidden_dim in hidden_dims: + for k in k_values: + print(f"\n{'-'*80}") + print( + f"Benchmarking: seq_len={seq_len}, hidden_dim={hidden_dim}, experts={n_experts}, k={k}" + ) + + if verify: + verification_passed = verify_implementations( + seq_len, hidden_dim, n_experts, k + ) + if not verification_passed: + print( + f"WARNING: Verification failed for this configuration!" + ) + + results = benchmark_implementations( + seq_len, hidden_dim, n_experts, k, num_runs + ) + all_results.append(results) + + # Save incremental results to avoid losing data if something crashes + pd.DataFrame(all_results).to_csv( + "benchmark_results/incremental_results.csv", index=False + ) + + # Convert to DataFrame + results_df = pd.DataFrame(all_results) + + # Save results + results_df.to_csv("benchmark_results/benchmark_results.csv", index=False) + + return results_df + + +def create_plots(results_df: pd.DataFrame): + """Create various plots from benchmark results""" + # Create experts vs time plots for each hidden_dim + for hidden_dim in results_df["hidden_dim"].unique(): + for k in results_df["k"].unique(): + for seq_len in results_df["seq_len"].unique(): + # Filter data + df_filtered = results_df[ + (results_df["hidden_dim"] == hidden_dim) + & (results_df["k"] == k) + & (results_df["seq_len"] == seq_len) + ] + + if df_filtered.empty: + continue + + plt.figure(figsize=(10, 6)) + plt.errorbar( + df_filtered["n_experts"], + df_filtered["pytorch_time"], + yerr=df_filtered["pytorch_std"], + marker="o", + label="PyTorch", + ) + plt.errorbar( + df_filtered["n_experts"], + df_filtered["cuda_time"], + yerr=df_filtered["cuda_std"], + marker="s", + label="CUDA", + ) + + plt.xscale("log", base=2) + plt.yscale("log") + plt.xlabel("Number of Experts") + plt.ylabel("Time (ms)") + plt.title( + f"Execution Time vs. Experts (hidden_dim={hidden_dim}, k={k}, seq_len={seq_len})" + ) + plt.grid(True, which="both", ls="--", alpha=0.5) + plt.legend() + plt.tight_layout() + + plt.savefig( + f"benchmark_results/time_vs_experts_h{hidden_dim}_k{k}_seq{seq_len}.png", + dpi=300, + ) + plt.close() + + # Create speedup plot + plt.figure(figsize=(10, 6)) + plt.plot(df_filtered["n_experts"], df_filtered["speedup"], marker="o") + plt.axhline(y=1.0, color="r", linestyle="--", alpha=0.5) + plt.xscale("log", base=2) + plt.xlabel("Number of Experts") + plt.ylabel("Speedup (PyTorch / CUDA)") + plt.title( + f"Speedup vs. Experts (hidden_dim={hidden_dim}, k={k}, seq_len={seq_len})" + ) + plt.grid(True, which="both", ls="--", alpha=0.5) + plt.tight_layout() + + plt.savefig( + f"benchmark_results/speedup_vs_experts_h{hidden_dim}_k{k}_seq{seq_len}.png", + dpi=300, + ) + plt.close() + + # Create hidden_dim vs time plots for each expert count + for n_experts in results_df["n_experts"].unique(): + for k in results_df["k"].unique(): + for seq_len in results_df["seq_len"].unique(): + # Filter data + df_filtered = results_df[ + (results_df["n_experts"] == n_experts) + & (results_df["k"] == k) + & (results_df["seq_len"] == seq_len) + ] + + if df_filtered.empty: + continue + + plt.figure(figsize=(10, 6)) + plt.errorbar( + df_filtered["hidden_dim"], + df_filtered["pytorch_time"], + yerr=df_filtered["pytorch_std"], + marker="o", + label="PyTorch", + ) + plt.errorbar( + df_filtered["hidden_dim"], + df_filtered["cuda_time"], + yerr=df_filtered["cuda_std"], + marker="s", + label="CUDA", + ) + + plt.xscale("log", base=2) + plt.yscale("log") + plt.xlabel("Hidden Dimension") + plt.ylabel("Time (ms)") + plt.title( + f"Execution Time vs. Hidden Dim (experts={n_experts}, k={k}, seq_len={seq_len})" + ) + plt.grid(True, which="both", ls="--", alpha=0.5) + plt.legend() + plt.tight_layout() + + plt.savefig( + f"benchmark_results/time_vs_hidden_e{n_experts}_k{k}_seq{seq_len}.png", + dpi=300, + ) + plt.close() + + # Create a summary heatmap of speedups + plt.figure(figsize=(12, 8)) + + # For simplicity, fix k=2 and seq_len=2048 for the heatmap + if 2 in results_df["k"].values and 2048 in results_df["seq_len"].values: + df_heatmap = results_df[ + (results_df["k"] == 1) & (results_df["seq_len"] == 4096) + ] + + # Create pivot table for heatmap + heatmap_data = df_heatmap.pivot( + index="n_experts", columns="hidden_dim", values="speedup" + ) + + # Plot heatmap + plt.imshow(heatmap_data, cmap="viridis", aspect="auto", interpolation="nearest") + plt.colorbar(label="Speedup (PyTorch / CUDA)") + + # Set labels + plt.xlabel("Hidden Dimension") + plt.ylabel("Number of Experts") + plt.title("Speedup Heatmap (k=1, seq_len=4096)") + + # Set ticks + plt.xticks(range(len(heatmap_data.columns)), heatmap_data.columns) + plt.yticks(range(len(heatmap_data.index)), heatmap_data.index) + + plt.tight_layout() + plt.savefig("benchmark_results/speedup_heatmap.png", dpi=300) + plt.close() + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark and verify token sorting implementations" + ) + parser.add_argument( + "--seq-lens", + type=str, + default="2048,4096,8192", + help="Comma-separated list of sequence lengths to test", + ) + parser.add_argument( + "--hidden-dims", + type=str, + default="1024,4096,8192", + help="Comma-separated list of hidden dimensions to test", + ) + parser.add_argument( + "--expert-counts", + type=str, + default="16,64,128,256,512", + help="Comma-separated list of expert counts to test", + ) + parser.add_argument( + "--k-values", + type=str, + default="1,2,6", + help="Comma-separated list of k values (experts per token) to test", + ) + parser.add_argument( + "--runs", type=int, default=10, help="Number of runs for each benchmark" + ) + parser.add_argument( + "--skip-verify", action="store_true", help="Skip verification step" + ) + parser.add_argument( + "--quick", + action="store_true", + help="Run a quick benchmark with reduced parameter sets", + ) + args = parser.parse_args() + + # Parse arguments + if args.quick: + # Use a reduced set of parameters for quick testing + seq_lens = [2048] + hidden_dims = [1024] + expert_counts = [16, 128, 512] + k_values = [2] + else: + seq_lens = [int(x) for x in args.seq_lens.split(",")] + hidden_dims = [int(x) for x in args.hidden_dims.split(",")] + expert_counts = [int(x) for x in args.expert_counts.split(",")] + k_values = [int(x) for x in args.k_values.split(",")] + + print("=" * 80) + print("Token Sorting Benchmark") + print("=" * 80) + print(f"Sequence Lengths: {seq_lens}") + print(f"Hidden Dimensions: {hidden_dims}") + print(f"Expert Counts: {expert_counts}") + print(f"K Values: {k_values}") + print(f"Runs per test: {args.runs}") + print(f"Skip verification: {args.skip_verify}") + print("=" * 80) + + # Check CUDA availability + if not torch.cuda.is_available(): + print("ERROR: CUDA is not available. This benchmark requires a GPU.") + return 1 + + # Print CUDA device info + device = torch.cuda.current_device() + print(f"Using GPU: {torch.cuda.get_device_name(device)}") + print(f"CUDA Capability: {torch.cuda.get_device_capability(device)}") + print( + f"CUDA Memory: {torch.cuda.get_device_properties(device).total_memory / 1e9:.2f} GB" + ) + print("=" * 80) + + # Run benchmarks + results_df = run_benchmarks( + expert_counts=expert_counts, + hidden_dims=hidden_dims, + seq_lens=seq_lens, + k_values=k_values, + num_runs=args.runs, + verify=not args.skip_verify, + ) + + # Create plots + create_plots(results_df) + + # Print summary table + print("\n" + "=" * 100) + print("Summary Results (k=2):") + + # For simplicity, show only k=2 in the summary table + if 2 in results_df["k"].values: + summary_df = results_df[results_df["k"] == 2] + + # Create a pivot table for better readability + for seq_len in seq_lens: + if seq_len in summary_df["seq_len"].values: + seq_df = summary_df[summary_df["seq_len"] == seq_len] + + print(f"\nSequence Length: {seq_len}") + + summary_data = [] + for n_experts in expert_counts: + if n_experts not in seq_df["n_experts"].values: + continue + + row = [n_experts] + for hidden_dim in hidden_dims: + if hidden_dim not in seq_df["hidden_dim"].values: + continue + + # Get the speedup for this configuration + speedup = seq_df[ + (seq_df["n_experts"] == n_experts) + & (seq_df["hidden_dim"] == hidden_dim) + ]["speedup"].values + + if len(speedup) > 0: + row.append(f"{speedup[0]:.2f}x") + else: + row.append("N/A") + + summary_data.append(row) + + headers = ["Experts"] + [ + f"Hidden={dim}" + for dim in hidden_dims + if dim in seq_df["hidden_dim"].values + ] + print(tabulate(summary_data, headers=headers, tablefmt="grid")) + + print("\nBenchmark complete! Results saved to benchmark_results/ directory.") + return 0 + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/kernels/moe_sorting/debug_sorting.py b/torchtitan/experiments/kernels/moe_sorting/debug_sorting.py new file mode 100644 index 000000000..7f8651131 --- /dev/null +++ b/torchtitan/experiments/kernels/moe_sorting/debug_sorting.py @@ -0,0 +1,429 @@ +import argparse +import time + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +import torch +from tabulate import tabulate + +try: + import token_sorting_cuda +except ImportError: + print(f"unable to import token_sorting_cuda extension...") + raise + +# temp verify + +print(f"Main Function signature: {token_sorting_cuda.sort_tokens_by_expert.__doc__}") + +import argparse +import time + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import token_sorting_cuda +import torch +from tabulate import tabulate + + +def pytorch_sort_tokens(topk_ids, x, n_experts): + """Original PyTorch implementation for comparison""" + with torch.no_grad(): + # [seq_len, n_experts] + cnts = topk_ids.new_zeros((topk_ids.shape[0], n_experts)) + # Fill 1 to the selected experts + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + # Token indices for each expert + idxs = topk_ids.view(-1).argsort() + sorted_tokens_shape = idxs.shape + x.shape[1:] + sorted_tokens = x[idxs // topk_ids.shape[1]] + + return sorted_tokens, idxs, tokens_per_expert + + +def cuda_sort_tokens(topk_ids, x, n_experts, use_parallel_scan=False): + """CUDA optimized implementation""" + # Ensure tensor types are compatible with CUDA implementation + # The CUDA implementation expects int32 tensors internally but handles int64 conversion + sorted_tokens, sorted_indices, tokens_per_expert = ( + token_sorting_cuda.sort_tokens_by_expert( + topk_ids, x, n_experts, use_parallel_scan + ) + ) + return sorted_tokens, sorted_indices, tokens_per_expert + + +def verify_implementations(seq_len, hidden_dim, n_experts, k=1): + """Verify that PyTorch and CUDA implementations produce identical results""" + print(f"\nVerifying implementations for {n_experts} experts:") + + # Create random input data + torch.manual_seed(42) + device = torch.device("cuda") + + # Generate expert IDs, ensuring they're valid indices + topk_ids = torch.randint( + 0, n_experts, (seq_len, k), device=device, dtype=torch.int64 + ) + x = torch.randn(seq_len, hidden_dim, device=device) + + # Run implementations + pt_sorted, pt_indices, pt_counts = pytorch_sort_tokens(topk_ids, x, n_experts) + cuda_sorted, cuda_indices, cuda_counts = cuda_sort_tokens(topk_ids, x, n_experts) + + # Verify tokens per expert counts + counts_match = torch.allclose(pt_counts, cuda_counts) + + # Verify sorted tokens + # Note: For k>1, the shape of sorted tokens may differ between implementations + # So we need to check if the content matches when reshaping + if pt_sorted.shape[0] == cuda_sorted.shape[0]: + tokens_match = torch.allclose(pt_sorted, cuda_sorted, rtol=1e-5, atol=1e-5) + else: + print( + f" Warning: Shape mismatch - PyTorch: {pt_sorted.shape}, CUDA: {cuda_sorted.shape}" + ) + tokens_match = False + + # Check if indices map correctly - regenerate the original features + # and see if they match the input + if pt_indices.shape[0] == cuda_indices.shape[0]: + # Map back to original token features + pt_original = torch.zeros_like(x) + cuda_original = torch.zeros_like(x) + + # Create mask for valid indices (to handle k>1 case) + valid_pt_indices = pt_indices // k < seq_len + valid_cuda_indices = cuda_indices < seq_len + + # Reconstruct using valid indices only + pt_reconstructed = x[pt_indices[valid_pt_indices] // k] + cuda_reconstructed = x[cuda_indices[valid_cuda_indices]] + + # Check if reconstructed features are close + reconstruction_match = torch.allclose( + pt_reconstructed, cuda_reconstructed, rtol=1e-5, atol=1e-5 + ) + else: + print( + f" Warning: Indices shape mismatch - PyTorch: {pt_indices.shape}, CUDA: {cuda_indices.shape}" + ) + reconstruction_match = False + + # Print results + print(f" Token counts match: {counts_match}") + print(f" Sorted tokens match: {tokens_match}") + print(f" Reconstruction match: {reconstruction_match}") + + overall_match = counts_match and tokens_match and reconstruction_match + + # For deeper verification, output the first few tokens and indices + if not overall_match: + print("\nDetailed debugging info:") + print(" PyTorch tokens per expert:", pt_counts.cpu().numpy()) + print(" CUDA tokens per expert:", cuda_counts.cpu().numpy()) + + print("\n First 5 PyTorch sorted tokens:") + print(pt_sorted[:5, :3].cpu().numpy()) + print(" First 5 CUDA sorted tokens:") + print(cuda_sorted[:5, :3].cpu().numpy()) + + print("\n First 10 PyTorch indices:") + print(pt_indices[:10].cpu().numpy()) + print(" First 10 CUDA indices:") + print(cuda_indices[:10].cpu().numpy()) + + return overall_match + + +def benchmark_implementations( + seq_len, hidden_dim, n_experts, k=1, num_runs=10, verify=True +): + """Benchmark PyTorch vs CUDA implementations""" + # Create random input data + torch.manual_seed(2020) + device = torch.device("cuda") + + topk_ids = torch.randint( + 0, n_experts, (seq_len, k), device=device, dtype=torch.int64 + ) + x = torch.randn(seq_len, hidden_dim, device=device) + + # Verify if requested + if verify: + match = verify_implementations(seq_len, hidden_dim, n_experts, k) + if not match: + print(f" WARNING: Verification failed for {n_experts} experts with k={k}!") + + # Warmup + for _ in range(3): + pytorch_sort_tokens(topk_ids, x, n_experts) + cuda_sort_tokens(topk_ids, x, n_experts) + + results = {} + + # Benchmark PyTorch + torch.cuda.synchronize() + pt_times = [] + + for _ in range(num_runs): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + pytorch_sort_tokens(topk_ids, x, n_experts) + end_event.record() + + torch.cuda.synchronize() + pt_times.append(start_event.elapsed_time(end_event)) + + pt_avg_time = sum(pt_times) / len(pt_times) + pt_std_time = torch.tensor(pt_times).std().item() + + # Benchmark CUDA + torch.cuda.synchronize() + cuda_times = [] + + for _ in range(num_runs): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + cuda_sort_tokens(topk_ids, x, n_experts) + end_event.record() + + torch.cuda.synchronize() + cuda_times.append(start_event.elapsed_time(end_event)) + + cuda_avg_time = sum(cuda_times) / len(cuda_times) + cuda_std_time = torch.tensor(cuda_times).std().item() + + # Calculate speedup + speedup = pt_avg_time / cuda_avg_time + + results = { + "n_experts": n_experts, + "k": k, + "pytorch_time": pt_avg_time, + "pytorch_std": pt_std_time, + "cuda_time": cuda_avg_time, + "cuda_std": cuda_std_time, + "speedup": speedup, + } + + print(f"\nBenchmark Results for {n_experts} experts, k={k}:") + print(f" PyTorch: {pt_avg_time:.3f} ± {pt_std_time:.3f} ms") + print(f" CUDA: {cuda_avg_time:.3f} ± {cuda_std_time:.3f} ms") + print(f" Speedup: {speedup:.2f}x") + + return results + + +def plot_results(results_list): + """Generate plots from benchmark results""" + results_df = pd.DataFrame(results_list) + + # Create directory for plots + import os + + os.makedirs("benchmark_results", exist_ok=True) + + # Group results by k value + k_values = results_df["k"].unique() + + for k in k_values: + k_results = results_df[results_df["k"] == k] + + # 1. Execution Time vs Number of Experts + plt.figure(figsize=(10, 6)) + plt.errorbar( + k_results["n_experts"], + k_results["pytorch_time"], + yerr=k_results["pytorch_std"], + marker="o", + label="PyTorch", + ) + plt.errorbar( + k_results["n_experts"], + k_results["cuda_time"], + yerr=k_results["cuda_std"], + marker="s", + label="CUDA", + ) + plt.xscale("log", base=2) + plt.yscale("log") + plt.xlabel("Number of Experts") + plt.ylabel("Execution Time (ms)") + plt.title(f"Execution Time vs Number of Experts (k={k})") + plt.grid(True, which="both", ls="--", alpha=0.5) + plt.legend() + plt.tight_layout() + plt.savefig(f"benchmark_results/execution_time_k{k}.png", dpi=300) + + # 2. Speedup vs Number of Experts + plt.figure(figsize=(10, 6)) + plt.plot(k_results["n_experts"], k_results["speedup"], marker="o") + plt.axhline(y=1.0, color="r", linestyle="--", alpha=0.5) + plt.xscale("log", base=2) + plt.xlabel("Number of Experts") + plt.ylabel("Speedup Factor (PyTorch / CUDA)") + plt.title(f"CUDA Speedup vs Number of Experts (k={k})") + plt.grid(True, which="both", ls="--", alpha=0.5) + plt.tight_layout() + plt.savefig(f"benchmark_results/speedup_k{k}.png", dpi=300) + + # 3. Impact of k on speedup + if len(k_values) > 1: + plt.figure(figsize=(10, 6)) + for n_exp in results_df["n_experts"].unique(): + n_exp_results = results_df[results_df["n_experts"] == n_exp] + plt.plot( + n_exp_results["k"], + n_exp_results["speedup"], + marker="o", + label=f"{n_exp} experts", + ) + + plt.axhline(y=1.0, color="r", linestyle="--", alpha=0.5) + plt.xlabel("k value (experts per token)") + plt.ylabel("Speedup Factor (PyTorch / CUDA)") + plt.title("Impact of k on CUDA Speedup") + plt.grid(True, which="both", ls="--", alpha=0.5) + plt.legend() + plt.tight_layout() + plt.savefig("benchmark_results/k_impact.png", dpi=300) + + return results_df + + +def verify_k_values(seq_len, hidden_dim, n_experts): + """Verify the implementation for different k values""" + print("\n" + "=" * 80) + print(f"Verifying implementation for different k values (experts per token)") + print("=" * 80) + + k_values = [1, 2, 4, 8] + results = [] + + for k in k_values: + print(f"\nTesting k={k}:") + match = verify_implementations(seq_len, hidden_dim, n_experts, k) + results.append({"k": k, "match": match}) + + print("\nSummary results for different k values:") + for res in results: + status = "✓ PASS" if res["match"] else "✗ FAIL" + print(f"k={res['k']}: {status}") + + return all(res["match"] for res in results) + + +def main(): + parser = argparse.ArgumentParser( + description="Verify and benchmark token sorting implementations" + ) + parser.add_argument("--seq-len", type=int, default=2048, help="Sequence length") + parser.add_argument( + "--hidden-dim", type=int, default=1024, help="Hidden dimension size" + ) + parser.add_argument( + "--runs", type=int, default=10, help="Number of runs for timing" + ) + parser.add_argument( + "--skip-verify", action="store_true", help="Skip verification step" + ) + parser.add_argument( + "--k", type=int, default=1, help="Number of expert assignments per token" + ) + parser.add_argument( + "--verify-k", action="store_true", help="Verify different k values" + ) + parser.add_argument( + "--experts", + type=str, + default="16,64,128,256,512", + help="Comma-separated list of expert counts to test", + ) + args = parser.parse_args() + + print("=" * 80) + print(f"Token Sorting Benchmark") + print( + f"Sequence Length: {args.seq_len}, Hidden Dimension: {args.hidden_dim}, k: {args.k}" + ) + print("=" * 80) + + # Verify different k values if requested + if args.verify_k: + all_k_pass = verify_k_values( + args.seq_len, args.hidden_dim, 64 + ) # Use 64 experts for k verification + if not all_k_pass: + print("\nWARNING: Some k values failed verification!") + + # Parse expert counts + expert_counts = [int(x) for x in args.experts.split(",")] + + # Run benchmarks + results = [] + for n_experts in expert_counts: + result = benchmark_implementations( + args.seq_len, + args.hidden_dim, + n_experts, + k=args.k, + num_runs=args.runs, + verify=not args.skip_verify, + ) + results.append(result) + + # Generate plots + results_df = plot_results(results) + + # Print summary table + print("\n" + "=" * 100) + print("Summary Results:") + + summary_data = [] + for _, row in results_df.iterrows(): + summary_data.append( + [ + row["n_experts"], + row["k"], + f"{row['pytorch_time']:.2f} ± {row['pytorch_std']:.2f}", + f"{row['cuda_time']:.2f} ± {row['cuda_std']:.2f}", + f"{row['speedup']:.2f}x", + ] + ) + + headers = ["Experts", "k", "PyTorch (ms)", "CUDA (ms)", "Speedup"] + print(tabulate(summary_data, headers=headers, tablefmt="grid")) + print("=" * 100) + + print("\nBenchmark complete! Plots saved to benchmark_results/ directory.") + + +if __name__ == "__main__": + + seq_len = 8 + hidden_dim = 4 + n_experts = 4 + k = 2 + topk_ids = torch.randint(0, n_experts, (seq_len, k), device="cuda") + x = torch.randn(seq_len, hidden_dim, device="cuda") + + # Compare results + pt_result = pytorch_sort_tokens(topk_ids, x, n_experts) + cuda_result = cuda_sort_tokens(topk_ids, x, n_experts) + + print(f"{pt_result=}") + print(f"{cuda_result=}") + + # same = torch.allclose(pt_result, cuda_result) + # print(f"{same=}") + + # main() diff --git a/torchtitan/experiments/kernels/moe_sorting/dist/token_sorting-0.0.0-py3.12-linux-x86_64.egg b/torchtitan/experiments/kernels/moe_sorting/dist/token_sorting-0.0.0-py3.12-linux-x86_64.egg new file mode 100644 index 000000000..9fef68f1a Binary files /dev/null and b/torchtitan/experiments/kernels/moe_sorting/dist/token_sorting-0.0.0-py3.12-linux-x86_64.egg differ diff --git a/torchtitan/experiments/kernels/moe_sorting/moe_kernel_utils.h b/torchtitan/experiments/kernels/moe_sorting/moe_kernel_utils.h new file mode 100644 index 000000000..0c7bf30ce --- /dev/null +++ b/torchtitan/experiments/kernels/moe_sorting/moe_kernel_utils.h @@ -0,0 +1,72 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD 3-Clause license found in the +// LICENSE file in the root directory of this source tree. + +/* + * Cuda kernel utils file for MoE related kernels + * basically let's not reinvent the wheel for core functions... + * ====================== + * cdiv + * grid_1d + * grid_2d + * calc_shared_memory_size + * ======================= + + */ + +#pragma once + +#include +#include + +namespace moe_kernel_utils { +/** + * cdiv - Ceiling division - grid and block size calc support + * + * @param numerator Number of elements to process + * @param denominator Number of elements per thread/block + * @return Ceiling of the division (usually number of blocks needed) + */ +inline int cdiv(int numerator, int denominator) { + return (numerator + denominator - 1) / denominator; +} + +/** + * grid_1d - calculate 1D grid size with upper limit + * + * @param elements Number of elements to process + * @param threads_per_block Number of threads per block + * @param max_blocks Upper limit of blocks (default to 256 for now) + * @return optimal number of blocks for the 1d grid + */ +inline int grid_1d(int elements, int threads_per_block, int max_blocks = 256) { + return std::min(max_blocks, cdiv(elements, threads_per_block)); +} + +/** + * grid_2d - calcuate 2d grid based on input dimensions (x,y) + * @param dim_x 1st dimension size - usually rows + * @param dim_y 2nd dimension (usually features/columns) + * @param block_dim_x Number of threads per block in x dimension + * @param block_dim_y Number of threads per block in y dimension + * @return dim3 with grid dimensions + */ +inline dim3 grid_2d(int dim_x, int dim_y, int block_dim_x, int block_dim_y) { + return dim3(cdiv(dim_x, block_dim_x), cdiv(dim_y, block_dim_y)); +} + +/** +* calc_shared_memory_size - calculate shared memory size needed for given type +and count +* +* @param T Type to allocate for +* @param count Num elements +* @return Size in bytes for shared memory allocation + + */ +template inline size_t calc_shared_memory_size(int count) { + return count * sizeof(T); +} +} // namespace moe_kernel_utils diff --git a/torchtitan/experiments/kernels/moe_sorting/setup.py b/torchtitan/experiments/kernels/moe_sorting/setup.py new file mode 100644 index 000000000..660e85e32 --- /dev/null +++ b/torchtitan/experiments/kernels/moe_sorting/setup.py @@ -0,0 +1,34 @@ +import os + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +extra_compile_args = { + "cxx": ["-O3"], + "nvcc": [ + "-O3", + "--gpu-architecture=sm_90", # H100 + "--use_fast_math", + "--extended-lambda", + ], +} + +# Source files +sources = [ + "token_sorting_kernels.cu", +] # "moe_kernel_utils.h"] + +setup( + name="token_sorting_cuda", + version="0.1", + description="CUDA-accelerated token sorting for Mixture of Experts models", + ext_modules=[ + CUDAExtension( + name="token_sorting_cuda", + sources=sources, + extra_compile_args=extra_compile_args, + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/torchtitan/experiments/kernels/moe_sorting/simpletest.py b/torchtitan/experiments/kernels/moe_sorting/simpletest.py new file mode 100644 index 000000000..3fdf73f35 --- /dev/null +++ b/torchtitan/experiments/kernels/moe_sorting/simpletest.py @@ -0,0 +1,272 @@ +""" +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD 3-Clause license found in the +// LICENSE file in the root directory of this source tree. +""" + +import torch + +try: + import token_sorting_cuda +except ImportError: + print(f"unable to import token_sorting_cuda extension...") + raise + +import argparse + +import numpy as np + + +def pytorch_sort_tokens(topk_ids, x, n_experts): + """Original PyTorch implementation for comparison""" + with torch.no_grad(): + # [seq_len, n_experts] + cnts = topk_ids.new_zeros((topk_ids.shape[0], n_experts)) + # Fill 1 to the selected experts + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + # Token indices for each expert + idxs = topk_ids.view(-1).argsort() + sorted_tokens_shape = idxs.shape + x.shape[1:] + sorted_tokens = x[idxs // topk_ids.shape[1]] + + return sorted_tokens, idxs, tokens_per_expert + + +def cuda_sort_tokens(topk_ids, x, n_experts): + """CUDA optimized implementation""" + # topk_int_ids = topk_ids.to(torch.int32) + # topk_int_ids = topk_ids.to(torch.int32) + # print(f"Original dtype: {topk_ids.dtype}, Converted dtype: {topk_int_ids.dtype}") + + # print(f"{topk_int_ids=}, {x=}, {n_experts=}") + + sorted_tokens, sorted_indices, tokens_per_expert = ( + token_sorting_cuda.sort_tokens_by_expert(topk_ids, x, n_experts) + ) + + return sorted_tokens, sorted_indices, tokens_per_expert + + +def test_simple_case(): + """Test with a simple example where we know the expected output""" + device = torch.device("cuda") + + # Create small test case + seq_len = 4 + k = 2 + hidden_dim = 3 + n_experts = 3 + + # Create expert assignments: [[0,1], [1,2], [0,2], [1,0]] + topk_ids = torch.tensor( + [[0, 1], [1, 2], [0, 2], [1, 0]], device=device, dtype=torch.int64 + ) + + # Create token features with recognizable values + x = torch.tensor( + [ + [1.0, 1.1, 1.2], # token 0 + [2.0, 2.1, 2.2], # token 1 + [3.0, 3.1, 3.2], # token 2 + [4.0, 4.1, 4.2], # token 3 + ], + device=device, + dtype=torch.float32, + ) + + print("\n===== SIMPLE TEST CASE =====") + print(f"Input topk_ids:\n{topk_ids}") + print(f"Input tokens:\n{x}") + + # Run implementations + pt_sorted, pt_indices, pt_counts = pytorch_sort_tokens(topk_ids, x, n_experts) + + cuda_sorted, cuda_indices, cuda_counts = cuda_sort_tokens(topk_ids, x, n_experts) + + # Display results + print("\nToken counts per expert:") + print(f"PyTorch: {pt_counts}") + print(f"CUDA: {cuda_counts}") + print(f"Match: {torch.allclose(pt_counts, cuda_counts)}") + + print("\nSorted indices:") + print(f"PyTorch: {pt_indices}") + print(f"CUDA: {cuda_indices}") + print(f"Shapes match: {pt_indices.shape == cuda_indices.shape}") + + print("\nSorted tokens (first few):") + print(f"PyTorch:\n{pt_sorted[:5]}") + print(f"CUDA:\n{cuda_sorted[:5]}") + print(f"Shapes match: {pt_sorted.shape == cuda_sorted.shape}") + + if pt_sorted.shape == cuda_sorted.shape: + tokens_match = torch.allclose(pt_sorted, cuda_sorted, rtol=1e-5, atol=1e-5) + print(f"Values match: {tokens_match}") + + overall_match = ( + torch.allclose(pt_counts, cuda_counts) + and pt_indices.shape == cuda_indices.shape + and pt_sorted.shape == cuda_sorted.shape + and torch.allclose(pt_sorted, cuda_sorted, rtol=1e-5, atol=1e-5) + ) + + print(f"\nOverall match: {overall_match}") + return overall_match + + +def test_random_case(seq_len=16, hidden_dim=8, n_experts=4, k=2): + """Test with random inputs of specified dimensions""" + torch.manual_seed(42) # For reproducibility + device = torch.device("cuda") + + # Create random inputs + topk_ids = torch.randint( + 0, n_experts, (seq_len, k), device=device, dtype=torch.int64 + ) + x = torch.randn(seq_len, hidden_dim, device=device) + + print(f"\n===== RANDOM TEST CASE =====") + print(f"seq_len={seq_len}, hidden_dim={hidden_dim}, n_experts={n_experts}, k={k}") + + # Run implementations + pt_sorted, pt_indices, pt_counts = pytorch_sort_tokens(topk_ids, x, n_experts) + cuda_sorted, cuda_indices, cuda_counts = cuda_sort_tokens(topk_ids, x, n_experts) + + # Display results + print("\nToken counts per expert:") + print(f"PyTorch: {pt_counts}") + print(f"CUDA: {cuda_counts}") + print(f"Match: {torch.allclose(pt_counts, cuda_counts)}") + + print("\nSorted indices shapes:") + print(f"PyTorch: {pt_indices.shape}") + print(f"CUDA: {cuda_indices.shape}") + print(f"Match: {pt_indices.shape == cuda_indices.shape}") + + print("\nSorted tokens shapes:") + print(f"PyTorch: {pt_sorted.shape}") + print(f"CUDA: {cuda_sorted.shape}") + print(f"Match: {pt_sorted.shape == cuda_sorted.shape}") + + if pt_sorted.shape == cuda_sorted.shape: + tokens_match = torch.allclose(pt_sorted, cuda_sorted, rtol=1e-5, atol=1e-5) + print(f"Values match: {tokens_match}") + + overall_match = ( + torch.allclose(pt_counts, cuda_counts) + and pt_indices.shape == cuda_indices.shape + and pt_sorted.shape == cuda_sorted.shape + and torch.allclose(pt_sorted, cuda_sorted, rtol=1e-5, atol=1e-5) + ) + + print(f"\nOverall match: {overall_match}") + return overall_match + + +def debug_equality(pt_sorted, cuda_sorted, pt_indices, cuda_indices): + """Debug why tensors might not be equal""" + print("\n===== DEBUGGING EQUALITY =====") + + if pt_sorted.shape != cuda_sorted.shape: + print(f"Shape mismatch: PyTorch {pt_sorted.shape} vs CUDA {cuda_sorted.shape}") + return + + # Check for NaN or Inf values + print(f"PyTorch has NaN: {torch.isnan(pt_sorted).any()}") + print(f"CUDA has NaN: {torch.isnan(cuda_sorted).any()}") + print(f"PyTorch has Inf: {torch.isinf(pt_sorted).any()}") + print(f"CUDA has Inf: {torch.isinf(cuda_sorted).any()}") + + # Check differences + if not torch.allclose(pt_sorted, cuda_sorted, rtol=1e-5, atol=1e-5): + diff = torch.abs(pt_sorted - cuda_sorted) + max_diff = torch.max(diff).item() + max_diff_idx = torch.argmax(diff.view(-1)).item() + print(f"Max difference: {max_diff} at index {max_diff_idx}") + + # Find rows with largest differences + row_diffs = torch.sum(diff, dim=1) + top_diff_rows = torch.topk(row_diffs, min(5, len(row_diffs))) + print("Top 5 rows with largest differences:") + for i, idx in enumerate(top_diff_rows.indices): + print(f"Row {idx}:") + print(f" PyTorch: {pt_sorted[idx]}") + print(f" CUDA: {cuda_sorted[idx]}") + print(f" Diff: {diff[idx]}") + + # Check if indices are different + if not torch.equal(pt_indices, cuda_indices): + print("\nIndices don't match") + print(f"First 10 PyTorch indices: {pt_indices[:10]}") + print(f"First 10 CUDA indices: {cuda_indices[:10]}") + + # Check distribution of indices + print(f"\nPyTorch indices min: {pt_indices.min()}, max: {pt_indices.max()}") + print(f"CUDA indices min: {cuda_indices.min()}, max: {cuda_indices.max()}") + + # Check uniqueness of indices + pt_unique = torch.unique(pt_indices) + cuda_unique = torch.unique(cuda_indices) + print(f"PyTorch unique indices count: {len(pt_unique)}") + print(f"CUDA unique indices count: {len(cuda_unique)}") + + +def main(): + parser = argparse.ArgumentParser(description="Test token sorting implementations") + parser.add_argument("--seq-len", type=int, default=16, help="Sequence length") + parser.add_argument( + "--hidden-dim", type=int, default=8, help="Hidden dimension size" + ) + parser.add_argument("--experts", type=int, default=4, help="Number of experts") + parser.add_argument( + "--k", type=int, default=2, help="Number of expert assignments per token" + ) + args = parser.parse_args() + + print("=" * 50) + print("Token Sorting Tests") + print("=" * 50) + + # Run the simple test case first + simple_match = test_simple_case() + + # Run the random test case with configurable dimensions + random_match = test_random_case( + seq_len=args.seq_len, + hidden_dim=args.hidden_dim, + n_experts=args.experts, + k=args.k, + ) + + if not simple_match or not random_match: + print("\n⚠️ Some tests failed. Collecting debug information...") + + # Run a debug test case and collect detailed comparison + device = torch.device("cuda") + topk_ids = torch.randint( + 0, args.experts, (args.seq_len, args.k), device=device, dtype=torch.int64 + ) + x = torch.randn(args.seq_len, args.hidden_dim, device=device) + + pt_sorted, pt_indices, pt_counts = pytorch_sort_tokens( + topk_ids, x, args.experts + ) + cuda_sorted, cuda_indices, cuda_counts = cuda_sort_tokens( + topk_ids, x, args.experts + ) + + debug_equality(pt_sorted, cuda_sorted, pt_indices, cuda_indices) + + print("\n" + "=" * 50) + print(f"Simple test result: {'✅ PASS' if simple_match else '❌ FAIL'}") + print(f"Random test result: {'✅ PASS' if random_match else '❌ FAIL'}") + print("=" * 50) + + return 0 if simple_match and random_match else 1 + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/kernels/moe_sorting/speedup_heatmap.png b/torchtitan/experiments/kernels/moe_sorting/speedup_heatmap.png new file mode 100644 index 000000000..4a9a815d1 Binary files /dev/null and b/torchtitan/experiments/kernels/moe_sorting/speedup_heatmap.png differ diff --git a/torchtitan/experiments/kernels/moe_sorting/token_sorting_cuda.cpython-312-x86_64-linux-gnu.so b/torchtitan/experiments/kernels/moe_sorting/token_sorting_cuda.cpython-312-x86_64-linux-gnu.so new file mode 100755 index 000000000..c5b0bac2a Binary files /dev/null and b/torchtitan/experiments/kernels/moe_sorting/token_sorting_cuda.cpython-312-x86_64-linux-gnu.so differ diff --git a/torchtitan/experiments/kernels/moe_sorting/token_sorting_kernels.cu b/torchtitan/experiments/kernels/moe_sorting/token_sorting_kernels.cu new file mode 100644 index 000000000..e41ac9eb3 --- /dev/null +++ b/torchtitan/experiments/kernels/moe_sorting/token_sorting_kernels.cu @@ -0,0 +1,300 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD 3-Clause license found in the +// LICENSE file in the root directory of this source tree. + +/* + * Token sorting for MoE Models + * + */ + +#include "moe_kernel_utils.h" +#include +#include +#include +#include + +// Our utility namespace +using namespace moe_kernel_utils; + +// +// CUDA Kernels +// + +// count tokens per expert +__global__ void optimized_count_tokens_kernel( + int *tokens_per_expert, // output: count of tokens per expert [n_experts] + const int64_t *topk_ids, // input: expert assignments [seq_len, k] + int seq_len, // sequence length + int n_experts, + int k, // top-k experts per token + int hidden_dim) { + + // for local counters + extern __shared__ unsigned int s_expert_counts[]; + + // Initialize shared memory to zero + for (int i = threadIdx.x; i < n_experts; i += blockDim.x) { + s_expert_counts[i] = 0; + } + __syncthreads(); + + // Adjust tokens per thread based on hidden dimension size to reduce register + // pressure for larger feature dimensions + const int tokens_per_thread = + (hidden_dim <= 1024) ? 4 : ((hidden_dim <= 4096) ? 2 : 1); + const int token_stride = blockDim.x * gridDim.x; + + // Registers to track seen experts for each token + // Use 32-bit words to track up to 32 experts per word + // This handles up to 512 experts with 16 words + unsigned int seen_experts[16] = {0}; + + for (int base_idx = blockIdx.x * blockDim.x + threadIdx.x; base_idx < seq_len; + base_idx += token_stride) { + +#pragma unroll + for (int t = 0; t < tokens_per_thread; t++) { + const int token_idx = base_idx + t * token_stride; + if (token_idx >= seq_len) + break; + +// Reset the seen experts tracker for this token +#pragma unroll + for (int w = 0; w < 16; w++) { + seen_experts[w] = 0; + } + +// Process all expert assignments for this token +#pragma unroll + for (int j = 0; j < k; j++) { + const int expert_id = static_cast(topk_ids[token_idx * k + j]); + if (expert_id >= 0 && expert_id < n_experts) { + // Mark this expert as seen + const int word_idx = expert_id >> 5; // expert_id / 32 + const int bit_idx = expert_id & 31; // expert_id % 32 + seen_experts[word_idx] |= (1U << bit_idx); + } + } + +// Update shared memory counters with seen experts +#pragma unroll + for (int w = 0; w < 16; w++) { + unsigned int word = seen_experts[w]; + while (word) { + // Find the least significant set bit + unsigned int bit_pos = __ffs(word) - 1; + + // Calculate expert id + int expert_id = (w << 5) | bit_pos; // (w * 32) + bit_pos + if (expert_id < n_experts) { + // Increment counter for this expert + // Use atomic since multiple threads may update same + // counter + atomicAdd(&s_expert_counts[expert_id], 1); + } + + // Clear the processed bit and continue with next set bit + word &= ~(1U << bit_pos); + } + } + } + } + + // Make sure all threads in the block have updated shared memory + __syncthreads(); + + // Contribute local counts to global counts with coalesced memory access + for (int i = threadIdx.x; i < n_experts; i += blockDim.x) { + if (s_expert_counts[i] > 0) { + atomicAdd(&tokens_per_expert[i], s_expert_counts[i]); + } + } +} + +// Optimized gather kernel for large feature dimensions +template +__global__ void gather_sorted_tokens_kernel_large( + scalar_t *sorted_tokens, // output: sorted token features + const scalar_t *input_tokens, // input: original token features + const int64_t *sort_indices, // input: indices from argsort + int total_elements, // total number of elements + int hidden_dim, // hidden dimension size + int k // k value for integer division +) { + // Calculate global thread ID + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Calculate token and feature indices + int token_idx = idx / hidden_dim; + int feat_idx = idx % hidden_dim; + + if (token_idx < total_elements) { + // Integer division to get token index from flattened index + int src_idx = static_cast(sort_indices[token_idx]) / k; + sorted_tokens[token_idx * hidden_dim + feat_idx] = + input_tokens[src_idx * hidden_dim + feat_idx]; + } +} + +// reorder (gather) +template +__global__ void gather_sorted_tokens_kernel( + scalar_t *sorted_tokens, // output: sorted token features + const scalar_t *input_tokens, // input: original token features + const int64_t *sort_indices, // input: indices from argsort + int total_elements, + int hidden_dim, // hidden dimension size + int k // k value for integer division +) { + // Calculate global thread indices + int token_idx = blockIdx.x * blockDim.x + threadIdx.x; + int feat_idx = blockIdx.y * blockDim.y + threadIdx.y; + + if (token_idx < total_elements && feat_idx < hidden_dim) { + // Integer division to get token index from flattened index + int src_idx = static_cast(sort_indices[token_idx]) / k; + sorted_tokens[token_idx * hidden_dim + feat_idx] = + input_tokens[src_idx * hidden_dim + feat_idx]; + } +} + +////////////////////////////////////////////////////////////////////////////// +// C++/CUDA wrapper functions +////////////////////////////////////////////////////////////////////////////// + +std::tuple +sort_tokens_by_expert_cuda(torch::Tensor topk_ids, torch::Tensor x, + int n_experts) { + + auto device = topk_ids.device(); + int seq_len = topk_ids.size(0); + int k = topk_ids.size(1); + int hidden_dim = x.size(1); + int total_elements = seq_len * k; + + // Validate inputs + TORCH_CHECK(topk_ids.device().is_cuda(), "topk_ids must be a CUDA tensor"); + TORCH_CHECK(x.device().is_cuda(), "input tensor must be a CUDA tensor"); + TORCH_CHECK(topk_ids.dim() == 2, "topk_ids must be a 2D tensor"); + TORCH_CHECK(x.dim() == 2, "input tensor must be a 2D tensor"); + TORCH_CHECK(n_experts <= 512, "Maximum number of experts supported is 512"); + + // Always use int64 for topk_ids to match PyTorch + auto topk_ids_int64 = topk_ids; + if (topk_ids.scalar_type() != torch::kInt64) { + topk_ids_int64 = topk_ids.to(torch::kInt64); + } + topk_ids_int64 = topk_ids_int64.contiguous(); + + // Step 1: Count tokens per expert using specialized CUDA kernel + // Use int32 for token counts to avoid atomicAdd issues + auto tokens_per_expert = + torch::zeros({n_experts}, torch::dtype(torch::kInt32).device(device)); + + // Optimize kernel launch parameters + const int count_threads = 256; + + const int count_blocks = + std::min(256, (seq_len + count_threads - 1) / count_threads); + + int shared_mem_size = n_experts * sizeof(int); + + // Make sure shared memory size is reasonable + int max_shared_mem; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlock, + device.index()); + if (shared_mem_size > max_shared_mem) { + // Fall back to a smaller value + shared_mem_size = max_shared_mem; + } + + // Launch optimized counting kernel + optimized_count_tokens_kernel<<>>( + tokens_per_expert.data_ptr(), topk_ids_int64.data_ptr(), + seq_len, n_experts, k, + hidden_dim // Pass hidden_dim for dynamic optimization + ); + + // Check for errors + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) { + printf("CUDA error in counting kernel: %s\n", cudaGetErrorString(error)); + throw std::runtime_error("CUDA kernel execution failed"); + } + + // Step 2: Use PyTorch's argsort for now... + auto flattened_topk = topk_ids_int64.reshape({-1}); + auto sort_indices = flattened_topk.argsort(); + + // Step 3: Gather the token features + auto sorted_tokens = torch::empty({total_elements, hidden_dim}, x.options()); + + // kernel strategy based on hidden dimension size + if (hidden_dim > 2048) { + // For large feature dimensions, use a 1D kernel with better memory + // coalescing + int block_size = 256; + int num_blocks = + (total_elements * hidden_dim + block_size - 1) / block_size; + + // Launch gather kernel for large dimensions + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x.scalar_type(), "gather_sorted_tokens_cuda_large", ([&] { + gather_sorted_tokens_kernel_large + <<>>(sorted_tokens.data_ptr(), + x.data_ptr(), + sort_indices.data_ptr(), + total_elements, hidden_dim, k); + })); + } else { + // For smaller dimensions, use the 2D kernel + const int gather_token_threads = 16; + const int gather_feature_threads = 16; + + dim3 gather_threads(gather_token_threads, gather_feature_threads); + dim3 gather_blocks((total_elements + gather_threads.x - 1) / + gather_threads.x, + (hidden_dim + gather_threads.y - 1) / gather_threads.y); + + // Launch standard gather kernel + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x.scalar_type(), "gather_sorted_tokens_cuda", ([&] { + gather_sorted_tokens_kernel + <<>>( + sorted_tokens.data_ptr(), x.data_ptr(), + sort_indices.data_ptr(), total_elements, hidden_dim, + k); + })); + } + + // Check for errors + error = cudaGetLastError(); + if (error != cudaSuccess) { + printf("CUDA error in gather_sorted_tokens_kernel: %s\n", + cudaGetErrorString(error)); + throw std::runtime_error("CUDA kernel execution failed"); + } + + // Convert token counts back to match input type + torch::Tensor tokens_per_expert_out; + if (topk_ids.scalar_type() == torch::kInt64) { + tokens_per_expert_out = tokens_per_expert.to(torch::kInt64); + } else { + tokens_per_expert_out = tokens_per_expert; + } + + return std::make_tuple(sorted_tokens, sort_indices, tokens_per_expert_out); +} + +////////////////////////////////////////////////////////////////////////////// +// Python bindings +////////////////////////////////////////////////////////////////////////////// + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("sort_tokens_by_expert", &sort_tokens_by_expert_cuda, + "Sort tokens by expert assignment (CUDA)", py::arg("topk_ids"), + py::arg("x"), py::arg("n_experts")); +}