🌐 AI搜索 & 代理 主页
Skip to content

Commit ae4915d

Browse files
avikchaudhurifacebook-github-bot
authored andcommitted
kill allow_complex_guards_as_runtime_asserts (#161794)
Summary: [reland] Since `allow_complex_guards_as_runtime_asserts` is now sync'd with `prefer_deferred_runtime_asserts_over_guards`, we can kill the former (especially since it was a export-only concept). Test Plan: updated tests Rollback Plan: Reviewed By: zhxchen17 Differential Revision: D81334984
1 parent d647185 commit ae4915d

File tree

10 files changed

+48
-68
lines changed

10 files changed

+48
-68
lines changed

test/dynamo/test_activation_checkpointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def runtime_wrapper(*runtime_args):
263263
dynamic_shapes=None,
264264
preserve_module_call_signature=(),
265265
restore_fqn=False,
266-
allow_complex_guards_as_runtime_asserts=False,
266+
prefer_deferred_runtime_asserts_over_guards=False,
267267
_log_export_usage=False,
268268
)
269269
# NOTE: this is necessary for rng to be added to the exported graph

test/dynamo/test_misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10903,8 +10903,8 @@ def test_shape_env_equal_constructor(self):
1090310903
ShapeEnv not equal: field values don't match:
1090410904
1090510905
==> settings: values don't match.
10906-
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False)
10907-
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False)
10906+
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False)
10907+
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False)
1090810908
""",
1090910909
)
1091010910
self._replay_and_check(main)

test/export/test_export.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5511,11 +5511,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
55115511
dim0_x = torch.export.Dim("dim0_x", min=3)
55125512
dim1_x = torch.export.Dim("dim1_x", max=8000)
55135513
dynamic_shapes = {"x": (dim0_x, dim1_x)}
5514-
em = torch.export._trace._export(
5514+
em = torch.export.export(
55155515
m,
55165516
(a,),
55175517
dynamic_shapes=dynamic_shapes,
5518-
allow_complex_guards_as_runtime_asserts=True,
5518+
prefer_deferred_runtime_asserts_over_guards=True,
55195519
)
55205520
em.module()(torch.randn(4, 3))
55215521
with self.assertRaisesRegex(
@@ -13399,7 +13399,7 @@ def forward(self, x):
1339913399

1340013400
def test_disable_forced_specializations_ok(self):
1340113401
# check that we don't force specialization, and defer to runtime asserts
13402-
# with allow_complex_guards_as_runtime_asserts=True to successfully export
13402+
# with prefer_deferred_runtime_asserts_over_guards=True to successfully export
1340313403
# case 1: modulo guards
1340413404
from torch.export import dims
1340513405

@@ -13409,11 +13409,11 @@ def forward(self, x):
1340913409

1341013410
inputs = (torch.randn(10, 72),)
1341113411
dx, dy = dims("dx", "dy")
13412-
ep = torch.export._trace._export(
13412+
ep = torch.export.export(
1341313413
Mod4Reshape(),
1341413414
inputs,
1341513415
dynamic_shapes={"x": (dx, dy)},
13416-
allow_complex_guards_as_runtime_asserts=True,
13416+
prefer_deferred_runtime_asserts_over_guards=True,
1341713417
)
1341813418
out1 = ep.module()(torch.randn(8, 7))
1341913419
self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
@@ -13443,11 +13443,11 @@ def forward(self, x, y, z):
1344313443

1344413444
for private_api in (True, False):
1344513445
if private_api:
13446-
ep = torch.export._trace._export(
13446+
ep = torch.export.export(
1344713447
FreeReshape(),
1344813448
inputs,
1344913449
dynamic_shapes=dynamic_shapes,
13450-
allow_complex_guards_as_runtime_asserts=True,
13450+
prefer_deferred_runtime_asserts_over_guards=True,
1345113451
)
1345213452
else:
1345313453
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
@@ -13484,11 +13484,11 @@ def forward(self, x, y):
1348413484
"x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
1348513485
"y": (Dim("dy", min=8),),
1348613486
}
13487-
ep = torch.export._trace._export(
13487+
ep = torch.export.export(
1348813488
Reshape3d(),
1348913489
inputs,
1349013490
dynamic_shapes=dynamic_shapes,
13491-
allow_complex_guards_as_runtime_asserts=True,
13491+
prefer_deferred_runtime_asserts_over_guards=True,
1349213492
)
1349313493
out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126))
1349413494
self.assertEqual(out1.shape, torch.ones(126).shape)
@@ -13610,11 +13610,11 @@ def forward(self, x):
1361013610
model = Model()
1361113611
x = torch.rand(1024, 20, 16)
1361213612
dynamic_shapes = {"x": {0: Dim("batch")}}
13613-
ep = torch.export._trace._export(
13613+
ep = torch.export.export(
1361413614
model,
1361513615
(x,),
1361613616
dynamic_shapes=dynamic_shapes,
13617-
allow_complex_guards_as_runtime_asserts=True,
13617+
prefer_deferred_runtime_asserts_over_guards=True,
1361813618
)
1361913619
with self.assertRaisesRegex(
1362013620
RuntimeError,
@@ -13687,11 +13687,11 @@ def forward(self, x, y):
1368713687

1368813688
inputs = (torch.randn(6), torch.randn(12))
1368913689
dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]}
13690-
ep = torch.export._trace._export(
13690+
ep = torch.export.export(
1369113691
Foo(),
1369213692
inputs,
1369313693
dynamic_shapes=dynamic_shapes,
13694-
allow_complex_guards_as_runtime_asserts=True,
13694+
prefer_deferred_runtime_asserts_over_guards=True,
1369513695
)
1369613696
# check forward pass
1369713697
out0, out1 = ep.module()(torch.randn(9), torch.randn(27))
@@ -13726,7 +13726,7 @@ def forward(self, x, y):
1372613726
Foo(),
1372713727
inputs,
1372813728
dynamic_shapes=dynamic_shapes,
13729-
allow_complex_guards_as_runtime_asserts=True,
13729+
prefer_deferred_runtime_asserts_over_guards=True,
1373013730
).run_decompositions()
1373113731

1373213732
self.assertEqual(
@@ -14138,11 +14138,11 @@ def forward(self, x, y):
1413814138

1413914139
inputs = (torch.randn(5), torch.randn(3))
1414014140
shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)}
14141-
ep = torch.export._trace._export(
14141+
ep = torch.export.export(
1414214142
Foo(),
1414314143
inputs,
1414414144
dynamic_shapes=shapes,
14145-
allow_complex_guards_as_runtime_asserts=True,
14145+
prefer_deferred_runtime_asserts_over_guards=True,
1414614146
)
1414714147
# count 2 pow nodes, 2 sym_size.int nodes
1414814148
self.assertEqual(
@@ -14941,11 +14941,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1494114941

1494214942
for private_api in (True, False):
1494314943
if private_api:
14944-
ep = torch.export._trace._export(
14944+
ep = torch.export.export(
1494514945
ModConstraint(),
1494614946
(torch.randn(3, 4),),
1494714947
dynamic_shapes={"x": (dynamic, dynamic)},
14948-
allow_complex_guards_as_runtime_asserts=True,
14948+
prefer_deferred_runtime_asserts_over_guards=True,
1494914949
)
1495014950
else:
1495114951
ep = export(
@@ -14959,7 +14959,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
1495914959
for node in ep.graph.nodes
1496014960
].count(True)
1496114961
if private_api:
14962-
self.assertEqual(num_asserts, 7)
14962+
self.assertEqual(num_asserts, 6)
1496314963
with self.assertRaisesRegex(
1496414964
RuntimeError,
1496514965
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)",

torch/_dynamo/config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,6 @@
258258
# hybrid backed unbacked symints
259259
prefer_deferred_runtime_asserts_over_guards = False
260260

261-
# For complex dynamic shapes guards that we're unable to specify with dynamo/export's
262-
# range constraints + dims + derived dims language, we raise constraint violation
263-
# errors or specialize by default. If set to True, this flag avoids crashing/specialization,
264-
# and allows complex guards as runtime assertions in the graph.
265-
allow_complex_guards_as_runtime_asserts = False
266-
267261
# By default, dynamo will treat all ints as backed SymInts, which means (1) it
268262
# will wait to see the int change over multiple runs before generalizing and
269263
# (2) it will still always 0/1 specialize an int. When true, this knob

torch/_dynamo/eval_frame.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,7 +1734,6 @@ def export(
17341734
same_signature: bool = True,
17351735
disable_constraint_solver: bool = False,
17361736
prefer_deferred_runtime_asserts_over_guards: bool = False,
1737-
allow_complex_guards_as_runtime_asserts: bool = False,
17381737
_log_export_usage: bool = True,
17391738
constraints: Optional[list[Constraint]] = None,
17401739
**extra_kwargs: Any,
@@ -1961,7 +1960,6 @@ def fakify_with_ambient(
19611960
capture_dynamic_output_shape_ops=True,
19621961
capture_scalar_outputs=True,
19631962
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
1964-
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
19651963
),
19661964
_compiling_state_context(),
19671965
):

torch/_dynamo/output_graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,6 @@ def __init__(
468468
allow_scalar_outputs=config.capture_scalar_outputs,
469469
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
470470
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
471-
allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts,
472471
co_fields=self.co_fields,
473472
)
474473

torch/_export/non_strict_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def make_fake_inputs(
330330
args,
331331
kwargs,
332332
dynamic_shapes,
333-
allow_complex_guards_as_runtime_asserts=False,
333+
prefer_deferred_runtime_asserts_over_guards=False,
334334
):
335335
"""
336336
Given an nn module, example inputs, and constraints, return a new fake mode,
@@ -382,8 +382,7 @@ def make_fake_inputs(
382382
shape_env=ShapeEnv(
383383
tracked_fakes=[],
384384
co_fields=co_fields,
385-
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
386-
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
385+
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
387386
trace_asserts=True,
388387
),
389388
allow_non_fake_inputs=True,

torch/export/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def export_for_training(
158158
dynamic_shapes,
159159
strict=strict,
160160
preserve_module_call_signature=preserve_module_call_signature,
161-
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
161+
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
162162
)
163163

164164

@@ -282,7 +282,7 @@ def export(
282282
strict=strict,
283283
preserve_module_call_signature=preserve_module_call_signature,
284284
pre_dispatch=True,
285-
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
285+
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
286286
)
287287
except Exception as e:
288288
draft_export_msg = (

torch/export/_trace.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,7 @@ def _export_to_torch_ir(
750750
*,
751751
preserve_module_call_signature: tuple[str, ...] = (),
752752
disable_constraint_solver: bool = False,
753-
allow_complex_guards_as_runtime_asserts: bool = False,
753+
prefer_deferred_runtime_asserts_over_guards: bool = False,
754754
restore_fqn: bool = True,
755755
_log_export_usage: bool = True,
756756
same_signature: bool = True,
@@ -810,10 +810,7 @@ def _export_to_torch_ir(
810810
assume_static_by_default=True,
811811
tracing_mode="symbolic",
812812
disable_constraint_solver=disable_constraint_solver,
813-
# currently the following 2 flags are tied together for export purposes,
814-
# but untangle for sake of dynamo export api
815-
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
816-
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
813+
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
817814
_log_export_usage=_log_export_usage,
818815
same_signature=same_signature,
819816
)(
@@ -1402,7 +1399,7 @@ def _strict_export(
14021399
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
14031400
preserve_module_call_signature: tuple[str, ...],
14041401
orig_in_spec: TreeSpec,
1405-
allow_complex_guards_as_runtime_asserts: bool,
1402+
prefer_deferred_runtime_asserts_over_guards: bool,
14061403
_to_aten_func: Callable,
14071404
) -> ExportArtifact:
14081405
"""
@@ -1416,7 +1413,7 @@ def _strict_export(
14161413
dynamic_shapes,
14171414
preserve_module_call_signature=preserve_module_call_signature,
14181415
restore_fqn=False, # don't need to restore because we will do it later
1419-
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
1416+
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
14201417
_log_export_usage=False,
14211418
)
14221419

@@ -1859,7 +1856,7 @@ def _non_strict_export(
18591856
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
18601857
preserve_module_call_signature: tuple[str, ...],
18611858
orig_in_spec: TreeSpec,
1862-
allow_complex_guards_as_runtime_asserts: bool,
1859+
prefer_deferred_runtime_asserts_over_guards: bool,
18631860
_to_aten_func: Callable,
18641861
) -> ExportArtifact:
18651862
"""
@@ -1956,7 +1953,7 @@ def forward(self, *args, **kwargs):
19561953
args,
19571954
kwargs,
19581955
dynamic_shapes,
1959-
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, # for shape env initialization
1956+
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, # for shape env initialization
19601957
)
19611958

19621959
fake_params_buffers = _fakify_params_buffers(fake_mode, mod)
@@ -2037,7 +2034,7 @@ def _export_for_training(
20372034
*,
20382035
strict: bool = True,
20392036
preserve_module_call_signature: tuple[str, ...] = (),
2040-
allow_complex_guards_as_runtime_asserts: bool = False,
2037+
prefer_deferred_runtime_asserts_over_guards: bool = False,
20412038
) -> ExportedProgram:
20422039
global _EXPORT_MODULE_HIERARCHY
20432040
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
@@ -2062,7 +2059,7 @@ def _export_for_training(
20622059
dynamic_shapes=dynamic_shapes,
20632060
preserve_module_call_signature=preserve_module_call_signature,
20642061
orig_in_spec=orig_in_spec,
2065-
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
2062+
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
20662063
_to_aten_func=_export_to_aten_ir_make_fx,
20672064
)
20682065

@@ -2124,7 +2121,7 @@ def _export(
21242121
strict: bool = True,
21252122
preserve_module_call_signature: tuple[str, ...] = (),
21262123
pre_dispatch: bool = False,
2127-
allow_complex_guards_as_runtime_asserts: bool = False,
2124+
prefer_deferred_runtime_asserts_over_guards: bool = False,
21282125
) -> ExportedProgram:
21292126
"""
21302127
Traces either an nn.Module's forward function or just a callable with PyTorch
@@ -2155,7 +2152,7 @@ def _export(
21552152
preserve_module_call_signature: A list of submodule paths for which the original
21562153
calling conventions are preserved as metadata.
21572154
2158-
allow_complex_guards_as_runtime_asserts:
2155+
prefer_deferred_runtime_asserts_over_guards:
21592156
With the current dynamic shapes language for dims and derived dims, we can run into constraints
21602157
that are not expressible with the language. For example, flattening a matrix and adding to a vector,
21612158
both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible.
@@ -2199,7 +2196,7 @@ def _export(
21992196
dynamic_shapes,
22002197
strict=strict,
22012198
preserve_module_call_signature=preserve_module_call_signature,
2202-
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
2199+
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
22032200
)
22042201
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
22052202
return ep
@@ -2224,7 +2221,7 @@ def _export(
22242221
dynamic_shapes=dynamic_shapes,
22252222
preserve_module_call_signature=preserve_module_call_signature,
22262223
orig_in_spec=original_in_spec,
2227-
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
2224+
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
22282225
_to_aten_func=functools.partial(
22292226
_export_to_aten_ir,
22302227
pre_dispatch=pre_dispatch,

0 commit comments

Comments
 (0)