Skip to content

Latest commit

 

History

History
70 lines (46 loc) · 2.94 KB

README.md

File metadata and controls

70 lines (46 loc) · 2.94 KB

VeDeviceMesh for nD Parallelism

TLDR

vedevicemesh

(* is under development.)

What is VeDeviceMesh?

VeDeviceMesh (veScale Device Mesh) is an advanced API that is built on top of PyTorch native's DeviceMesh. This API enhances the existing capabilities of DeviceMesh, enabling effective nD parallelism strategies, checkpointing, and easy-to-use APIs, with ideals below:

  • “A DeviceMesh, but better”

  • One “Mesh” fits all: users don't need to worry about meddling with DeviceMesh and ProcessGroups' throughout the course of training. Additionally, users make the most out of the same DeviceMesh to enable hybrid parallelization training.

  • Easy to extend: for more refined capabilities for imminent parallelization methods in the future, VeDeviceMesh provides mature APIs to extend new functionalities without breaking the semantics of communication

How does VeDeviceMesh work?

VeDeviceMesh wraps around PyTorch DeviceMesh with APIs that seamlessly integrate with APIs of veScale's DModule, DDP, DistributedOptimizer, Pipeline Parallel, and Checkpoint.

VeDeviceMesh further implements advanced features surrounding DeviceMesh:

  • rank mapping between local rank and global rank or between strategy coordinates and global rank

  • submesh mapping between global mesh and submeshes or between local submesh and neighbor submeshes

  • [in future] fault tolerance with reconfigurable meshes

How to use VeDeviceMesh?

from vescale.devicemesh_api import VESCALE_DEVICE_MESH 
from vescale.dmodule.api import parallelize_module
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.optim.distributed_optimizer import DistributedOptimizer
from ... import GPT

# create torch-native model as usual
model = GPT()

# initialize a VeDeviceMesh containing a global DeviceMesh with size of (2, 2)
VESCALE_DEVICE_MESH.init_device_mesh("cuda", mesh_shape=(2, 2), mesh_dim_names=("DP", "TP"))

# use VeDeviceMesh to obtain global DeviceMesh's tensor parallelism view
if VESCALE_DEVICE_MESH.get_strategy_size("TP") > 1:
    # wrap DModule (TP/SP)
    model = parallelize_module(model, VESCALE_DEVICE_MESH["TP"], sharding_plan, ...)

# use VeDeviceMesh to obtain global DeviceMesh's data parallelism view
if VESCALE_DEVICE_MESH.get_strategy_size("DP") > 1:
    # wrap DDP module
    model = DDP(model, VESCALE_DEVICE_MESH["DP"], ...)

# use VeDeviceMesh to build distributed optimizer
if VESCALE_DEVICE_MESH.get_strategy_size("DP") > 1:
    optimizer = DistributedOptimizer(torch.optim.Adam, models=[model], ...)

# Train model
for X, Y in data_set:
    optimizer.zero_grad()
    loss = model(X, Y)
    loss.backward()
    optimizer.step()
  • APIs can be found in <repo>/vescale/devicemesh_api/api.py

  • More examples can be found under <repo>/test/parallel/devicemesh_api/*.py