Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions backends/arm/test/ops/test_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@ def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]):
"case",
test_cases,
xfails={
"zero_args_one_output": "Since the submodules have no input, the tracer fails finding a fake tensor mode,"
" and traces the graph with real tensors, which tosa.RESCALE can't handle.",
"one_arg_and_scalar_one_output": "Incorrect quantization on the scalar.",
"nested_one_arg_one_output": "Node submodule_0 target submodule_0 references nonexistent attribute submodule_0",
},
Expand Down
8 changes: 4 additions & 4 deletions exir/pass_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -593,14 +594,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult:
), "Multiple fake tensor mode detected."
fake_tensor_mode = i.fake_mode
if fake_tensor_mode is None:
self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
fake_tensor_mode = nullcontext() # type: ignore[assignment]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@angelayi @tugsbayasgalan what is the difference between None and nullcontext() for fake tensor mode

fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
dispatcher_mode = nullcontext() # type: ignore[assignment]
else:
fake_tensor_mode.allow_non_fake_inputs = True
self.tracer.fake_tensor_mode = fake_tensor_mode
dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment]
self.fake_tensor_mode = self.tracer.fake_tensor_mode
self.tracer.fake_tensor_mode = fake_tensor_mode
self.fake_tensor_mode = fake_tensor_mode

with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr]
result = self.call_submodule(graph_module, tuple(inputs))
Expand Down
Loading