Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove activation checkpointing tag to get correct FQNs (#124698) #126559

Merged
merged 1 commit into from
May 24, 2024

Conversation

mvpatel2000
Copy link
Contributor

@mvpatel2000 mvpatel2000 commented May 17, 2024

Cherry-pick for release branch

Fixes #124546

When setting use_orig_params = False and using activation checkpointing, the FQN mapping as retrieved by the _get_fqns function is incorrect because the prefix that is added to the name of each activation checkpointed module, _checkpoint_wrapped_module, can still be present. I think this is an edge case with the _get_fqns function that was not addressed by this previous commit #118119.

Without the change, the list of object names for an activation checkpointed module with FSDP (and use_orig_params=False) can be something like:

['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', '_flat_param']

Which will incorrectly return just one FQN, {'model.transformer.blocks.0._flat_param'}, when all the FQNs of the parameters of the transformer block should be returned.

With the change, the list of object names will now have _checkpoint_wrapped_module removed:

['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_flat_param']

And the FQNs are correctly retrieved and returned in _get_fqns when this condition is satisfied. The correct FQNs are:

{'model.transformer.blocks.0.attn.Wqkv.bias', 'model.transformer.blocks.0.ffn.up_proj.bias',
'model.transformer.blocks.0.attn.out_proj.weight', 'model.transformer.blocks.0.norm_2.weight',
'model.transformer.blocks.0.ffn.down_proj.weight', 'model.transformer.blocks.0.attn.Wqkv.weight',
'model.transformer.blocks.0.norm_2.bias', 'model.transformer.blocks.0.ffn.up_proj.weight',
'model.transformer.blocks.0.ffn.down_proj.bias', 'model.transformer.blocks.0.norm_1.bias',
'model.transformer.blocks.0.norm_1.weight', 'model.transformer.blocks.0.attn.out_proj.bias'}

Pull Request resolved: #124698
Approved by: https://github.com/Skylion007

Fixes #ISSUE_NUMBER

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC

Fixes pytorch#124546

When setting `use_orig_params = False` and using activation checkpointing, the FQN mapping as retrieved by the `_get_fqns` function is incorrect because the prefix that is added to the name of each activation checkpointed module, `_checkpoint_wrapped_module`, can still be present. I think this is an edge case with the `_get_fqns` function that was not addressed by this previous commit pytorch#118119.

Without the change, the list of object names for an activation checkpointed module with FSDP (and `use_orig_params=False`) can be something like:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', '_flat_param']
```
Which will incorrectly return just one FQN, `{'model.transformer.blocks.0._flat_param'}`, when all the FQNs of the parameters of the transformer block should be returned.

With the change, the list of object names will now have `_checkpoint_wrapped_module` removed:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_flat_param']
```
And the FQNs are correctly retrieved and returned in `_get_fqns` when [this condition](https://github.com/pytorch/pytorch/blob/ea61c9cb299b6dfebc57dc9d8821c34321d568ab/torch/distributed/checkpoint/state_dict.py#L168) is satisfied. The correct FQNs are:
```
{'model.transformer.blocks.0.attn.Wqkv.bias', 'model.transformer.blocks.0.ffn.up_proj.bias',
'model.transformer.blocks.0.attn.out_proj.weight', 'model.transformer.blocks.0.norm_2.weight',
'model.transformer.blocks.0.ffn.down_proj.weight', 'model.transformer.blocks.0.attn.Wqkv.weight',
'model.transformer.blocks.0.norm_2.bias', 'model.transformer.blocks.0.ffn.up_proj.weight',
'model.transformer.blocks.0.ffn.down_proj.bias', 'model.transformer.blocks.0.norm_1.bias',
'model.transformer.blocks.0.norm_1.weight', 'model.transformer.blocks.0.attn.out_proj.bias'}
```

Pull Request resolved: pytorch#124698
Approved by: https://github.com/Skylion007
@pytorch-bot pytorch-bot bot added module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue labels May 17, 2024
Copy link

pytorch-bot bot commented May 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126559

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit e4fe159 with merge base 86a2d67 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorchmergebot pushed a commit that referenced this pull request May 23, 2024
Adding test for this cherry pick. #126559

Pull Request resolved: #126935
Approved by: https://github.com/fegin
wz337 added a commit to wz337/pytorch that referenced this pull request May 23, 2024
wz337 added a commit to wz337/pytorch that referenced this pull request May 23, 2024
wz337 added a commit to wz337/pytorch that referenced this pull request May 23, 2024
@atalman atalman merged commit 19058a6 into pytorch:release/2.3 May 24, 2024
97 of 98 checks passed
atalman pushed a commit that referenced this pull request May 24, 2024
…26992)

[DCP][AC] Add test for apply AC with FSDP1 (#126935)

Adding test for this cherry pick. #126559

Pull Request resolved: #126935
Approved by: https://github.com/fegin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue open source
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants