Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed Checkpoint] When loading FSDP sharded checkpointing each rank needs all the checkpointing files #125740

Open
bigning opened this issue May 8, 2024 · 8 comments 路 May be fixed by #126569
Assignees
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@bigning
Copy link
Contributor

bigning commented May 8, 2024

馃悰 Describe the bug

Hi team, when loading FSDP sharded checkpoint without changing the world size (no resharding), each rank needs to get checkpoint file saved by other ranks. This happens when there are scalar tensors in the optimizer.

Below is a simple two-gpu reproduce:

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from typing import Callable
from torch.distributed import checkpoint as dist_cp
from torch.distributed.checkpoint.planner import LoadItemType
from torch.distributed.checkpoint.state_dict import get_state_dict, StateDictOptions
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy
from torch.optim import SGD

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_0 = nn.Linear(3, 3)
        self.linear_1 = nn.Linear(3, 3)
    def forward(self, x):
        return torch.sum(self.linear_1(self.linear_0(x)))

class MyReader(dist_cp.FileSystemReader):
    def read_data(self, plan, planner):
        need_other_rank_file = False
        for read_item in plan.items:
            relative_file_path = self.storage_data[read_item.storage_index].relative_path
            rank = torch.distributed.get_rank()
            if read_item.type == LoadItemType.TENSOR:
                if (rank == 0 and "__1_0" in relative_file_path) or (rank == 1 and "__0_0" in relative_file_path):
                    print(f"bigning debug rank: {torch.distributed.get_rank()}, path: {relative_file_path}, {read_item}")
                    need_other_rank_file = True
        if need_other_rank_file:
            pass
            #raise RuntimeError("Why rank 0 needs '__1_0.distcp' ?")
        return super().read_data(plan, planner)

class MySGD(SGD):
    def __init__(
        self,
        params,
        lr,
        momentum: float = 0,
        dampening: float = 0,
        weight_decay: float = 0,
        nesterov: bool = False,
    ):
        super().__init__(
            params=params,
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )

    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if 'step' not in state:
                    state['step'] = torch.zeros((), dtype=torch.float, device=p.device)
                state['step'] += 1
        super().step(closure)
    
def main(rank, world_size):
    ## INITIALIZE DIST
    # Running on one node so master_addr is just local host
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "28000"
    # All ranks simulataneously init the process group together.
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    save_path = "./checkpoint"

    model = MyModel().to(f"cuda:{rank}")
    optimizer = MySGD(model.parameters(), lr=0.01)
    fsdp_model = FSDP(
        model,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        use_orig_params=True,
    )
    optimizer.zero_grad()

    x = torch.rand(([2, 3]), device=f"cuda:{rank}")
    y = fsdp_model(x)
    y.backward()
    optimizer.step()

    # save model
    model_state_dict, opt_state_dict = get_state_dict(
        fsdp_model, 
        optimizer, 
        options=StateDictOptions(full_state_dict=False),
    )
    state_dict = {
        "model": model_state_dict,
        "optimizer": opt_state_dict,
    }
    dist_cp.save(
        state_dict=state_dict,
        storage_writer=dist_cp.FileSystemWriter(save_path)
    )
    print(f"bigning debug saving done")

    # load
    dist_cp.load(
        state_dict=state_dict,
        storage_reader=MyReader(save_path),
    )

    dist.destroy_process_group()

if __name__ == "__main__":
    world_size = 2 
    mp.spawn(
        main,
        args = (world_size, ),
        nprocs=world_size,
        join=True
    )

It printed that rank 0 read_items are in __1_0.distcp. output screenshot:
image

so some of the scalar tensors needed by rank0 are only in __1_0.distcp file, which is saved by rank 1. Imagine for a multiple node training, when loading checkpoint, node-0 rank needs files saved by other nodes which are not available on node-0. Or with remote checkpointing saved in cloud, each node needs to download all the sharded checkpoint files.

I think it's because here when deduping save plans , it tries to balance the storage needed for each rank, so those replicated scalar tensors are saved in to different files. So when loading the checkpoint, even without resharding, each rank still needs almost all files to load the tensor. Can we save those replicate tensors to rank-0 only ?

Versions

PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.31

Python version: 3.11.9 (main, Apr 6 2024, 17:59:24) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

@bigning
Copy link
Contributor Author

bigning commented May 8, 2024

cc @mvpatel2000 , @wanchaol

@bdhirsh bdhirsh added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 8, 2024
@bigning
Copy link
Contributor Author

bigning commented May 9, 2024

cc @LucasLLC

@LucasLLC LucasLLC self-assigned this May 10, 2024
@LucasLLC LucasLLC added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 10, 2024
@LucasLLC
Copy link
Contributor

Thanks for creating the issue @bigning !

One solution is to consider making the de-duplication of scalar tensors (or potentially even non-DTensors) optional here:

def dedup_save_plans(all_plans: List[SavePlan]) -> List[SavePlan]:

I'm not sure I can prioritize this immediately, but would welcome a PR if anyone is interested.

@bigning
Copy link
Contributor Author

bigning commented May 10, 2024

@LucasLLC , why not just revert to the torch2.2.2 dedup_tensors? https://github.com/pytorch/pytorch/blob/v2.2.2/torch/distributed/checkpoint/_dedup_tensors.py#L29-L58 . It still does the dedup, but just put all tensors to the lowest rank.

@bigning
Copy link
Contributor Author

bigning commented May 15, 2024

@LucasLLC , is it possible to include the fix into 2.3.1? So we don't need this patch in our composer repo

bigning added a commit to bigning/pytorch that referenced this issue May 17, 2024
To resolve pytorch#125740, save each tensor on the lowest rank.
@bigning bigning linked a pull request May 17, 2024 that will close this issue
@LucasLLC
Copy link
Contributor

Hey @bigning , reverting to dedup_tensors would remove the performance gains we get from load balancing tensors during serialization. Additionally I think we're viewing this as an additional request instead of a bug.

I don't think we have time to prioritize this before 2.3.1. :/

@bigning
Copy link
Contributor Author

bigning commented May 17, 2024

performance gains we get from load balancing tensors during serialization

according to the PR description #116469, you only tested the performance in saving. but it really hurt the loading. After your change, in the worst case for sharded FSDP state dict loading, every node needs to donwload all the checkpointing files. That's N times regression, where N is the number of nodes.

I submited a simple PR #126569 , could you review ?

@bigning
Copy link
Contributor Author

bigning commented May 17, 2024

I think we're viewing this as an additional request instead of a bug.

i think it's a regression of sharded checkpoint loading. Torch 2.2 doesn't have this problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants