@@ -581,7 +581,7 @@ def forward(self, a, b, c):
581581 m = torch.jit.load(buffer)
582582 new_res = m(a, b, c)
583583 FileCheck().check_not("aten::relu(") \
584- .check("aten::add_relu (") \
584+ .check("aten::_add_relu (") \
585585 .run(m.graph)
586586 torch.testing.assert_allclose(orig_res, new_res)
587587
@@ -600,7 +600,7 @@ def forward(self, a, b, c):
600600 m = torch.jit.load(buffer)
601601 new_res = m(a, b, c)
602602 FileCheck().check_not("aten::relu_(") \
603- .check("aten::add_relu (") \
603+ .check("aten::_add_relu (") \
604604 .run(m.graph)
605605 torch.testing.assert_allclose(orig_res, new_res)
606606
@@ -631,10 +631,10 @@ def forward(self, a, b):
631631 new_res = m(a_copy, b)
632632 FileCheck().check_not("aten::add_(") \
633633 .check_not("aten::relu_(") \
634- .check("aten::add_relu_ (") \
634+ .check("aten::_add_relu_ (") \
635635 .run(m.graph)
636636 torch.testing.assert_allclose(orig_res, new_res)
637- # Since add_relu_ does inplace mutation ensure
637+ # Since _add_relu_ does inplace mutation ensure
638638 # a_copy is modified
639639 torch.testing.assert_allclose(orig_res, a_copy)
640640
@@ -669,10 +669,10 @@ def forward(self, a, b):
669669 new_res = m(a_copy, b)
670670 FileCheck().check_not("aten::add(") \
671671 .check_not("aten::relu_(") \
672- .check("aten::add_relu (") \
672+ .check("aten::_add_relu (") \
673673 .run(m.graph)
674674 torch.testing.assert_allclose(orig_res, new_res)
675- # Since add_relu_ with out=a does inplace mutation ensure
675+ # Since _add_relu_ with out=a does inplace mutation ensure
676676 # a_copy is modified
677677 torch.testing.assert_allclose(orig_res, a_copy)
678678
0 commit comments