@@ -1028,18 +1028,20 @@ def _try_get_metadata_from_dynamo(
10281028 seen_sources = set ()
10291029
10301030 aot_autograd_arg_pos_to_source = []
1031+ static_input_indices = []
10311032 # Collect the new inputs lifted by aotdispatch
1032- for name in param_keys :
1033+ for i , name in enumerate ( param_keys ) :
10331034 assert name in param_name_to_source , f"{ name } not found."
10341035 source = param_name_to_source [name ]
10351036 assert source not in seen_sources , source
10361037 seen_sources .add (source )
10371038 aot_autograd_arg_pos_to_source .append (source )
10381039
1040+ static_input_indices .append (i )
1041+
10391042 # Collect the dynamo graph inputs
10401043 # TODO(mlazos): Revisit if this is still needed. With Dynamo install ID
10411044 # matched tensors back into the Fx graph, this might not be necessary.
1042- static_input_indices = []
10431045 for pos , node in enumerate (mod .graph .find_nodes (op = "placeholder" )):
10441046 assert hasattr (node , "_dynamo_source" )
10451047 source = node ._dynamo_source
@@ -1048,16 +1050,22 @@ def _try_get_metadata_from_dynamo(
10481050 aot_autograd_arg_pos_to_source .append (source )
10491051 source_name = source .name () if source else str (source )
10501052
1053+ # input[i] in dynamo is now:
1054+ # input[i + len(extra_params)] in AOT,
1055+ # where extra_params are the params/buffers that dynamo baked into the
1056+ # OutputGraph
1057+ actual_pos = pos + len (param_keys )
1058+
10511059 if "tensor_dict" in node .meta and node .meta ["tensor_dict" ].get (
10521060 "_dynamo_static_input_type" , None
10531061 ):
10541062 static_inputs_log .debug (
1055- "Adding static input pos %s for source %s" , pos , source_name
1063+ "Adding static input pos %s for source %s" , actual_pos , source_name
10561064 )
1057- static_input_indices .append (pos )
1065+ static_input_indices .append (actual_pos )
10581066 else :
10591067 static_inputs_log .debug (
1060- "Non-static input pos %s for source %s" , pos , source_name
1068+ "Non-static input pos %s for source %s" , actual_pos , source_name
10611069 )
10621070
10631071 assert full_args_num == len (aot_autograd_arg_pos_to_source )
0 commit comments