🌐 AI搜索 & 代理 主页
Skip to content

Conversation

@BoyuanFeng
Copy link
Contributor

@BoyuanFeng BoyuanFeng commented Oct 27, 2025

Graph partition relies on get_free_symbol_uses() to collect symbol inputs.

def get_scheduler_node_symbol_uses(
node: BaseSchedulerNode,
) -> OrderedSet[sympy.Symbol]:
"""
Gets symbols used in node.
"""
if isinstance(node, FusedSchedulerNode):
return OrderedSet().union(
*(get_scheduler_node_symbol_uses(snode) for snode in node.snodes)
)
assert node.node is not None
free_symbol_uses = node.node.get_free_symbol_uses()
free_symbol_uses.update(
*(get_layout_symints(ir_node) for ir_node in node.node.get_outputs())
)
return free_symbol_uses

I empirically observed that get_free_symbol_uses() becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to get_free_symbol_uses() for 1 node.

Why? Because get_free_symbol_uses() may recursively call another get_free_symbol_uses(), which could recursively run many times.

pytorch/torch/_inductor/ir.py

Lines 4541 to 4543 in ee7434b

result = self.layout.get_free_symbol_uses(
unbacked_only
) | self.data.get_free_symbol_uses(unbacked_only)

This PR fixes the issue by caching the results of get_free_symbol_uses(). I validated on torchtitan that the issue is fixed.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@BoyuanFeng BoyuanFeng added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category module: inductor labels Oct 27, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166338

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0d670c8 with merge base 365ed62 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eellison eellison requested a review from laithsakka October 27, 2025 21:50
@laithsakka
Copy link
Contributor

laithsakka commented Oct 28, 2025

seems reasonable are inductor nodes immutable? @eellison
if NOT wonder if we can do this optimization in a more safe way, within a context that i know that the nodes are not changing i can cache ? that would be dependent on the torchtitan case.

def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV:
key = (id(self), args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = fn(self, *args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at how cache_on_self was implemented, I think we should do something similar here to further improve the performance.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, flexible layout might change stride, although it's unlikely it would induce a new symbol use. Would it be safer to only cache if the layout is fixed ?


@offset.setter
def offset(self, value: Expr) -> None:
self.assert_free_symbol_uses_unchanged("offset", value)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error if free symbols are added or deleted after initialization.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, one last question - at the point we call it all these nodes should have fixed layout. should we just only cache in the fixed layout case ? i think that will be a bit simpler.

@desertfire
Copy link
Contributor

sorry, one last question - at the point we call it all these nodes should have fixed layout. should we just only cache in the fixed layout case ? i think that will be a bit simpler.

That is simpler but probably will have performance implication. @BoyuanFeng , I wonder how much performance difference it will be.

@atalman
Copy link
Contributor

atalman commented Oct 31, 2025

@pytorchmergebot revert -c nosignal -m "Failure: test/nn/test_convolution.py::TestConvolutionNN::test_conv3d_overflow_values GH job link HUD commit link"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@BoyuanFeng your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Oct 31, 2025
This reverts commit a6b1ef1.

Reverted #166338 on behalf of https://github.com/atalman due to Failure: test/nn/test_convolution.py::TestConvolutionNN::test_conv3d_overflow_values [GH job link](https://github.com/pytorch/pytorch/actions/runs/18961173726/job/54149112920) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/a6b1ef17173f56ba93ac97ff4384fa4060b5e41e) ([comment](#166338 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Oct 31, 2025
@BoyuanFeng
Copy link
Contributor Author

@atalman the failure is not related to this pr. I also cannot repro locally. Let me rebase and try ci again.

image

@BoyuanFeng
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

BoyuanFeng added a commit that referenced this pull request Oct 31, 2025
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885

I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.

Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543

This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.

Pull Request resolved: #166338
Approved by: https://github.com/eellison
BoyuanFeng pushed a commit that referenced this pull request Oct 31, 2025
BoyuanFeng added a commit that referenced this pull request Oct 31, 2025
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885

I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.

Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543

This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.

Pull Request resolved: #166338
Approved by: https://github.com/eellison
etaf pushed a commit to etaf/pytorch-inductor-xpu that referenced this pull request Nov 4, 2025
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885

I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.

Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543

This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.

Pull Request resolved: pytorch#166338
Approved by: https://github.com/eellison
etaf pushed a commit to etaf/pytorch-inductor-xpu that referenced this pull request Nov 4, 2025
etaf pushed a commit to etaf/pytorch-inductor-xpu that referenced this pull request Nov 4, 2025
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885

I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.

Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543

This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.

Pull Request resolved: pytorch#166338
Approved by: https://github.com/eellison
@BoyuanFeng
Copy link
Contributor Author

@pytorchbot cherry-pick --onto release/2.9 --fixes "Inductor partition compilation infinite hang issue introduced in 2.9.0 breaking torchtitan" -c fixnewfeature

@pytorchbot
Copy link
Collaborator

Cherry picking #166338

Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x dfebdcab86acbaa0eaa996b47595e5f27a66492e returned non-zero exit code 1

Auto-merging test/inductor/test_torchinductor.py
Auto-merging torch/_inductor/ir.py
CONFLICT (content): Merge conflict in torch/_inductor/ir.py
Auto-merging torch/_inductor/utils.py
CONFLICT (content): Merge conflict in torch/_inductor/utils.py
error: could not apply dfebdcab86a... [GraphPartition] cache get_free_symbol_uses (#166338)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

Lucaskabela pushed a commit that referenced this pull request Nov 4, 2025
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885

I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.

Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543

This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.

Pull Request resolved: #166338
Approved by: https://github.com/eellison

(cherry picked from commit dfebdca)
Lucaskabela pushed a commit that referenced this pull request Nov 4, 2025
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885

I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.

Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543

This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.

Pull Request resolved: #166338
Approved by: https://github.com/eellison

(cherry picked from commit dfebdca)
Lucaskabela pushed a commit that referenced this pull request Nov 4, 2025
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885

I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.

Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543

This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.

Pull Request resolved: #166338
Approved by: https://github.com/eellison

(cherry picked from commit dfebdca)
atalman pushed a commit that referenced this pull request Nov 6, 2025
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885

I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.

Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543

This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.

Pull Request resolved: #166338
Approved by: https://github.com/eellison

(cherry picked from commit dfebdca)

Co-authored-by: Boyuan Feng <boyuan@meta.com>
@github-actions github-actions bot deleted the bf/partition-cache-free-symbols branch December 5, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants