diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index 3ca6d09a52..418407ae1e 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -489,6 +489,7 @@ def forward(self, w, x, y, z): # specify 0/1 sample inputs when you'd like your program to hardcode them, and non-0/1 sample inputs when dynamic behavior is desirable. See what happens # at runtime when we export this linear layer: +torch._logging.set_logs(dynamic=0) ep = export( torch.nn.Linear(4, 3), (torch.randn(1, 4),), @@ -591,6 +592,30 @@ def forward(self, x, y): "bool_val": None, } +###################################################################### +# (experimental) Avoiding 0/1 specialization +# ^^^^^^^^^^^^^^^^^^ +# +# Export provides an experimental option to avoid specializing on size 0/1 sample inputs. Users can turn on `torch.fx.experimental._config.backed_size_oblivious = True` to enable this behavior. +# This allows the compiler to allocate a [0, inf] range for symbols, and assume general-case semantics in compiler decisions between semantics for size 0/1 and >= 2 sizes. +# This can lead to behavior divergence between eager mode and the exported program on size 0/1 inputs - for example, in broadcasting decisions, we will assume input shapes are not 1-specialized, +# and therefore assume broadcasting does not apply (even if it does on the particular sample inputs). The same logic applies for other semantics (e.g. contiguity), and size 0 tensors. +# +# The exact semantics under this flag are a work in progress, and usage is recommended only when the user is certain their model does not rely on 0/1-specialized semantics. +# For now, export users can enable this with: + +class Foo(torch.nn.Module): + def forward(self, x, y): + return x + y # nothing special about size 0/1 here + +x = torch.randn(0, 1) +y = torch.randn(1) +dynamic_shapes = {"x": (Dim.AUTO, Dim.AUTO), "y": (Dim.AUTO,)} +with torch.fx.experimental._config.patch(backed_size_oblivious=True): + ep = export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) +ep.module()(torch.randn(8, 1), torch.randn(1)) +ep.module()(torch.randn(5, 6), torch.randn(6)) + ###################################################################### # Data-dependent errors # ---------------------