🌐 AI搜索 & 代理 主页
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions test/inductor/test_cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,125 @@ def f(x, flag):
num_partitions = get_num_partitions(code)
self.assertEqual(num_partitions, 1)

@torch._inductor.config.patch("graph_partition", True)
@torch._inductor.config.patch("implicit_fallbacks", True)
def test_graph_partition_with_memory_plan_reuse(self):
BATCH_SIZE = 16
MLP_SIZE = 128
HIDDEN_SIZE = 128
RANDOM_SEED = 0

@torch.library.custom_op(
"silly::attention",
mutates_args=["out"],
tags=(torch._C.Tag.cudagraph_unsafe,),
)
def attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
out.copy_(q + k + v)

@attention.register_fake
def _(q, k, v, out):
return None

class ParentModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x

class Attention(torch.nn.Module):
def __init__(self, mlp_size: int, hidden_size: int) -> None:
super().__init__()
self.pre_attn = torch.nn.Linear(mlp_size, hidden_size, bias=False)
self.post_attn = torch.nn.Linear(hidden_size, mlp_size, bias=False)
self.rms_norm_weight = torch.nn.Parameter(torch.ones(hidden_size))

def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
x_f32 = x.float()
return (
x_f32
* torch.rsqrt(
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6
)
* self.rms_norm_weight
).to(x.dtype)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre_attn(x)
x = self.rms_norm_ref(x)
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = self.rms_norm_ref(x)
x = self.post_attn(x)
return x

class CompiledAttention(torch.nn.Module):
def __init__(
self,
*,
mlp_size: int,
hidden_size: int,
) -> None:
super().__init__()
self.attn = Attention(mlp_size, hidden_size)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(x)

class CompiledAttentionTwo(CompiledAttention):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(x) + x

class SimpleModelWithTwoGraphs(ParentModel):
def __init__(
self,
*,
mlp_size: int,
hidden_size: int,
) -> None:
super().__init__()
self.attn_one = CompiledAttention(
mlp_size=mlp_size,
hidden_size=hidden_size,
)
self.attn_two = CompiledAttentionTwo(
mlp_size=mlp_size,
hidden_size=hidden_size,
)

self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda()

def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz = x.shape[0]
# CUDAGraph expects same tensor addresses for each run
self.hidden_states[:bsz].copy_(x)
x = self.attn_one(self.hidden_states[:bsz])
self.hidden_states[:bsz].copy_(x)
x = self.attn_two(self.hidden_states[:bsz])
return x

eager_model = (
SimpleModelWithTwoGraphs(
mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
)
.eval()
.cuda()
)

compiled_model = torch.compile(eager_model, mode="reduce-overhead")

inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()

for _ in range(3):
eager_out = eager_model(inputs)
compiled_out = compiled_model(inputs)
self.assertEqual(eager_out, compiled_out)

@torch._inductor.config.patch("graph_partition", True)
@torch._inductor.config.patch("triton.cudagraph_trees", False)
def test_graph_partition_gc(self):
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,7 +1808,8 @@ def memory_plan(self):
self.lines = MemoryPlanner(self).plan(self.lines)

def memory_plan_reuse(self):
out_names = V.graph.get_output_names()
outputs = self.get_graph_outputs()
out_names = V.graph._get_output_names(outputs)

while (
self.lines
Expand Down
7 changes: 5 additions & 2 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2479,11 +2479,11 @@ def _compile_to_module_lines(

return mod

def get_output_names(self) -> list[str]:
def _get_output_names(self, graph_outputs: list[ir.IRNode]) -> list[str]:
names = []
shape_counter = itertools.count(0)
none_counter = itertools.count(0)
for node in self.graph_outputs:
for node in graph_outputs:
if isinstance(node, ir.NoneAsConstantBuffer):
names.append(f"{self.name}_none{next(none_counter)}")
elif isinstance(node, ir.ShapeAsConstantBuffer):
Expand All @@ -2492,6 +2492,9 @@ def get_output_names(self) -> list[str]:
names.append(node.get_name())
return names

def get_output_names(self) -> list[str]:
return self._get_output_names(self.graph_outputs)

def is_unspec_arg(self, name: str) -> bool:
# dynamo wraps unspec variable as 0d CPU tensor,
# need to convert to scalar during codegen (triton only)
Expand Down
Loading