Simple PyTorch graph capturing.
Please install graphviz first:
$ apt-get install graphvizClone and install this package:
$ pip install .Examples:
from torchgraph import dispatch_capture, aot_capture, compile_capture
import torch
import torch.nn as nn
model = nn.Sequential(
    nn.Conv2d(16, 32, 3),
    nn.BatchNorm2d(32),
    nn.SiLU(),
).cuda()
x = torch.randn((2, 16, 8, 8), requires_grad=True, device="cuda")
# Capture joint forward and backward graph through dispatch
dispatch_capture(model, x)
# Capture separate forward and backward graphs through PyTorch AOTAutograd
aot_capture(model, x)
# Capture forward graph through PyTorch compile
compile_capture(model, x)You'll find the captured graphs in .svg format under current folder.