Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tracing through __getitem__ -> __len__ for ModuleList fails. #126445

Open
laithsakka opened this issue May 16, 2024 · 4 comments
Open

Tracing through __getitem__ -> __len__ for ModuleList fails. #126445

laithsakka opened this issue May 16, 2024 · 4 comments
Assignees
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@laithsakka
Copy link
Contributor

laithsakka commented May 16, 2024

python test/dynamo/test_dynamic_shapes.py -k DynamicShapesNNModuleTests.test_modulelist_dynamic_shapes

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@laithsakka
Copy link
Contributor Author

@laithsakka laithsakka changed the title enable dynamo/test_dynamic_shapes.py tests with nn module inlining Tracing through get_item for ModuleList fails. May 16, 2024
@laithsakka laithsakka changed the title Tracing through get_item for ModuleList fails. Tracing through __getitem__ for ModuleList fails. May 16, 2024
@laithsakka
Copy link
Contributor Author

laithsakka commented May 16, 2024

Ok I looked at this, and here is whats happening:
The module we are tracing is:

class ModuleList(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.ModuleList(
            [
                torch.nn.Linear(10, 10),
            ]
        )

    def forward(self, x):
        for idx, layer in enumerate(self.layers[::-1]):
            # pass
            x = layer(x) * idx

        return x
  1. when inlining disabled, when we trace through
 for idx, layer in enumerate(self.layers[::-1]):

we need to call getitem but this is special handled in

elif name == "__getitem__":

and we do not trace through getitem for ModuleList.

  1. When we enable we trace through getItem which ends up calling
    def len(self) -> int:
    def __len__(self) -> int:

and we fail to trace through the length call.

@laithsakka laithsakka reopened this May 16, 2024
@laithsakka laithsakka changed the title Tracing through __getitem__ for ModuleList fails. Tracing through __getitem__ -> __len__ for ModuleList fails. May 16, 2024
@laithsakka
Copy link
Contributor Author

basically we fail exactly here:

  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/variables/functions.py", line 341, in call_function
    return super().call_function(tx, args, kwargs)
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/symbolic_convert.py", line 743, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/symbolic_convert.py", line 2447, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/symbolic_convert.py", line 2563, in inline_call_
    tracer.run()
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/symbolic_convert.py", line 1253, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/variables/builtin.py", line 948, in call_function
    return handler(tx, args, kwargs)
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/variables/builtin.py", line 832, in builtin_dipatch
    rv = fn(tx, args, kwargs)
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/variables/builtin.py", line 750, in call_self_handler
    result = self_handler(tx, *args, **kwargs)
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/variables/builtin.py", line 1347, in call_len
    return args[0].call_method(tx, "__len__", args[1:], kwargs)
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/variables/misc.py", line 694, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/variables/base.py", line 320, in call_method
    unimplemented(f"call_method {self} {name} {args} {kwargs}")
  File "/data/users/lsakka/pytorch/pytorch/torch/_dynamo/exc.py", line 216, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_method GetAttrVariable(UnspecializedNNModuleVariable(ModuleList), _modules) __len__ () {}
V0516 16:12:16.275000 140569933899584 torch/_dynamo/symbolic_convert.py:2535] [0/0] INLINING <code object __len__ at 0x7fd8f216a4a0, file "/data/users/lsakka/pytorch/pytorch/torch/nn/modules/container.py", line 311>, inlined according trace_rules.lookup MOD_INLINELIST
V0516 16:12:16.275000 140569933899584 torch/_dynamo/symbolic_convert.py:769] [0/0] [__trace_source] TRACE starts_line /data/users/lsakka/pytorch/pytorch/torch/nn/modules/container.py:313 in __len__ (ModuleList.__len__) (inline depth: 5)
V0516 16:12:16.275000 140569933899584 torch/_dynamo/symbolic_convert.py:769] [0/0] [__trace_source]             return len(self._modules)
V0516 16:12:16.276000 140569933899584 torch/_dynamo/symbolic_convert.py:792] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL len []
V0516 16:12:16.276000 140569933899584 torch/_dynamo/symbolic_convert.py:792] [0/0] [__trace_bytecode] TRACE LOAD_FAST self [BuiltinVariable()]
V0516 16:12:16.276000 140569933899584 torch/_dynamo/symbolic_convert.py:792] [0/0] [__trace_bytecode] TRACE LOAD_ATTR _modules [BuiltinVariable(), UnspecializedNNModuleVariable()]
V0516 16:12:16.277000 140569933899584 torch/_dynamo/symbolic_convert.py:792] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION 1 [BuiltinVariable(), GetAttrVariable()]
    def call_method(
        self,
        tx,
        name,
        args: "List[VariableTracker]",
        kwargs: "Dict[str, VariableTracker]",
    ) -> "VariableTracker":
        if name == "__len__" and self.has_unpack_var_sequence(tx):
            assert not (args or kwargs)
            return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx)))
        elif (
            name == "__getattr__"
            and len(args) == 1
            and args[0].is_python_constant()
            and not kwargs
        ):
            return self.var_getattr(tx, args[0].as_python_constant())
        unimplemented(f"call_method {self} {name} {args} {kwargs}")

@laithsakka
Copy link
Contributor Author

I think the issue is that self._module is a dictionary but we represent that as GetAttrVariable instead of ConstDictVariable, but i do now know if we its constant and if we can use ConstDictVariable,

the type of _module is _modules: Dict[str, Module] # type: ignore[assignment]

@laithsakka laithsakka removed their assignment May 16, 2024
@xmfan xmfan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants