Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _dedup_save_plans.py #126569

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Update _dedup_save_plans.py #126569

wants to merge 7 commits into from

Conversation

bigning
Copy link
Contributor

@bigning bigning commented May 17, 2024

To resolve pytorch#125740, save each tensor on the lowest rank.
@pytorch-bot pytorch-bot bot added module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue labels May 17, 2024
Copy link

pytorch-bot bot commented May 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126569

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ 1 Pending, 1 Unrelated Failure

As of commit c102e7f with merge base 64c581a (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At minimum, we would need this to be optional and it should not replace the current deduplication logic. The current logic is a storage optimization, and this would cause a regression in terms of calling dcp.save on models which are replica heavy.

I think in the originally linked issue, we considered not applying this logic to scalars as a fix for minimizing the number of files which need to be loaded

@bigning
Copy link
Contributor Author

bigning commented May 17, 2024

The current logic is a storage optimization, and this would cause a regression in terms of calling dcp.save on models which are replica heavy.

it's "optimization" only in terms of saving balance. But it hurts the loading performance in multi-node case.

we considered not applying this logic to scalars as a fix

i don't think that works. The root cause is the duplicated tensors are saved in different files, no matter if it's scalar tensor or not.

@bigning bigning requested a review from LucasLLC May 17, 2024 20:20
@bigning
Copy link
Contributor Author

bigning commented May 17, 2024

@LucasLLC , i replied to your comment. Could you please take a look ?

@LucasLLC
Copy link
Contributor

LucasLLC commented May 20, 2024

@bigning , I believe this generally this isn't seen as a large issue during loading since files are all expected to live in the same NFS directory. Additionally, I think we would prioritize saving latency over loading at least in this case since users are typically saving much more often. Could we change this PR to make the de-duplication optional (and true by default)?

# essentially ignores the storage size of anything that is not a tensor, since
# we don't know how much storage they represent
plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1
select_plan_idx = min(plan_indices, key=lambda plan_idx: plan_idx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
select_plan_idx = min(plan_indices, key=lambda plan_idx: plan_idx)
select_plan_idx = min(plan_indices)

@bigning
Copy link
Contributor Author

bigning commented May 20, 2024

@LucasLLC

make the de-duplication optional

If we skip the dedup, it fails here https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/default_planner.py#L386-L387.

files are all expected to live in the same NFS directory.

For NFS, the issue is not about if the file exists in the NFS. It's that each node needs to download the Nx times more files. If it's using cloud, this introduces:

  1. Nx cloud storage money cost for downloading more data
  2. Nx more time blocking the training. For saving, it can be async, but for loading, usually loading checkpointing blocks the training.
  3. Downloading large files from cloud is not stable, this is really a pain point for using cloud storage. The more files you download, the more chance the checkpointing downloading fails.

Can I just add a save_to_same_rank param to dedup_save_plans and DefaultSavePlanner ?

@drisspg drisspg added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2024
@LucasLLC
Copy link
Contributor

Can I just add a save_to_same_rank param to dedup_save_plans and DefaultSavePlanner

@bigning I understand the pain points in having to download multiple files. I think we would accept a PR which makes this behavior configurable.

@bigning
Copy link
Contributor Author

bigning commented May 21, 2024

makes this behavior configurable.

@LucasLLC , this makes sense. I added a param to dedup_save_plans and DefaultSavePlanner, could you take a look? Thanks!

@LucasLLC
Copy link
Contributor

Thanks @bigning ! This looks good to me. Will merge if tests pass

@bigning
Copy link
Contributor Author

bigning commented May 23, 2024

thanks @LucasLLC , it seems two lints failed, I can't find useful error message, do you know how to fix or how to re-run?

@bigning
Copy link
Contributor Author

bigning commented May 23, 2024

thanks @LucasLLC , it seems two lints failed, I can't find useful error message, do you know how to fix or how to re-run?

NVM, i just submitted another commit. Looks all tests are green now.

@bigning
Copy link
Contributor Author

bigning commented May 29, 2024

@LucasLLC , can you help merge?

@LucasLLC
Copy link
Contributor

@pytorchbot merge

Copy link

pytorch-bot bot commented May 31, 2024

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@bigning
Copy link
Contributor Author

bigning commented May 31, 2024

@pytorchbot merge

Copy link

pytorch-bot bot commented May 31, 2024

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@LucasLLC LucasLLC added the topic: not user facing topic category label May 31, 2024
@LucasLLC
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 31, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-13-py3-arm64 / build

Details for Dev Infra team Raised by workflow job

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Distributed Checkpoint] When loading FSDP sharded checkpointing each rank needs all the checkpointing files
6 participants