🌐 AI搜索 & 代理 主页
Skip to content

Commit 4a63cab

Browse files
bdhirshpytorchmergebot
authored andcommitted
[cudagraphs] Fix issue in collecting static_input_idxs (#152287)
related to #152275 Pull Request resolved: #152287 Approved by: https://github.com/bdhirsh, https://github.com/eellison Co-authored-by: Brian Hirsh <hirsheybar@fb.com>
1 parent bce7f0a commit 4a63cab

File tree

6 files changed

+91
-13
lines changed

6 files changed

+91
-13
lines changed

test/dynamo/test_subclasses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,9 +2135,9 @@ def inner_compile(
21352135
extern_node_serializer: Optional[Callable[[list[Any]], Any]] = None,
21362136
):
21372137
if dynamic:
2138-
self.assertEqual(static_input_idxs, [0, 1, 2, 3, 4])
2138+
self.assertEqual(static_input_idxs, [2, 3, 4])
21392139
else:
2140-
self.assertEqual(static_input_idxs, [0, 1, 2])
2140+
self.assertEqual(static_input_idxs, [1, 2])
21412141
return gm
21422142

21432143
compiler = functools.partial(compile_fx, inner_compile=inner_compile)

test/inductor/test_cudagraph_trees.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2426,6 +2426,40 @@ def fn(x, y):
24262426
self.run_static_input_param_test(fn, 4)
24272427
self.assertEqual(counters["inductor"]["cudagraph_skips"], 0)
24282428

2429+
@torch._dynamo.config.patch("error_on_recompile", True)
2430+
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
2431+
def test_no_rerecord_with_mark_static_address(self):
2432+
class Mod(torch.nn.Module):
2433+
def __init__(self):
2434+
super().__init__()
2435+
self.linear = nn.Linear(2, 2)
2436+
2437+
def forward(self, x):
2438+
return self.linear(x)
2439+
2440+
mod = Mod().cuda()
2441+
2442+
def fn_eager(x, marked_static_y):
2443+
return torch.cos(x) + mod(marked_static_y)
2444+
2445+
with torch.device("cuda"):
2446+
fn_compiled = torch.compile(fn_eager, mode="reduce-overhead")
2447+
2448+
# y is marked static
2449+
y = torch.randn(2, 2)
2450+
torch._dynamo.mark_static_address(y)
2451+
2452+
# Chanhing pointer of x should not lead to re-records
2453+
for _ in range(5):
2454+
x = torch.randn(2, 2, requires_grad=True)
2455+
res = fn_compiled(x, y)
2456+
res.sum().backward()
2457+
x.grad = None
2458+
mod.linear.weight.grad = None
2459+
mod.linear.bias.grad = None
2460+
# One forward and one backward
2461+
self.assertEqual(self.get_manager().new_graph_id().id, 2)
2462+
24292463
def test_tensor_constant_mutation(self):
24302464
class Foo(torch.nn.Module):
24312465
def __init__(self) -> None:

test/inductor/test_inductor_freezing.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: inductor"]
22
import contextlib
3+
import copy
34
import functools
45
import importlib
56
import itertools
@@ -375,6 +376,35 @@ def foo(mod, x):
375376
):
376377
mod(x)
377378

379+
def test_static_indices_cudagraph(self):
380+
if self.device != "cuda":
381+
return
382+
383+
mod1 = torch.nn.Sequential(
384+
torch.nn.Linear(2, 2).to(self.device), torch.nn.Linear(2, 2).to(self.device)
385+
)
386+
mod2 = copy.deepcopy(mod1)
387+
388+
def fn(x, y, mod):
389+
x.add_(1)
390+
getattr(mod, "0").bias.add_(2)
391+
getattr(mod, "1").weight.add_(3)
392+
return mod(x) + y
393+
394+
x1 = torch.randn(2, 2, device=self.device)
395+
y1 = torch.randn(2, 2, device=self.device)
396+
x2 = x1.clone()
397+
y2 = y1.clone()
398+
399+
opt_fn = torch.compile(fn, mode="reduce-overhead")
400+
401+
with torch.no_grad():
402+
ref = fn(x1, y1, mod1)
403+
res = opt_fn(x2, y2, mod2)
404+
self.assertEqual(ref, res)
405+
self.assertEqual(x1, x2)
406+
self.assertEqual(y1, y2)
407+
378408
def test_rng_op(self):
379409
@torch.compile()
380410
def foo():

torch/_functorch/aot_autograd.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,18 +1028,20 @@ def _try_get_metadata_from_dynamo(
10281028
seen_sources = set()
10291029

10301030
aot_autograd_arg_pos_to_source = []
1031+
static_input_indices = []
10311032
# Collect the new inputs lifted by aotdispatch
1032-
for name in param_keys:
1033+
for i, name in enumerate(param_keys):
10331034
assert name in param_name_to_source, f"{name} not found."
10341035
source = param_name_to_source[name]
10351036
assert source not in seen_sources, source
10361037
seen_sources.add(source)
10371038
aot_autograd_arg_pos_to_source.append(source)
10381039

1040+
static_input_indices.append(i)
1041+
10391042
# Collect the dynamo graph inputs
10401043
# TODO(mlazos): Revisit if this is still needed. With Dynamo install ID
10411044
# matched tensors back into the Fx graph, this might not be necessary.
1042-
static_input_indices = []
10431045
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
10441046
assert hasattr(node, "_dynamo_source")
10451047
source = node._dynamo_source
@@ -1048,16 +1050,22 @@ def _try_get_metadata_from_dynamo(
10481050
aot_autograd_arg_pos_to_source.append(source)
10491051
source_name = source.name() if source else str(source)
10501052

1053+
# input[i] in dynamo is now:
1054+
# input[i + len(extra_params)] in AOT,
1055+
# where extra_params are the params/buffers that dynamo baked into the
1056+
# OutputGraph
1057+
actual_pos = pos + len(param_keys)
1058+
10511059
if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
10521060
"_dynamo_static_input_type", None
10531061
):
10541062
static_inputs_log.debug(
1055-
"Adding static input pos %s for source %s", pos, source_name
1063+
"Adding static input pos %s for source %s", actual_pos, source_name
10561064
)
1057-
static_input_indices.append(pos)
1065+
static_input_indices.append(actual_pos)
10581066
else:
10591067
static_inputs_log.debug(
1060-
"Non-static input pos %s for source %s", pos, source_name
1068+
"Non-static input pos %s for source %s", actual_pos, source_name
10611069
)
10621070

10631071
assert full_args_num == len(aot_autograd_arg_pos_to_source)

torch/_inductor/compile_fx.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def get_static_input_idxs(num_fixed: int) -> list[int]:
212212
if not context or not context.fw_metadata:
213213
return fixed
214214

215-
return fixed + context.fw_metadata.static_input_indices
215+
return context.fw_metadata.static_input_indices
216216

217217

218218
def record_original_output_strides(gm: GraphModule) -> None:
@@ -1745,7 +1745,6 @@ def fw_compiler_freezing(
17451745
)
17461746

17471747
aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices]
1748-
num_fixed = len(preserved_arg_indices) - num_example_inputs
17491748

17501749
fake_mode = detect_fake_mode(aot_example_inputs)
17511750

@@ -1756,7 +1755,7 @@ def fw_compiler_freezing(
17561755
idx for idx, n in enumerate(model_outputs) if isinstance(n, torch.fx.Node)
17571756
]
17581757

1759-
static_input_idxs = list(range(num_fixed))
1758+
static_input_idxs = []
17601759
# constant params will be real tensors, not fake
17611760
tracing_context = torch._guards.TracingContext.try_get()
17621761
unwrapped_args_offsets = [0]
@@ -1788,7 +1787,7 @@ def fw_compiler_freezing(
17881787
tracing_context.params_flat[i] = None
17891788

17901789
if tracing_context.fw_metadata:
1791-
static_input_idxs += tracing_context.fw_metadata.static_input_indices
1790+
static_input_idxs = tracing_context.fw_metadata.static_input_indices
17921791

17931792
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
17941793
optimized_function = inner_compile(

torch/_inductor/freezing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,21 @@ def replace_params_with_constants(
5252
in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH)
5353
]
5454

55+
static_indices_new = []
56+
static_indices_offset = 0
5557
for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)):
5658
if i in mutated_inps or i in aliased_input_args:
5759
preserved_arg_indices.append(i)
58-
continue
59-
replace_node_with_constant(gm, node, real_input)
60+
if i in fw_metadata.static_input_indices:
61+
new_static_index = i - static_indices_offset
62+
static_indices_new.append(new_static_index)
63+
else:
64+
replace_node_with_constant(gm, node, real_input)
65+
static_indices_offset += 1
6066
# add on non param inputs
6167
preserved_arg_indices.extend(range(len(flat_params), len(params)))
6268
# is this necessary ?
69+
fw_metadata.static_input_indices = static_indices_new
6370
gm.recompile()
6471
return preserved_arg_indices
6572

0 commit comments

Comments
 (0)