Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DMP doesn't broadcast DataParallel ShardingType embedding table from the process with rank 0 to all other processes #1739

Open
tiankongdeguiji opened this issue Mar 1, 2024 · 13 comments

Comments

@tiankongdeguiji
Copy link

DMP should broadcast DataParallel ShardingType embedding table param from the process with rank 0 to all other processes in the group to make sure that all model replicas start from the exact same state.

However, it has been observed that DMP currently designates all sharded_parameter_names as part of the params_and_buffers_to_ignore list for Distributed Data Parallel (DDP) operations. This behavior leads to a situation where DMP omits the necessary synchronization of the DataParallel ShardingType embedding table parameters during the initialization phase. As a consequence, model replicas may start from different states, which could result in inconsistent training outcomes and potentially compromise model convergence.

class DefaultDataParallelWrapper(DataParallelWrapper):
    ...

    def wrap(
        self,
        dmp: "DistributedModelParallel",
        env: ShardingEnv,
        device: torch.device,
    ) -> None:
        if isinstance(dmp._dmp_wrapped_module, DistributedDataParallel) or isinstance(
            dmp._dmp_wrapped_module, FullyShardedDataParallel
        ):
            return
        sharded_parameter_names = set(
            DistributedModelParallel._sharded_parameter_names(dmp._dmp_wrapped_module)
        )
        self._ddp_wrap(dmp, env, device, sharded_parameter_names)

    ...

class DistributedModelParallel(nn.Module, FusedOptimizerModule):
    @staticmethod
    def _sharded_parameter_names(module: nn.Module, prefix: str = "") -> Iterator[str]:
        module = get_unwrapped_module(module)
        if isinstance(module, ShardedModule):
            yield from module.sharded_parameter_names(prefix)
        else:
            for name, child in module.named_children():
                yield from DistributedModelParallel._sharded_parameter_names(
                    child, append_prefix(prefix, name)
                )
@tiankongdeguiji tiankongdeguiji changed the title DMP doesn't broadcasts DataParallel ShardingType embedding table from the process with rank 0 to all other processes DMP doesn't broadcast DataParallel ShardingType embedding table from the process with rank 0 to all other processes Mar 1, 2024
@IvanKobzarev
Copy link
Contributor

@tiankongdeguiji
DDP tensors are replicated in the state_dict that should be loaded on each rank.
So no comm_ops are necessary to broadcast them.

@tiankongdeguiji
Copy link
Author

@IvanKobzarev
In TorchRec DMP, all parameters of ShardedModule (Including ShardingType==DataParallel) are added to the params_and_buffers_to_ignore list for DDP. This configuration prevents DDP from broadcasting these parameters. However, for proper functionality, it is crucial to ensure that parameters of the ShardedModule with ShardingType equal to DataParallel are indeed broadcasted.

// torchrec/distributed/model_parallel.py
class DefaultDataParallelWrapper(DataParallelWrapper):
    ...

    def _ddp_wrap(
        self,
        dmp: "DistributedModelParallel",
        env: ShardingEnv,
        device: torch.device,
        ddp_ignore_param_names: Set[str],
    ) -> None:
        pg = env.process_group
        if pg is None:
            raise RuntimeError("Can only init DDP for ProcessGroup-based ShardingEnv")
        all_parameter_names = set(dict(dmp.named_parameters()).keys())
        if len(all_parameter_names - ddp_ignore_param_names) == 0:
            return
        DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
            module=dmp._dmp_wrapped_module,
            params_and_buffers_to_ignore=ddp_ignore_param_names,
        )
        # initialize DDP
        dmp._dmp_wrapped_module = cast(
            nn.Module,
            DistributedDataParallel(
                module=dmp._dmp_wrapped_module.to(device),
                device_ids=None if device.type == "cpu" else [device],
                process_group=pg,
                gradient_as_bucket_view=True,
                broadcast_buffers=False,
                static_graph=self._static_graph,
                find_unused_parameters=self._find_unused_parameters,
                bucket_cap_mb=self._bucket_cap_mb,
            ),
        )
        if self._allreduce_comm_precision == "fp16":
            dmp._dmp_wrapped_module.register_comm_hook(
                None, ddp_default_hooks.fp16_compress_hook
            )
        elif self._allreduce_comm_precision == "bf16":
            dmp._dmp_wrapped_module.register_comm_hook(
                None, ddp_default_hooks.bf16_compress_hook
            )

    def wrap(
        self,
        dmp: "DistributedModelParallel",
        env: ShardingEnv,
        device: torch.device,
    ) -> None:
        if isinstance(dmp._dmp_wrapped_module, DistributedDataParallel) or isinstance(
            dmp._dmp_wrapped_module, FullyShardedDataParallel
        ):
            return
        sharded_parameter_names = set(
            DistributedModelParallel._sharded_parameter_names(dmp._dmp_wrapped_module)
        )
        self._ddp_wrap(dmp, env, device, sharded_parameter_names)

@tiankongdeguiji
Copy link
Author

@henrylhtsang
Copy link
Contributor

@tiankongdeguiji

I suspect this isn't really a problem. I tested it with the NCCL model_parallel test_sharding_dp by printing the state_dict out, and found them to be the same.

I suspect it is working due to this hack
https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embeddingbag.py?fbclid=IwAR0LKttb3ZOvOhIh1lAfnbO7hUR16YrPp9kPfZSM3WW4pHvv800z0G3a718#L499-L500

@colin2328
Copy link
Contributor

@henrylhtsang yes, that is where DDP modules are set up (using actual DDP) to make these data_parallel tables call all_reduce to get the correct gradients.
Why do you call this part a hack?

@tiankongdeguiji , is your concern that all_reduce won't be called during training? or is your concern that restoring from checkpoint will be incorrect?

@tiankongdeguiji
Copy link
Author

@henrylhtsang

In the test_sharding_dp function, the state_dict of global_model is copied to local_model, and the global_model is initialized using torch.manual_seed(0). This ensures that all model replicas start from identical initial parameters.

https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/test_utils/test_sharding.py#L376

However, it is important to note that without performing this state_dict copying at the beginning of training, as is typically done, the DataParallel ShardingType embedding table parameters in different replicas would start from differing states.

@henrylhtsang
Copy link
Contributor

henrylhtsang commented Mar 18, 2024

@tiankongdeguiji can you try to inspect the state dict of local model right after local_model = DistributedModelParallel(, ie before the copy_state_dict?

When I ran it, it showed those parameters are the same.

update: I just tested it again. Before DMP, the table weights are on meta device. After DMP, they are on cuda and are the same.

@tiankongdeguiji
Copy link
Author

@henrylhtsang

Can you run torchrun --master_addr=localhost --master_port=54926 --nnodes=1 --nproc-per-node=2 --node_rank=0 debug_dp_shard.py, and use the enviroment torchrec==0.6.0+cu121, torch==2.2.0+cu121, fbgemm-gpu==0.6.0+cu121?

debug_dp_shard.py

import os
from typing import Dict, cast

import torch
import torch.distributed as dist
import torchrec
from torch import nn
from torchrec import EmbeddingBagCollection
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import (
    DistributedModelParallel,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.optim import optimizers
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

large_table_cnt = 2
small_table_cnt = 2
large_tables = [
    torchrec.EmbeddingBagConfig(
        name="large_table_" + str(i),
        embedding_dim=64,
        num_embeddings=4096,
        feature_names=["large_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM,
    )
    for i in range(large_table_cnt)
]
small_tables = [
    torchrec.EmbeddingBagConfig(
        name="small_table_" + str(i),
        embedding_dim=64,
        num_embeddings=1024,
        feature_names=["small_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM,
    )
    for i in range(small_table_cnt)
]


def gen_constraints(
    sharding_type: ShardingType = ShardingType.DATA_PARALLEL,
) -> Dict[str, ParameterConstraints]:
    large_table_constraints = {
        "large_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
        )
        for i in range(large_table_cnt)
    }
    small_table_constraints = {
        "small_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
        )
        for i in range(small_table_cnt)
    }
    constraints = {**large_table_constraints, **small_table_constraints}
    return constraints


class DebugModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.ebc = EmbeddingBagCollection(tables=large_tables + small_tables, device="meta")
        self.linear = nn.Linear(64 * (small_table_cnt + large_table_cnt), 1)

    def forward(self, kjt: KeyedJaggedTensor):
        emb = self.ebc(kjt)
        return torch.mean(self.linear(emb.values()))


rank = int(os.environ["RANK"])
if torch.cuda.is_available():
    device = torch.device(f"cuda:{rank}")
    backend = "nccl"
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
    backend = "gloo"
dist.init_process_group(backend=backend)
world_size = dist.get_world_size()
print("world_size:", world_size)

model = DebugModel()
apply_optimizer_in_backward(optimizers.Adagrad, model.ebc.parameters(), {"lr": 0.001})

topology = Topology(world_size=world_size, compute_device=device.type)
constraints = gen_constraints(ShardingType.DATA_PARALLEL)
planner = EmbeddingShardingPlanner(
    topology=topology,
    constraints=constraints,
)
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
plan = planner.collective_plan(model, sharders, dist.GroupMember.WORLD)

sharded_model = DistributedModelParallel(
    model,
    plan=plan,
    sharders=sharders,
    device=device,
)
dense_optimizer = KeyedOptimizerWrapper(
    dict(in_backward_optimizer_filter(sharded_model.named_parameters())),
    lambda params: torch.optim.Adam(params, lr=0.001),
)
optimizer = CombinedOptimizer([sharded_model.fused_optimizer, dense_optimizer])
print(f"rank:{rank},sharding plan: {plan}")

batch_size = 64
kjt = KeyedJaggedTensor(
    keys=["large_table_feature_" + str(i) for i in range(large_table_cnt)]
    + ["small_table_feature_" + str(i) for i in range(small_table_cnt)],
    values=torch.cat([
        torch.randint(0, 4096, (batch_size * 2,))
        , torch.randint(0, 1023, (batch_size * 2,))]
    ),
    lengths=torch.ones(batch_size * (small_table_cnt + large_table_cnt), dtype=torch.int32),
).to(device=device)
losses = sharded_model.forward(kjt)
torch.sum(losses, dim=0).backward()
optimizer.step()

dist.barrier()
for k, v in sharded_model.named_parameters():
    if 'ebc' in k:
        t_list = [torch.zeros_like(v) for _ in range(world_size)]
        dist.all_gather(t_list, v)
        if rank == 0:
            print(k, t_list[0].equal(t_list[1]))

It will print the following log which indicates ebc parameters are not the same.

ebc.embedding_bags.large_table_0.weight False
ebc.embedding_bags.large_table_1.weight False
ebc.embedding_bags.small_table_0.weight False
ebc.embedding_bags.small_table_1.weight False

@henrylhtsang
Copy link
Contributor

henrylhtsang commented Mar 22, 2024

@tiankongdeguiji Okay I think you are 100% right. Sorry I didn't understand your point on the torch seed part.

I looked into it. A few points:

  1. The problem isn't with training. It seems like they are different at the initialization.
  2. The "hack" is working perfectly. If you print self._lookups[index].state_dict() right after that point, they are the same.
  3. What I think the problem is is self._initialize_torch_state(), which initialize the tables separately. I tested putting torch.manual_seed(0) before the self._initialize_torch_state() line, and your code are printing 4 Trues. Please test this and report back to see if this works.

Followups:

  1. Fix the problem.
  2. Remove torch.manual_seed in tests

cc @PaulZhang12 @colin2328

@tiankongdeguiji
Copy link
Author

@henrylhtsang Yes, it works. However, I think it's better to broadcast these parameters on rank 0 to other ranks, like DDP.

@henrylhtsang
Copy link
Contributor

@tiankongdeguiji fyi I raised the issue to the team already. Probably will wait a bit.

On the other hand, any suggestions on how to fix this in a nice way? Maybe remove the names from _sharded_parameter_names?

@tiankongdeguiji
Copy link
Author

@henrylhtsang We are unable to remove the names from _sharded_parameter_names, as the dist.Reducer within DDP is incapable of managing the parameters associated with the DataParallel ShardingType embedding table. At present, I invoke dist._broadcast_coalesced for these parameters following DMP.

@henrylhtsang
Copy link
Contributor

@tiankongdeguiji fyi landed the fix cc482f8

Basically every time we call reset_parameters, we will also broadcast the re-initialized DP tables from rank 0 to all other ranks

Though be sure to not forget to apply optimizer to DP tables

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants