-
Notifications
You must be signed in to change notification settings - Fork 96
[Thunderfx report] Adds Thunder segmentation and fusion information #1747
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exciting! Looking forward to more comments on the classes and functions. Your PR descriptions always do an excellent job of showcasing the work.
Overall this looks really great. Let's keep the focus on the new classes and methods in this PR, and then let's follow-up in another PR on the details of the timing information. @kshitij12345 has been reviewing various timers carefully and can probably help, too
Hi @mruberry , would you mind take a first look at the draft APIs to see if they are suitable. Also about the NvFusion report, it's easy to get the nvfusion repro method in a standalone file, but getting the nvfusion subsymbols to run with torch.compile in a standalone script is a bit tricky, we don't have a way to write the symbols into a runnable python function, I think if we want a script it must start from the python function of the fx graph, something like:
so currently there's |
Regarding this, recently I was looking for something similar and this is roughly what seemed to work. import torch
import thunder
from thunder.dev_utils.utils import _benchmark_fusion_region_with_nvfuser_and_torch_compile
from thunder.core.trace import TraceCtx, tracectx
from thunder.core.codeutils import SigInfo
from thunder.dynamo.utils import arg_like
def fn(x):
y = (x * x).sum()
z = x @ x
return z.sin() + y.cos()
jfn = thunder.jit(fn, nv_store_fusion_inputs=True)
jfn(torch.randn(16, 16, device="cuda"))
trc = thunder.last_traces(jfn)[-1]
def create_python_callable_from_bsym(bsym):
trace = TraceCtx()
si = SigInfo(bsym.sym.name)
si.args = [(v.name, None) for v in bsym.flat_args]
trace._siginfo = si
trace.siginfo()
# trace.args = bsym.flat_args
trace.bound_symbols = list(bsym.subsymbols)
with tracectx(trace):
thunder.prims.python_return(bsym.output)
return trace.python(include_decorators=False)
template_torch_compile = '''
{python_func}
from thunder.executors.torch_compile import make_compiled as make_torch_compile_callable
import thunder.examine
# inputs
x = torch.randn(16, 16, device="cuda")
inputs = {inputs}
jfn = thunder.jit({func_name})
o = jfn(*inputs)
trc = thunder.last_traces(jfn)[-1]
fusion_symbols = thunder.examine.get_fusion_symbols(trc)
assert len(fusion_symbols) == 1
bsym = fusion_symbols[0]
# print(bsym) # Verified visually it looks the same
torch_compiled_callable = make_torch_compile_callable(bsym.subsymbols, bsym.flat_args, bsym.flat_outs)
o = torch_compiled_callable(*inputs)
'''
for bsym in trc.bound_symbols:
if bsym.sym.is_fusion and "nvFusion" in bsym.sym.name:
python_func = create_python_callable_from_bsym(bsym)
with open("torch_compile_repro.py", "w") as f:
nvfuser_callable = bsym._call_ctx[bsym.sym.name]
inputs = nvfuser_callable.last_inputs
inputs = "[" + "".join(arg_like(inp) for inp in inputs) + "]"
program = template_torch_compile.format(python_func=python_func, func_name=bsym.sym.name, inputs=inputs)
f.write(program)
break Generated Repro import thunder
import thunder.core.prims as prims
import torch
def nvFusion0(x):
# /opt/pytorch/lightning-thunder/test.py:9: y = (x * x).sum()
t0 = prims.mul(x, x) # t0: "cuda:0 f32[16, 16]"
y = prims.sum(t0, (0, 1)) # y: "cuda:0 f32[]"
return [y]
from thunder.executors.torch_compile import make_compiled as make_torch_compile_callable
import thunder.examine
# inputs
x = torch.randn(16, 16, device="cuda")
inputs = [torch.testing.make_tensor((16, 16), dtype=torch.float32, device='cuda:0', requires_grad=False, low=-2.9306769371032715, high=2.9876084327697754,).as_strided(torch.Size([16, 16]), (16, 1)),]
jfn = thunder.jit(nvFusion0)
o = jfn(*inputs)
trc = thunder.last_traces(jfn)[-1]
fusion_symbols = thunder.examine.get_fusion_symbols(trc)
assert len(fusion_symbols) == 1
bsym = fusion_symbols[0]
# print(bsym) # Verified visually it looks the same
torch_compiled_callable = make_torch_compile_callable(bsym.subsymbols, bsym.flat_args, bsym.flat_outs)
o = torch_compiled_callable(*inputs) The repro takes a roundabout way to create a torch.compile callable for a nvFusion but it seems to work. Maybe this could help come up with a cleaner approach. |
Starting with @kshitij12345's technique and throwing an exception if a problem is encountered sounds like a great start. fyi @IvanYashchuk. I'm sure over time we can make the approach @kshitij12345 recommends more modular. It seems like it has the right concepts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks good to me. I just have few suggestions and questions regarding naming. Thanks @kiya00
Hi @mruberry @IvanYashchuk @kshitij12345 it's ready to review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! This is a great step forward in that it represents the remaining objects we're currently interested in: segments and fusions. Fantastic.
I wrote a comment about what I think the next step is — improving the modularity and quality of our benchmark numbers. @kshitij12345 should be able to help there.
Once we have the timings working for fx graphs, segments, and fusions as expected, then I think we'll just need a few refinements to begin using this tool in our benchmark runs. In particular, we should review how correctness issues are reported, how we can identify significant performance deviations and only generate reports when they're found, and how we can summarize split reasons.
How does that sound, @kiya00?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! This is a great step forward in that it represents the remaining objects we're currently interested in: segments and fusions. Fantastic.
I wrote a comment about what I think the next step is — improving the modularity and quality of our benchmark numbers. @kshitij12345 should be able to help there.
Once we have the timings working for fx graphs, segments, and fusions as expected, then I think we'll just need a few refinements to begin using this tool in our benchmark runs. In particular, we should review how correctness issues are reported, how we can identify significant performance deviations and only generate reports when they're found, and how we can summarize split reasons.
How does that sound, @kiya00?
Conflicts will need to be resolved before this can be merged |
…ble from a nvFusion bsym
of ThunderFXGraphReport into analyze_thunder_splits
c901a3e
to
5504f38
Compare
Hi @mruberry , thanks for the suggestions. The code in the constructor of |
Before submitting
What does this PR do?
Fixes #1741 .
adds
ThunderFXGraphReport, ThunderSplitGraphReport, ThunderFusionReport, analyze_with_thunder
to support query thunder split graph and nvfusion informationAn example:
generates files:
The nvfusion repro script is taken from nvfuser's
repro_script_for
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃