Skip to content

[https://nvbugs/5970614][fix] Sync CTA before PDL trigger in quantize_with_block_size#14668

Open
tianyuxbear wants to merge 2 commits into
NVIDIA:mainfrom
tianyuxbear:fix/5970614
Open

[https://nvbugs/5970614][fix] Sync CTA before PDL trigger in quantize_with_block_size#14668
tianyuxbear wants to merge 2 commits into
NVIDIA:mainfrom
tianyuxbear:fix/5970614

Conversation

@tianyuxbear

@tianyuxbear tianyuxbear commented May 28, 2026

Copy link
Copy Markdown
Collaborator

Summary

Fixes a PDL (Programmatic Dependent Launch) race in quantize_with_block_size that intermittently corrupts NVFP4 GEMM outputs and degrades GSM8K accuracy for DeepSeek-R1 NVFP4 on GB300 + PP=4 + MTP (nvbug 5970614).

cudaTriggerProgrammaticLaunchCompletion() only signals that the CTA has reached the trigger point — it does not flush prior stores to global memory. Memory visibility for the secondary kernel must be provided either by the producer (a fence before the trigger) or by the consumer (wait_on_dependent_grids() before its first dependent load). In the current NVFP4 path neither side does so: the producer lacks a fence, and the sm103 blockscaled GEMM's main_sf_load warp branch is missing the corresponding wait_on_dependent_grids() (tracked separately in NVIDIA/cutlass#3279).

Compounding this, PDL completion is reported per-CTA at-least-once: a single warp reaching the trigger marks its whole CTA as "trigger reached", even if peer warps in the same CTA are still writing sf_out / out. Once every CTA has been marked, the driver launches the secondary kernel, which TMA-loads partial data — NaNs then propagate through the DeepSeek-R1 forward and corrupt output tokens.

Fix: insert __syncthreads(); __threadfence(); immediately before the trigger, so all warps in the CTA reach the same program point and all their stores are made globally visible before PDL completion is signaled.

In-tree precedent

This mirrors an existing pattern in cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh:865-867, which uses __syncthreads(); membar.gl; immediately before the same cudaTriggerProgrammaticLaunchCompletion() call. __threadfence() is the CUDA intrinsic for membar.gl, so the two forms are semantically equivalent.

Evidence

The race was characterized on release/1.2, where the timing window reproduces it deterministically enough to measure. A cache-bypass ld.global.cv.u8 probe on the SF buffer detects 0x7f poison fill (from the cudaMallocAsync pool) as a direct race indicator:

Config GSM8K Poison probe
Baseline (no fix) 91.93 (FAIL, threshold 92.217) 2.31% (~74k / 3.2M samples)
With this fix 95.34 (PASS, ref 95.42) 0 / 2.4M

On main the race is latent: the GSM8K test no longer fails even without this fix (10/10 PASS, mean 94.95 ± 0.25 with autotuner off; with the fix mean 95.14 ± 0.19). Code/scheduler changes since release/1.2 appear to have narrowed the producer-to-consumer window enough that the race no longer trips the GSM8K threshold — but the underlying defect is unchanged, and any future change that widens the window (cutlass bump, scheduler reordering, kernel fusion) can re-expose it.

The full investigation branch — probes, race-rate measurement, per-run logs — lives on my fork at tianyuxbear/TensorRT-LLM:fix/5970614-bak, which is based on release/1.2 (where the race is deterministically reproducible).

The probe code used to measure the race is not part of this PR; it was a one-off diagnostic inside the cutlass consumer.

Relationship to the test-case waive on main

On main the affected test case (TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_pp4_mtp]) is currently waived for an unrelated OOM issue (nvbug 6018046); that is why the deterministic accuracy gap above was measured on release/1.2 instead.

This PR is standalone with respect to that waive: the change is a CTA-internal __syncthreads(); __threadfence(); immediately before cudaTriggerProgrammaticLaunchCompletion(), with no shared code surface with the OOM root cause. It can be merged without waiting for nvbug 6018046 to be resolved, and once that waive is eventually lifted, this fence will already be in place to keep the race from re-tripping the test should any future change widen the producer-to-consumer window again.

Companion cutlass fix

The consumer-side fix at NVIDIA/cutlass#3279 addresses the missing wait_on_dependent_grids() in the sm103 blockscaled GEMM main_sf_load warp branch (the other half of the race described in Summary). Either fix alone closes the race; both are correct individually:

  • The cutlass fix protects every PDL producer routed through that GEMM.
  • This producer-side fix protects every PDL consumer downstream of quantize_with_block_size, regardless of which cutlass revision trtllm pulls in.

Risk

  • Surface: one producer kernel (quantize_with_block_size), covering NVFP4 / FP8 / MXFP8 paths.
  • Per-CTA cost: one __syncthreads() + one __threadfence() at the very end of the kernel (cold path). No effect on PDL launch overlap — the trigger still happens; it just happens after the CTA has drained.
  • No API or behavior change for callers.

@tianyuxbear tianyuxbear requested a review from a team as a code owner May 28, 2026 04:52
@coderabbitai

coderabbitai Bot commented May 28, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

The PR addresses a CUDA graph execution issue in NVIDIA's TensorRT-LLM by adding memory-ordering synchronization to the quantization kernel and adjusting the DeepSeek-R1 NVFP4 test configuration, then removing the corresponding test waiver.

Changes

NVFP4 DeepSeek-R1 Throughput Stabilization

Layer / File(s) Summary
Kernel synchronization barrier and fence
cpp/tensorrt_llm/kernels/quantization.cuh
CTA-wide barrier and threadfence are inserted in quantize_with_block_size before launch completion to drain stores and ensure visibility to downstream GEMM consumers, fixing nvbug 5970614.
Test configuration adjustment and waiver removal
tests/integration/defs/accuracy/test_llm_api_pytorch.py, tests/integration/test_lists/waives.txt
Batch size reduced to 8 with inline documentation of the cuBLASLt workspace headroom issue; the corresponding SKIP waiver is removed so the test now executes.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

  • NVIDIA/TensorRT-LLM#14504: Also modifies tests/integration/test_lists/waives.txt by changing SKIP/waiver entries for integration tests.

Suggested reviewers

  • mikeiovine
  • dongfengy
  • jieli-matrix
  • xinhe-nv
🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title correctly identifies the primary fix: synchronizing the CTA before PDL trigger in quantize_with_block_size, with proper bug reference format.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed PR description comprehensively explains the PDL race condition, root cause, fix mechanism, evidence, and risk assessment with references to existing patterns and companion fixes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
cpp/tensorrt_llm/kernels/quantization.cuh (1)

2-2: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Update the copyright header year on this modified source file.

This file was modified, but the NVIDIA header still ends at 2023. Please update it to include the latest modification year.

Suggested fix
- * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2019-2026, NVIDIA CORPORATION.  All rights reserved.

As per coding guidelines: **/*.{cpp,cc,h,hpp,py,cu,cuh}: Include NVIDIA copyright header on ALL new files; update year on modified files.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/tensorrt_llm/kernels/quantization.cuh` at line 2, Update the NVIDIA
copyright header at the top of the modified source file by changing the year
range from "2019-2023" to include the current modification year (e.g.
"2019-2026"); edit the top-of-file header comment in quantization.cuh (the
initial copyright comment block) so the year range reflects the latest
modification year.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@cpp/tensorrt_llm/kernels/quantization.cuh`:
- Line 2: Update the NVIDIA copyright header at the top of the modified source
file by changing the year range from "2019-2023" to include the current
modification year (e.g. "2019-2026"); edit the top-of-file header comment in
quantization.cuh (the initial copyright comment block) so the year range
reflects the latest modification year.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 04887fc1-1f9e-4ec2-8b75-a6af13d6921c

📥 Commits

Reviewing files that changed from the base of the PR and between 59d4369 and f8d3481.

📒 Files selected for processing (3)
  • cpp/tensorrt_llm/kernels/quantization.cuh
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tests/integration/test_lists/waives.txt
💤 Files with no reviewable changes (1)
  • tests/integration/test_lists/waives.txt

@tianyuxbear

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53276 [ run ] triggered by Bot. Commit: 43e46dc Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53276 [ run ] completed with state SUCCESS. Commit: 43e46dc
/LLM/main/L0_MergeRequest_PR pipeline #42466 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@tianyuxbear

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53316 [ run ] triggered by Bot. Commit: 8613e8c Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53316 [ run ] completed with state SUCCESS. Commit: 8613e8c
/LLM/main/L0_MergeRequest_PR pipeline #42502 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@tianyuxbear

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53440 [ run ] triggered by Bot. Commit: 40d0115 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53440 [ run ] completed with state SUCCESS. Commit: 40d0115
/LLM/main/L0_MergeRequest_PR pipeline #42608 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@tianyuxbear

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53533 [ run ] triggered by Bot. Commit: fec41f6 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53533 [ run ] completed with state SUCCESS. Commit: fec41f6
/LLM/main/L0_MergeRequest_PR pipeline #42686 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@tianyuxbear

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53601 [ run ] triggered by Bot. Commit: fec41f6 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53601 [ run ] completed with state SUCCESS. Commit: fec41f6
/LLM/main/L0_MergeRequest_PR pipeline #42746 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@tianyuxbear

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53784 [ run ] triggered by Bot. Commit: 13363c1 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53784 [ run ] completed with state SUCCESS. Commit: 13363c1
/LLM/main/L0_MergeRequest_PR pipeline #42902 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@tianyuxbear

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53854 [ run ] triggered by Bot. Commit: ed9b3e3 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53854 [ run ] completed with state SUCCESS. Commit: ed9b3e3
/LLM/main/L0_MergeRequest_PR pipeline #42962 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

…_with_block_size

Signed-off-by: Tianyu Xiong <117647511+tianyuxbear@users.noreply.github.com>
…e_with_block_size

Signed-off-by: Tianyu Xiong <117647511+tianyuxbear@users.noreply.github.com>
@tianyuxbear

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54115 [ run ] triggered by Bot. Commit: cb88f5a Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54115 [ run ] completed with state SUCCESS. Commit: cb88f5a
/LLM/main/L0_MergeRequest_PR pipeline #43198 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants