-
Notifications
You must be signed in to change notification settings - Fork 368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DMP doesn't broadcast DataParallel ShardingType embedding table from the process with rank 0 to all other processes #1739
Comments
@tiankongdeguiji |
@IvanKobzarev // torchrec/distributed/model_parallel.py
class DefaultDataParallelWrapper(DataParallelWrapper):
...
def _ddp_wrap(
self,
dmp: "DistributedModelParallel",
env: ShardingEnv,
device: torch.device,
ddp_ignore_param_names: Set[str],
) -> None:
pg = env.process_group
if pg is None:
raise RuntimeError("Can only init DDP for ProcessGroup-based ShardingEnv")
all_parameter_names = set(dict(dmp.named_parameters()).keys())
if len(all_parameter_names - ddp_ignore_param_names) == 0:
return
DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
module=dmp._dmp_wrapped_module,
params_and_buffers_to_ignore=ddp_ignore_param_names,
)
# initialize DDP
dmp._dmp_wrapped_module = cast(
nn.Module,
DistributedDataParallel(
module=dmp._dmp_wrapped_module.to(device),
device_ids=None if device.type == "cpu" else [device],
process_group=pg,
gradient_as_bucket_view=True,
broadcast_buffers=False,
static_graph=self._static_graph,
find_unused_parameters=self._find_unused_parameters,
bucket_cap_mb=self._bucket_cap_mb,
),
)
if self._allreduce_comm_precision == "fp16":
dmp._dmp_wrapped_module.register_comm_hook(
None, ddp_default_hooks.fp16_compress_hook
)
elif self._allreduce_comm_precision == "bf16":
dmp._dmp_wrapped_module.register_comm_hook(
None, ddp_default_hooks.bf16_compress_hook
)
def wrap(
self,
dmp: "DistributedModelParallel",
env: ShardingEnv,
device: torch.device,
) -> None:
if isinstance(dmp._dmp_wrapped_module, DistributedDataParallel) or isinstance(
dmp._dmp_wrapped_module, FullyShardedDataParallel
):
return
sharded_parameter_names = set(
DistributedModelParallel._sharded_parameter_names(dmp._dmp_wrapped_module)
)
self._ddp_wrap(dmp, env, device, sharded_parameter_names) |
I suspect this isn't really a problem. I tested it with the NCCL model_parallel test_sharding_dp by printing the state_dict out, and found them to be the same. I suspect it is working due to this hack |
@henrylhtsang yes, that is where DDP modules are set up (using actual DDP) to make these data_parallel tables call all_reduce to get the correct gradients. @tiankongdeguiji , is your concern that all_reduce won't be called during training? or is your concern that restoring from checkpoint will be incorrect? |
In the https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/test_utils/test_sharding.py#L376 However, it is important to note that without performing this |
@tiankongdeguiji can you try to inspect the state dict of local model right after When I ran it, it showed those parameters are the same. update: I just tested it again. Before DMP, the table weights are on meta device. After DMP, they are on cuda and are the same. |
Can you run debug_dp_shard.py import os
from typing import Dict, cast
import torch
import torch.distributed as dist
import torchrec
from torch import nn
from torchrec import EmbeddingBagCollection
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.optim import optimizers
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
large_table_cnt = 2
small_table_cnt = 2
large_tables = [
torchrec.EmbeddingBagConfig(
name="large_table_" + str(i),
embedding_dim=64,
num_embeddings=4096,
feature_names=["large_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
)
for i in range(large_table_cnt)
]
small_tables = [
torchrec.EmbeddingBagConfig(
name="small_table_" + str(i),
embedding_dim=64,
num_embeddings=1024,
feature_names=["small_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
)
for i in range(small_table_cnt)
]
def gen_constraints(
sharding_type: ShardingType = ShardingType.DATA_PARALLEL,
) -> Dict[str, ParameterConstraints]:
large_table_constraints = {
"large_table_" + str(i): ParameterConstraints(
sharding_types=[sharding_type.value],
)
for i in range(large_table_cnt)
}
small_table_constraints = {
"small_table_" + str(i): ParameterConstraints(
sharding_types=[sharding_type.value],
)
for i in range(small_table_cnt)
}
constraints = {**large_table_constraints, **small_table_constraints}
return constraints
class DebugModel(nn.Module):
def __init__(self):
super().__init__()
self.ebc = EmbeddingBagCollection(tables=large_tables + small_tables, device="meta")
self.linear = nn.Linear(64 * (small_table_cnt + large_table_cnt), 1)
def forward(self, kjt: KeyedJaggedTensor):
emb = self.ebc(kjt)
return torch.mean(self.linear(emb.values()))
rank = int(os.environ["RANK"])
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)
world_size = dist.get_world_size()
print("world_size:", world_size)
model = DebugModel()
apply_optimizer_in_backward(optimizers.Adagrad, model.ebc.parameters(), {"lr": 0.001})
topology = Topology(world_size=world_size, compute_device=device.type)
constraints = gen_constraints(ShardingType.DATA_PARALLEL)
planner = EmbeddingShardingPlanner(
topology=topology,
constraints=constraints,
)
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
plan = planner.collective_plan(model, sharders, dist.GroupMember.WORLD)
sharded_model = DistributedModelParallel(
model,
plan=plan,
sharders=sharders,
device=device,
)
dense_optimizer = KeyedOptimizerWrapper(
dict(in_backward_optimizer_filter(sharded_model.named_parameters())),
lambda params: torch.optim.Adam(params, lr=0.001),
)
optimizer = CombinedOptimizer([sharded_model.fused_optimizer, dense_optimizer])
print(f"rank:{rank},sharding plan: {plan}")
batch_size = 64
kjt = KeyedJaggedTensor(
keys=["large_table_feature_" + str(i) for i in range(large_table_cnt)]
+ ["small_table_feature_" + str(i) for i in range(small_table_cnt)],
values=torch.cat([
torch.randint(0, 4096, (batch_size * 2,))
, torch.randint(0, 1023, (batch_size * 2,))]
),
lengths=torch.ones(batch_size * (small_table_cnt + large_table_cnt), dtype=torch.int32),
).to(device=device)
losses = sharded_model.forward(kjt)
torch.sum(losses, dim=0).backward()
optimizer.step()
dist.barrier()
for k, v in sharded_model.named_parameters():
if 'ebc' in k:
t_list = [torch.zeros_like(v) for _ in range(world_size)]
dist.all_gather(t_list, v)
if rank == 0:
print(k, t_list[0].equal(t_list[1])) It will print the following log which indicates
|
@tiankongdeguiji Okay I think you are 100% right. Sorry I didn't understand your point on the torch seed part. I looked into it. A few points:
Followups:
|
@henrylhtsang Yes, it works. However, I think it's better to broadcast these parameters on |
@tiankongdeguiji fyi I raised the issue to the team already. Probably will wait a bit. On the other hand, any suggestions on how to fix this in a nice way? Maybe remove the names from _sharded_parameter_names? |
@henrylhtsang We are unable to remove the names from _sharded_parameter_names, as the dist.Reducer within DDP is incapable of managing the parameters associated with the DataParallel ShardingType embedding table. At present, I invoke dist._broadcast_coalesced for these parameters following DMP. |
@tiankongdeguiji fyi landed the fix cc482f8 Basically every time we call reset_parameters, we will also broadcast the re-initialized DP tables from rank 0 to all other ranks Though be sure to not forget to apply optimizer to DP tables |
DMP should broadcast DataParallel ShardingType embedding table param from the process with rank 0 to all other processes in the group to make sure that all model replicas start from the exact same state.
However, it has been observed that DMP currently designates all sharded_parameter_names as part of the params_and_buffers_to_ignore list for Distributed Data Parallel (DDP) operations. This behavior leads to a situation where DMP omits the necessary synchronization of the DataParallel ShardingType embedding table parameters during the initialization phase. As a consequence, model replicas may start from different states, which could result in inconsistent training outcomes and potentially compromise model convergence.
The text was updated successfully, but these errors were encountered: