Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic module mapping using torch.fx #40

Open
xrsrke opened this issue Nov 22, 2023 · 3 comments
Open

Automatic module mapping using torch.fx #40

xrsrke opened this issue Nov 22, 2023 · 3 comments
Assignees
Labels
good first issue Good for newcomers
Milestone

Comments

@xrsrke
Copy link
Owner

xrsrke commented Nov 22, 2023

Notes

  • Location: pipegoose.nn.parallel_mapping.ParallelMapping
  • module is an instance in model.named_modules()
  • model is AutoModelForCausalLM.from_pretrained(), torch.nn.Transformer
  • Make it generalizable

APIs

model = AutoModel.from_pretrained(...)
ParallelMapping(model).is_mlp(module)

Write a function that:

  • ParallelMapping(model).is_column_parallel(name, module) returns True if the module is the first linear layer in an MLP layer, or if the module is a query, key, value linear, or a fused qkv linear of an attention layer, or an input embedding; otherwise, return False.
  • ParallelMapping(model).is_row_parallel(name, module) returns True if the module is the second linear layer in an MLP layer, or the output projection of an attention layer.
  • ParallelMapping(model).is_lm_head(name, module) returns True if the module is language model head.
  • ParallelMapping(model).is_text_embedding(name, module) returns True if the module is text embedding module.
  • ParallelMapping(model).is_mlp(name, module) returns True if the module is an MLP layer
@xrsrke xrsrke added help wanted Extra attention is needed good first issue Good for newcomers labels Nov 22, 2023
@xrsrke xrsrke added this to the v1 milestone Nov 22, 2023
@xrsrke xrsrke removed the help wanted Extra attention is needed label Nov 27, 2023
@giorgionicoli
Copy link

I'm taking a stab at this issue. I'll first come up with a solution that works for at least bloom-560m, which is already mapped in ParallelMapping, such that we can see if the automatic mapping works. Then we can work on making it more general/generalizable.

@xrsrke xrsrke added the help wanted Extra attention is needed label Dec 10, 2023
@yugen-ok
Copy link

yugen-ok commented Dec 10, 2023

I'm working on this issue.

For testing, can I see some MWEs of what the input of these methods looks like?
Like, typical inputs to make sure I get the right results.

@xrsrke
Copy link
Owner Author

xrsrke commented Dec 10, 2023

@yugen-ok

model = AutoModel.from_pretrained()
parallel_mapping = ParallelMapping(model)
ref_mapping = {...}

for name, _ in model.named_modules():
   mapping[name] = parallel_mapping.is_row_parallel(name, module)

assert ref_mapping == mapping

@xrsrke xrsrke removed the help wanted Extra attention is needed label Dec 11, 2023
@yugen-ok yugen-ok mentioned this issue Dec 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
Status: In Progress
Development

No branches or pull requests

3 participants