Skip to content

[Thunderfx report] Adds Thunder segmentation and fusion information #1747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 7, 2025

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented Feb 5, 2025

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #1741 .

adds ThunderFXGraphReport, ThunderSplitGraphReport, ThunderFusionReport, analyze_with_thunder to support query thunder split graph and nvfusion information
An example:

import tempfile
from pathlib import Path
import torch
from thunder.dynamo.report import (
    fx_report, FXReport, ThunderFXGraphReport, FXGraphReport, 
    ThunderSplitGraphReport, ThunderFusionReport, analyze_thunder_splits
)
from pathlib import Path

x = torch.ones(2, 2, device="cuda", requires_grad=True)

# Dynamo segments `foo` into two graphs. Each graph contains one Thunder-split graph, 
# and each Thunder-split graph has one nvFusion region.
def foo(x):
    x = x.exp()
    torch._dynamo.graph_break()
    y = torch.sinc(x) + torch.cos(x)
    return y + 1

# If using `torch.compile` alone, you can stop here and query the reports in `FXReport`. 
# For more details, see the example in :func:`fx_report`.
results: FXReport = fx_report(foo, x)

tmp_path = Path("tmp_repro")
fx_graph_report: FXGraphReport
for idx, fx_graph_report in enumerate(results.fx_graph_reports):
    # `ThunderFXGraphReport` extends `FXGraphReport`, providing the ability to save 
    # reproduction/benchmark scripts for the original FX graph. Additionally, it 
    # includes information about Thunder-split subgraphs in `subgraph_reports`.
    thunder_fx_graph_report: ThunderFXGraphReport = analyze_thunder_splits(fx_graph_report)
    # Saves a reproduction script for the original FX graph
    thunder_fx_graph_report.write_thunder_repro(tmp_path)

    thunder_split_report: ThunderSplitGraphReport
    for thunder_split_report in thunder_fx_graph_report.subgraph_reports:
        split_folder = tmp_path / str(idx)
        thunder_split_report.write_eager_repro(split_folder)
        thunder_split_report.write_thunder_repro(split_folder)
        thunder_split_report.write_inductor_repro(split_folder)

        # If you are only interested in the Thunder-split FX graph, you can stop here. 
        # If you want to inspect Thunder traces and nvFusion regions further, explicitly call 
        # `ThunderSplitGraphReport.create_fusion_reports()` and analyze as shown below.
        thunder_split_report.create_fusion_reports()
        print(f"fwd_trace:\n{thunder_split_report.fwd_trc}\n")
        print(f"bwd_trace:\n{thunder_split_report.bwd_trc}\n")
        nvfusion_report: ThunderFusionReport
        for nvfusion_report in thunder_split_report.fusion_reports:
            nvfusion_report.write_nvfuser_repro(split_folder / "nvfusion")
            nvfusion_report.write_inductor_repro(split_folder / "nvfusion")
            bench_data = nvfusion_report.run_benchmark()
            print(bench_data)
            print("---"*10)

generates files:

.
├── 0
│   ├── graph0_thunder_0_repro_eager.py
│   ├── graph0_thunder_0_repro_thunder.py
│   ├── graph0_thunder_0_repro_torchcompile.py
│   └── nvfusion
│       ├── graph0_thunder_0_nvFusion0_backward_repro_inductor.py
│       ├── graph0_thunder_0_nvFusion0_backward_repro_nvfuser.py
│       ├── graph0_thunder_0_nvFusion0_forward_repro_inductor.py
│       └── graph0_thunder_0_nvFusion0_forward_repro_nvfuser.py
├── 1
│   ├── graph1_thunder_1_repro_eager.py
│   ├── graph1_thunder_1_repro_thunder.py
│   ├── graph1_thunder_1_repro_torchcompile.py
│   └── nvfusion
│       ├── graph1_thunder_1_nvFusion0_backward_repro_inductor.py
│       ├── graph1_thunder_1_nvFusion0_backward_repro_nvfuser.py
│       ├── graph1_thunder_1_nvFusion0_forward_repro_inductor.py
│       └── graph1_thunder_1_nvFusion0_forward_repro_nvfuser.py
├── graph0_repro_thunder.py
└── graph1_repro_thunder.py

The nvfusion repro script is taken from nvfuser's repro_script_for

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@kiya00 kiya00 requested a review from IvanYashchuk February 5, 2025 17:02
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exciting! Looking forward to more comments on the classes and functions. Your PR descriptions always do an excellent job of showcasing the work.

Overall this looks really great. Let's keep the focus on the new classes and methods in this PR, and then let's follow-up in another PR on the details of the timing information. @kshitij12345 has been reviewing various timers carefully and can probably help, too

@mruberry mruberry requested a review from kshitij12345 February 5, 2025 17:09
@kiya00
Copy link
Collaborator Author

kiya00 commented Feb 5, 2025

Hi @mruberry , would you mind take a first look at the draft APIs to see if they are suitable. Also about the NvFusion report, it's easy to get the nvfusion repro method in a standalone file, but getting the nvfusion subsymbols to run with torch.compile in a standalone script is a bit tricky, we don't have a way to write the symbols into a runnable python function, I think if we want a script it must start from the python function of the fx graph, something like:

def gm_function():
    ...
trc=thunder.jit(gm_function)
# for each of the nvFusion region
for nvfusion_bsym in trc:
    callable_for_torchcompile = thunder.executors.torch_compile.make_compiled(nvfusionop.subbsyms,...)
    # run each nvfusion region using torch compile
    callable_for_torchcompile(*inputs)

so currently there's ThunderFusionReport.run_inductor_repro to run it without saving a file, and so does the ThunderFusionReport.run_benchmark, it just calls the _benchmark_fusion_region_with_nvfuser_and_torch_compile and prints out the timing data.

@kshitij12345
Copy link
Collaborator

but getting the nvfusion subsymbols to run with torch.compile in a standalone script is a bit tricky

Regarding this, recently I was looking for something similar and this is roughly what seemed to work.

import torch
import thunder
from thunder.dev_utils.utils import _benchmark_fusion_region_with_nvfuser_and_torch_compile
from thunder.core.trace import TraceCtx, tracectx
from thunder.core.codeutils import SigInfo
from thunder.dynamo.utils import arg_like

def fn(x):
    y = (x * x).sum()
    z = x @ x
    return z.sin() + y.cos()

jfn = thunder.jit(fn, nv_store_fusion_inputs=True)

jfn(torch.randn(16, 16, device="cuda"))

trc = thunder.last_traces(jfn)[-1]

def create_python_callable_from_bsym(bsym):
    trace = TraceCtx()
    si = SigInfo(bsym.sym.name)
    si.args = [(v.name, None) for v in bsym.flat_args]
    trace._siginfo = si
    trace.siginfo()
    # trace.args = bsym.flat_args
    trace.bound_symbols = list(bsym.subsymbols)

    with tracectx(trace):
        thunder.prims.python_return(bsym.output)

    
    return trace.python(include_decorators=False)

template_torch_compile = '''
{python_func}

from thunder.executors.torch_compile import make_compiled as make_torch_compile_callable
import thunder.examine

# inputs
x = torch.randn(16, 16, device="cuda")
inputs = {inputs}


jfn = thunder.jit({func_name})
o = jfn(*inputs)

trc = thunder.last_traces(jfn)[-1]
fusion_symbols = thunder.examine.get_fusion_symbols(trc)
assert len(fusion_symbols) == 1
bsym = fusion_symbols[0]

# print(bsym)  # Verified visually it looks the same
torch_compiled_callable = make_torch_compile_callable(bsym.subsymbols, bsym.flat_args, bsym.flat_outs)
o = torch_compiled_callable(*inputs)
'''

for bsym in trc.bound_symbols:
    if bsym.sym.is_fusion and "nvFusion" in bsym.sym.name:
        python_func = create_python_callable_from_bsym(bsym)
        with open("torch_compile_repro.py", "w") as f:
            nvfuser_callable = bsym._call_ctx[bsym.sym.name]
            inputs = nvfuser_callable.last_inputs
            inputs = "[" + "".join(arg_like(inp) for inp in inputs) + "]"
            program = template_torch_compile.format(python_func=python_func, func_name=bsym.sym.name, inputs=inputs)
            f.write(program)
        break

Generated Repro

import thunder
import thunder.core.prims as prims
import torch

def nvFusion0(x):
  # /opt/pytorch/lightning-thunder/test.py:9: 	    y = (x * x).sum()
  t0 = prims.mul(x, x)  # t0: "cuda:0 f32[16, 16]"
  y = prims.sum(t0, (0, 1))  # y: "cuda:0 f32[]"
  return [y]

from thunder.executors.torch_compile import make_compiled as make_torch_compile_callable
import thunder.examine

# inputs
x = torch.randn(16, 16, device="cuda")
inputs = [torch.testing.make_tensor((16, 16), dtype=torch.float32,  device='cuda:0', requires_grad=False, low=-2.9306769371032715, high=2.9876084327697754,).as_strided(torch.Size([16, 16]), (16, 1)),]


jfn = thunder.jit(nvFusion0)
o = jfn(*inputs)

trc = thunder.last_traces(jfn)[-1]
fusion_symbols = thunder.examine.get_fusion_symbols(trc)
assert len(fusion_symbols) == 1
bsym = fusion_symbols[0]

# print(bsym)  # Verified visually it looks the same
torch_compiled_callable = make_torch_compile_callable(bsym.subsymbols, bsym.flat_args, bsym.flat_outs)
o = torch_compiled_callable(*inputs)

The repro takes a roundabout way to create a torch.compile callable for a nvFusion but it seems to work. Maybe this could help come up with a cleaner approach.

@mruberry
Copy link
Collaborator

mruberry commented Feb 5, 2025

Starting with @kshitij12345's technique and throwing an exception if a problem is encountered sounds like a great start. fyi @IvanYashchuk. I'm sure over time we can make the approach @kshitij12345 recommends more modular. It seems like it has the right concepts

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good to me. I just have few suggestions and questions regarding naming. Thanks @kiya00

@kiya00 kiya00 marked this pull request as ready for review February 7, 2025 14:53
@kiya00 kiya00 requested review from lantiga and t-vi as code owners February 7, 2025 14:53
@kiya00
Copy link
Collaborator Author

kiya00 commented Feb 7, 2025

Hi @mruberry @IvanYashchuk @kshitij12345 it's ready to review
In the updates I use the snippet from @kshitij12345 to save the inductor repro script of nvFusion bsyms, and integrate the _benchmark_fusion_region_with_nvfuser_and_torch_compile for benchmarking nvFusion

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! This is a great step forward in that it represents the remaining objects we're currently interested in: segments and fusions. Fantastic.

I wrote a comment about what I think the next step is — improving the modularity and quality of our benchmark numbers. @kshitij12345 should be able to help there.

Once we have the timings working for fx graphs, segments, and fusions as expected, then I think we'll just need a few refinements to begin using this tool in our benchmark runs. In particular, we should review how correctness issues are reported, how we can identify significant performance deviations and only generate reports when they're found, and how we can summarize split reasons.

How does that sound, @kiya00?

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! This is a great step forward in that it represents the remaining objects we're currently interested in: segments and fusions. Fantastic.

I wrote a comment about what I think the next step is — improving the modularity and quality of our benchmark numbers. @kshitij12345 should be able to help there.

Once we have the timings working for fx graphs, segments, and fusions as expected, then I think we'll just need a few refinements to begin using this tool in our benchmark runs. In particular, we should review how correctness issues are reported, how we can identify significant performance deviations and only generate reports when they're found, and how we can summarize split reasons.

How does that sound, @kiya00?

@mruberry
Copy link
Collaborator

mruberry commented Feb 7, 2025

Conflicts will need to be resolved before this can be merged

@kiya00 kiya00 force-pushed the thunderfx_report3 branch from c901a3e to 5504f38 Compare February 7, 2025 17:23
@kiya00
Copy link
Collaborator Author

kiya00 commented Feb 7, 2025

Hi @mruberry , thanks for the suggestions. The code in the constructor of ThunderFXGraphReport is moved into the factory function analyze_thunder_splits.
And for the timer thing I'll think about it and ask @kshitij12345 in Monday and get back to you.

@kiya00 kiya00 merged commit bb61300 into main Feb 7, 2025
49 checks passed
@kiya00 kiya00 deleted the thunderfx_report3 branch February 7, 2025 17:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Reporting tool] Adds the Thunder segmentation and fusion information to the FXReport
3 participants