🌐 AI搜索 & 代理 主页
Skip to content

Commit db311a5

Browse files
committed
Update on "[Inductor] Node Level provenance tracking"
- use GraphTransformObserver + replace_node hooks to track node sources when they are replaced - add pre_grad_graph tracking to tlparse - add the node provenance information to post_grad_graph tlparse. This is for the frontend to create a mapping between pre_grad and post_grad graph. See an example frontend (this is just a prototype) here: https://drive.google.com/file/d/1cMHH_0y4FJUSS9tATwGQvA72O0Lth8eh/view?usp=sharing - change "action" of NodeSource from a single action to a list of actions. https://docs.google.com/document/d/1dGh9myqNhywmbfP0Quzx_f04bghDFlj8cawj8MopiO8/edit?tab=t.0 The front-end code that takes in the tlparse result is in https://github.com/yushangdi/compiler_explorer. Differential Revision: [D65006709](https://our.internmc.facebook.com/intern/diff/D65006709/) cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
2 parents eb3ad2e + d99ada1 commit db311a5

File tree

3 files changed

+141
-39
lines changed

3 files changed

+141
-39
lines changed

test/fx/test_fx_xform_observer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: fx"]
22

3+
import copy
34
import os
45
import tempfile
56

@@ -92,6 +93,11 @@ def check_node_source(node_source, node_name, target, id, pass_name, action):
9293
self.assertTrue("relu" in ob.created_nodes)
9394
self.assertTrue("neg" in ob.erased_nodes)
9495

96+
self.assertEqual(len(traced._replace_hooks), 0)
97+
self.assertEqual(len(traced._create_node_hooks), 0)
98+
self.assertEqual(len(traced._erase_node_hooks), 0)
99+
self.assertEqual(len(traced._deepcopy_hooks), 0)
100+
95101
for node in traced.graph.nodes:
96102
if node.name == "relu":
97103
from_node = node.meta["from_node"]
@@ -131,3 +137,50 @@ def check_node_source(node_source, node_name, target, id, pass_name, action):
131137
"replace_neg_with_relu",
132138
[NodeSourceAction.REPLACE, NodeSourceAction.CREATE],
133139
)
140+
141+
class SimpleLinearModel(torch.nn.Module):
142+
def forward(self, x):
143+
return torch.neg(x)
144+
145+
model = SimpleLinearModel()
146+
gm = torch.export.export(model, (torch.rand(10),)).module()
147+
148+
with GraphTransformObserver(gm, "test"):
149+
add_node = gm.graph.call_function(torch.ops.aten.add.default, (1, 1))
150+
neg_node = next(
151+
iter([node for node in gm.graph.nodes if node.name == "neg"])
152+
)
153+
neg_node.replace_all_uses_with(replace_with=add_node)
154+
155+
from_node = add_node.meta["from_node"]
156+
self.assertTrue(len(from_node) == 1)
157+
check_node_source(
158+
from_node[0],
159+
"neg",
160+
str(torch.ops.aten.neg.default),
161+
id(gm.graph),
162+
"test",
163+
[NodeSourceAction.REPLACE, NodeSourceAction.CREATE],
164+
)
165+
166+
def test_graph_transform_observer_deepcopy(self):
167+
class SimpleLinearModel(torch.nn.Module):
168+
def forward(self, x):
169+
return torch.neg(x)
170+
171+
model = SimpleLinearModel()
172+
gm = torch.export.export(model, (torch.rand(10),)).module()
173+
174+
with GraphTransformObserver(gm, "test"):
175+
gm2 = copy.deepcopy(gm)
176+
177+
nodes = [node.name for node in gm.graph.nodes]
178+
nodes2 = [node.name for node in gm2.graph.nodes]
179+
self.assertEqual(nodes, nodes2)
180+
181+
# deepcopied graph modules should not have hooks after exiting
182+
# the context
183+
self.assertEqual(len(gm2._replace_hooks), 0)
184+
self.assertEqual(len(gm2._create_node_hooks), 0)
185+
self.assertEqual(len(gm2._erase_node_hooks), 0)
186+
self.assertEqual(len(gm2._deepcopy_hooks), 0)

torch/fx/graph_module.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,8 @@ def __init__(
531531
self._replace_hooks: List[Callable] = []
532532
self._create_node_hooks: List[Callable] = []
533533
self._erase_node_hooks: List[Callable] = []
534+
# Used to remove hooks from deepcopied graph modules within a context manager.
535+
self._deepcopy_hooks: List[Callable] = []
534536

535537
# TorchScript breaks trying to compile the graph setter because of the
536538
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
@@ -888,6 +890,7 @@ def __deepcopy__(self, memo):
888890
"_replace_hooks",
889891
"_create_node_hooks",
890892
"_erase_node_hooks",
893+
"_deepcopy_hooks",
891894
]
892895
for attr in extra_preserved_attrs:
893896
if attr in self.__dict__:
@@ -896,6 +899,8 @@ def __deepcopy__(self, memo):
896899
if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
897900
for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
898901
setattr(res, attr_name, attr)
902+
for hook in self._deepcopy_hooks:
903+
hook(res)
899904
return res
900905

901906
def __copy__(self):
@@ -1002,6 +1007,22 @@ def _unregister_erase_node_hook(self, f):
10021007
assert callable(f), "erase_node hook must be a callable."
10031008
self._erase_node_hooks.remove(f)
10041009

1010+
def _register_deepcopy_hook(self, f):
1011+
"""
1012+
Takes a callable which will be called when we deepcopy this graph module. The
1013+
callable takes the resulting deepcopied graph module.
1014+
"""
1015+
assert callable(f), "deepcopy hook must be a callable."
1016+
self._deepcopy_hooks.append(f)
1017+
1018+
def _unregister_deepcopy_hook(self, f):
1019+
"""
1020+
Takes a callable which was previously registered to be called after deepcopy.
1021+
This function will unregister that callable so it is no longer invoked on deepcopy.
1022+
"""
1023+
assert callable(f), "deepcopy hook must be a callable."
1024+
self._deepcopy_hooks.remove(f)
1025+
10051026

10061027
# workarounds for issues in __torch_function__
10071028

torch/fx/passes/graph_transform_observer.py

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ def __init__(
3939
self.erased_nodes = set()
4040
self.created_nodes = set()
4141
self.name_to_node = {}
42+
# record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context
43+
self.copied_gms = []
44+
45+
self._node_creation_hook = self.get_node_creation_hook()
46+
self._node_erase_hook = self.get_node_erase_hook()
47+
self._node_replace_hook = self.get_node_replace_hook()
48+
self._deepcopy_hook = self.get_deepcopy_hook()
4249

4350
# If log_url is None, we don't log anything
4451
if log_url is None:
@@ -88,23 +95,27 @@ def _check_disable_pass(self):
8895
)
8996

9097
def __enter__(self):
91-
self.gm._register_create_node_hook(self.on_node_creation)
92-
self.gm._register_erase_node_hook(self.on_node_erase)
93-
self.gm._register_replace_node_hook(self.on_node_replace)
98+
self.gm._register_create_node_hook(self._node_creation_hook)
99+
self.gm._register_erase_node_hook(self._node_erase_hook)
100+
self.gm._register_replace_node_hook(self._node_replace_hook)
101+
self.gm._register_deepcopy_hook(self._deepcopy_hook)
94102

95103
self.erased_nodes.clear()
96104
self.created_nodes.clear()
97105
self.name_to_node.clear()
106+
self.copied_gms.clear()
98107

99108
for node in self.gm.graph.nodes:
100109
self.name_to_node[node.name] = node
101110

102111
return self
103112

104113
def __exit__(self, type, value, tb):
105-
self.gm._unregister_create_node_hook(self.on_node_creation)
106-
self.gm._unregister_erase_node_hook(self.on_node_erase)
107-
self.gm._unregister_replace_node_hook(self.on_node_replace)
114+
for gm in self.copied_gms + [self.gm]:
115+
gm._unregister_create_node_hook(self._node_creation_hook)
116+
gm._unregister_erase_node_hook(self._node_erase_hook)
117+
gm._unregister_replace_node_hook(self._node_replace_hook)
118+
gm._unregister_deepcopy_hook(self._deepcopy_hook)
108119

109120
if self.log_url is None or self.gm is None:
110121
return
@@ -140,39 +151,56 @@ def __exit__(self, type, value, tb):
140151
)
141152
)
142153

143-
def on_node_creation(self, node):
144-
self.created_nodes.add(node.name)
145-
self.name_to_node[node.name] = node
146-
source = NodeSource(None, self.passname, NodeSourceAction.CREATE)
147-
if "from_node" not in node.meta:
148-
node.meta["from_node"] = [source]
149-
else:
150-
node.meta["from_node"].append(source)
151-
152-
def on_node_erase(self, node):
153-
self.erased_nodes.add(node.name)
154-
self.name_to_node.pop(node.name, None)
155-
156-
def on_node_replace(self, old: Node, new: str, user: Node):
157-
# Update node meta when replacing old node with new node
158-
new_node = self.name_to_node.get(new, None)
159-
160-
action = [NodeSourceAction.REPLACE]
161-
if new_node.name in self.created_nodes:
162-
action.append(NodeSourceAction.CREATE)
163-
164-
def created_this_pass(source):
165-
return source.passname == self.passname and source.action == [
166-
NodeSourceAction.CREATE
154+
def get_node_creation_hook(self):
155+
# We have to return a function instead of using a class method directly
156+
# to avoid max recursion issue when deepcopy a graph module within the context manager.
157+
def on_node_creation(node):
158+
self.created_nodes.add(node.name)
159+
self.name_to_node[node.name] = node
160+
source = NodeSource(None, self.passname, NodeSourceAction.CREATE)
161+
if "from_node" not in node.meta:
162+
node.meta["from_node"] = [source]
163+
else:
164+
node.meta["from_node"].append(source)
165+
166+
return on_node_creation
167+
168+
def get_node_erase_hook(self):
169+
def on_node_erase(node):
170+
self.erased_nodes.add(node.name)
171+
self.name_to_node.pop(node.name, None)
172+
173+
return on_node_erase
174+
175+
def get_node_replace_hook(self):
176+
def on_node_replace(old: Node, new: str, user: Node):
177+
# Update node meta when replacing old node with new node
178+
new_node = self.name_to_node.get(new, None)
179+
180+
action = [NodeSourceAction.REPLACE]
181+
if new_node.name in self.created_nodes:
182+
action.append(NodeSourceAction.CREATE)
183+
184+
def created_this_pass(source):
185+
return source.pass_name == self.passname and source.action == [
186+
NodeSourceAction.CREATE
187+
]
188+
189+
# remove redundant source added on node creation
190+
new_from_node = new_node.meta.get("from_node", [])
191+
new_from_node = [
192+
source for source in new_from_node if not created_this_pass(source)
167193
]
168194

169-
# remove redundant source added on node creation
170-
new_from_node = new_node.meta.get("from_node", [])
171-
new_from_node = [
172-
source for source in new_from_node if not created_this_pass(source)
173-
]
195+
# add new source
196+
new_node_source = NodeSource(old, self.passname, action)
197+
new_from_node.append(new_node_source)
198+
new_node.meta["from_node"] = new_from_node
199+
200+
return on_node_replace
201+
202+
def get_deepcopy_hook(self):
203+
def on_deepcopy(gm):
204+
self.copied_gms.append(gm)
174205

175-
# add new source
176-
new_node_source = NodeSource(old, self.passname, action)
177-
new_from_node.append(new_node_source)
178-
new_node.meta["from_node"] = new_from_node
206+
return on_deepcopy

0 commit comments

Comments
 (0)