@@ -5511,11 +5511,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
55115511 dim0_x = torch.export.Dim("dim0_x", min=3)
55125512 dim1_x = torch.export.Dim("dim1_x", max=8000)
55135513 dynamic_shapes = {"x": (dim0_x, dim1_x)}
5514- em = torch.export._trace._export (
5514+ em = torch.export.export (
55155515 m,
55165516 (a,),
55175517 dynamic_shapes=dynamic_shapes,
5518- allow_complex_guards_as_runtime_asserts =True,
5518+ prefer_deferred_runtime_asserts_over_guards =True,
55195519 )
55205520 em.module()(torch.randn(4, 3))
55215521 with self.assertRaisesRegex(
@@ -13399,7 +13399,7 @@ def forward(self, x):
1339913399
1340013400 def test_disable_forced_specializations_ok(self):
1340113401 # check that we don't force specialization, and defer to runtime asserts
13402- # with allow_complex_guards_as_runtime_asserts =True to successfully export
13402+ # with prefer_deferred_runtime_asserts_over_guards =True to successfully export
1340313403 # case 1: modulo guards
1340413404 from torch.export import dims
1340513405
@@ -13409,11 +13409,11 @@ def forward(self, x):
1340913409
1341013410 inputs = (torch.randn(10, 72),)
1341113411 dx, dy = dims("dx", "dy")
13412- ep = torch.export._trace._export (
13412+ ep = torch.export.export (
1341313413 Mod4Reshape(),
1341413414 inputs,
1341513415 dynamic_shapes={"x": (dx, dy)},
13416- allow_complex_guards_as_runtime_asserts =True,
13416+ prefer_deferred_runtime_asserts_over_guards =True,
1341713417 )
1341813418 out1 = ep.module()(torch.randn(8, 7))
1341913419 self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
@@ -13443,11 +13443,11 @@ def forward(self, x, y, z):
1344313443
1344413444 for private_api in (True, False):
1344513445 if private_api:
13446- ep = torch.export._trace._export (
13446+ ep = torch.export.export (
1344713447 FreeReshape(),
1344813448 inputs,
1344913449 dynamic_shapes=dynamic_shapes,
13450- allow_complex_guards_as_runtime_asserts =True,
13450+ prefer_deferred_runtime_asserts_over_guards =True,
1345113451 )
1345213452 else:
1345313453 ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
@@ -13484,11 +13484,11 @@ def forward(self, x, y):
1348413484 "x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
1348513485 "y": (Dim("dy", min=8),),
1348613486 }
13487- ep = torch.export._trace._export (
13487+ ep = torch.export.export (
1348813488 Reshape3d(),
1348913489 inputs,
1349013490 dynamic_shapes=dynamic_shapes,
13491- allow_complex_guards_as_runtime_asserts =True,
13491+ prefer_deferred_runtime_asserts_over_guards =True,
1349213492 )
1349313493 out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126))
1349413494 self.assertEqual(out1.shape, torch.ones(126).shape)
@@ -13610,11 +13610,11 @@ def forward(self, x):
1361013610 model = Model()
1361113611 x = torch.rand(1024, 20, 16)
1361213612 dynamic_shapes = {"x": {0: Dim("batch")}}
13613- ep = torch.export._trace._export (
13613+ ep = torch.export.export (
1361413614 model,
1361513615 (x,),
1361613616 dynamic_shapes=dynamic_shapes,
13617- allow_complex_guards_as_runtime_asserts =True,
13617+ prefer_deferred_runtime_asserts_over_guards =True,
1361813618 )
1361913619 with self.assertRaisesRegex(
1362013620 RuntimeError,
@@ -13687,11 +13687,11 @@ def forward(self, x, y):
1368713687
1368813688 inputs = (torch.randn(6), torch.randn(12))
1368913689 dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]}
13690- ep = torch.export._trace._export (
13690+ ep = torch.export.export (
1369113691 Foo(),
1369213692 inputs,
1369313693 dynamic_shapes=dynamic_shapes,
13694- allow_complex_guards_as_runtime_asserts =True,
13694+ prefer_deferred_runtime_asserts_over_guards =True,
1369513695 )
1369613696 # check forward pass
1369713697 out0, out1 = ep.module()(torch.randn(9), torch.randn(27))
@@ -13726,7 +13726,7 @@ def forward(self, x, y):
1372613726 Foo(),
1372713727 inputs,
1372813728 dynamic_shapes=dynamic_shapes,
13729- allow_complex_guards_as_runtime_asserts =True,
13729+ prefer_deferred_runtime_asserts_over_guards =True,
1373013730 ).run_decompositions()
1373113731
1373213732 self.assertEqual(
@@ -14138,11 +14138,11 @@ def forward(self, x, y):
1413814138
1413914139 inputs = (torch.randn(5), torch.randn(3))
1414014140 shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)}
14141- ep = torch.export._trace._export (
14141+ ep = torch.export.export (
1414214142 Foo(),
1414314143 inputs,
1414414144 dynamic_shapes=shapes,
14145- allow_complex_guards_as_runtime_asserts =True,
14145+ prefer_deferred_runtime_asserts_over_guards =True,
1414614146 )
1414714147 # count 2 pow nodes, 2 sym_size.int nodes
1414814148 self.assertEqual(
@@ -14941,11 +14941,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1494114941
1494214942 for private_api in (True, False):
1494314943 if private_api:
14944- ep = torch.export._trace._export (
14944+ ep = torch.export.export (
1494514945 ModConstraint(),
1494614946 (torch.randn(3, 4),),
1494714947 dynamic_shapes={"x": (dynamic, dynamic)},
14948- allow_complex_guards_as_runtime_asserts =True,
14948+ prefer_deferred_runtime_asserts_over_guards =True,
1494914949 )
1495014950 else:
1495114951 ep = export(
@@ -14959,7 +14959,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1495914959 for node in ep.graph.nodes
1496014960 ].count(True)
1496114961 if private_api:
14962- self.assertEqual(num_asserts, 7 )
14962+ self.assertEqual(num_asserts, 6 )
1496314963 with self.assertRaisesRegex(
1496414964 RuntimeError,
1496514965 r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)",
0 commit comments