-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Distributed Checkpoint] When loading FSDP sharded checkpointing each rank needs all the checkpointing files #125740
Comments
cc @mvpatel2000 , @wanchaol |
cc @LucasLLC |
Thanks for creating the issue @bigning ! One solution is to consider making the de-duplication of scalar tensors (or potentially even non-DTensors) optional here:
I'm not sure I can prioritize this immediately, but would welcome a PR if anyone is interested. |
@LucasLLC , why not just revert to the torch2.2.2 |
@LucasLLC , is it possible to include the fix into 2.3.1? So we don't need this patch in our composer repo |
To resolve pytorch#125740, save each tensor on the lowest rank.
Hey @bigning , reverting to I don't think we have time to prioritize this before 2.3.1. :/ |
according to the PR description #116469, you only tested the performance in saving. but it really hurt the loading. After your change, in the worst case for sharded FSDP state dict loading, every node needs to donwload all the checkpointing files. That's N times regression, where N is the number of nodes. I submited a simple PR #126569 , could you review ? |
i think it's a regression of sharded checkpoint loading. Torch 2.2 doesn't have this problem. |
馃悰 Describe the bug
Hi team, when loading FSDP sharded checkpoint without changing the world size (no resharding), each rank needs to get checkpoint file saved by other ranks. This happens when there are scalar tensors in the optimizer.
Below is a simple two-gpu reproduce:
It printed that rank 0 read_items are in
__1_0.distcp
. output screenshot:so some of the scalar tensors needed by rank0 are only in
__1_0.distcp
file, which is saved by rank 1. Imagine for a multiple node training, when loading checkpoint, node-0 rank needs files saved by other nodes which are not available on node-0. Or with remote checkpointing saved in cloud, each node needs to download all the sharded checkpoint files.I think it's because here when deduping save plans , it tries to balance the storage needed for each rank, so those replicated scalar tensors are saved in to different files. So when loading the checkpoint, even without resharding, each rank still needs almost all files to load the tensor. Can we save those replicate tensors to rank-0 only ?
Versions
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.31
Python version: 3.11.9 (main, Apr 6 2024, 17:59:24) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k
The text was updated successfully, but these errors were encountered: