-
Notifications
You must be signed in to change notification settings - Fork 26.2k
[GraphPartition] cache get_free_symbol_uses #166338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166338
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0d670c8 with merge base 365ed62 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
seems reasonable are inductor nodes immutable? @eellison |
torch/_inductor/utils.py
Outdated
| def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV: | ||
| key = (id(self), args, tuple(sorted(kwargs.items()))) | ||
| if key not in cache: | ||
| cache[key] = fn(self, *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at how cache_on_self was implemented, I think we should do something similar here to further improve the performance.
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, flexible layout might change stride, although it's unlikely it would induce a new symbol use. Would it be safer to only cache if the layout is fixed ?
|
|
||
| @offset.setter | ||
| def offset(self, value: Expr) -> None: | ||
| self.assert_free_symbol_uses_unchanged("offset", value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
error if free symbols are added or deleted after initialization.
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, one last question - at the point we call it all these nodes should have fixed layout. should we just only cache in the fixed layout case ? i think that will be a bit simpler.
That is simpler but probably will have performance implication. @BoyuanFeng , I wonder how much performance difference it will be. |
|
@pytorchmergebot revert -c nosignal -m "Failure: test/nn/test_convolution.py::TestConvolutionNN::test_conv3d_overflow_values GH job link HUD commit link" |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@BoyuanFeng your PR has been successfully reverted. |
This reverts commit a6b1ef1. Reverted #166338 on behalf of https://github.com/atalman due to Failure: test/nn/test_convolution.py::TestConvolutionNN::test_conv3d_overflow_values [GH job link](https://github.com/pytorch/pytorch/actions/runs/18961173726/job/54149112920) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/a6b1ef17173f56ba93ac97ff4384fa4060b5e41e) ([comment](#166338 (comment)))
|
@atalman the failure is not related to this pr. I also cannot repro locally. Let me rebase and try ci again.
|
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885 I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node. Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543 This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed. Pull Request resolved: #166338 Approved by: https://github.com/eellison
This reverts commit a6b1ef1. Reverted #166338 on behalf of https://github.com/atalman due to Failure: test/nn/test_convolution.py::TestConvolutionNN::test_conv3d_overflow_values [GH job link](https://github.com/pytorch/pytorch/actions/runs/18961173726/job/54149112920) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/a6b1ef17173f56ba93ac97ff4384fa4060b5e41e) ([comment](#166338 (comment)))
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885 I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node. Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543 This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed. Pull Request resolved: #166338 Approved by: https://github.com/eellison
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885 I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node. Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543 This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed. Pull Request resolved: pytorch#166338 Approved by: https://github.com/eellison
This reverts commit a6b1ef1. Reverted pytorch#166338 on behalf of https://github.com/atalman due to Failure: test/nn/test_convolution.py::TestConvolutionNN::test_conv3d_overflow_values [GH job link](https://github.com/pytorch/pytorch/actions/runs/18961173726/job/54149112920) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/a6b1ef17173f56ba93ac97ff4384fa4060b5e41e) ([comment](pytorch#166338 (comment)))
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885 I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node. Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543 This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed. Pull Request resolved: pytorch#166338 Approved by: https://github.com/eellison
|
@pytorchbot cherry-pick --onto release/2.9 --fixes "Inductor partition compilation infinite hang issue introduced in 2.9.0 breaking torchtitan" -c fixnewfeature |
Cherry picking #166338Command Details for Dev Infra teamRaised by workflow job |
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885 I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node. Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543 This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed. Pull Request resolved: #166338 Approved by: https://github.com/eellison (cherry picked from commit dfebdca)
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885 I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node. Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543 This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed. Pull Request resolved: #166338 Approved by: https://github.com/eellison (cherry picked from commit dfebdca)
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885 I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node. Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543 This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed. Pull Request resolved: #166338 Approved by: https://github.com/eellison (cherry picked from commit dfebdca)
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885 I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node. Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543 This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed. Pull Request resolved: #166338 Approved by: https://github.com/eellison (cherry picked from commit dfebdca) Co-authored-by: Boyuan Feng <boyuan@meta.com>

Graph partition relies on
get_free_symbol_uses()to collect symbol inputs.pytorch/torch/_inductor/scheduler.py
Lines 4869 to 4885 in ee7434b
I empirically observed that
get_free_symbol_uses()becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds toget_free_symbol_uses()for 1 node.Why? Because
get_free_symbol_uses()may recursively call anotherget_free_symbol_uses(), which could recursively run many times.pytorch/torch/_inductor/ir.py
Lines 4541 to 4543 in ee7434b
This PR fixes the issue by caching the results of
get_free_symbol_uses(). I validated on torchtitan that the issue is fixed.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben