Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update xpu related device setting #446

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
)
from accelerate.utils import is_xpu_available

if is_xpu_available():
import intel_extension_for_pytorch
import oneccl_bindings_for_pytorch

def setup_wandb(train_config, fsdp_config, **kwargs):
try:
import wandb
Expand Down Expand Up @@ -174,7 +178,7 @@ def main(**kwargs):

device_id = 0
if is_xpu_available():
device_id = torch.xpu.current_device()
device_id = torch.device(f"xpu:{local_rank}")
elif torch.cuda.is_available():
device_id = torch.cuda.current_device()

Expand All @@ -188,14 +192,15 @@ def main(**kwargs):
device_id=device_id,
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
param_init_fn=lambda module: \
module.to_empty(device=torch.device("xpu") if is_xpu_available() else torch.device("cuda"), recurse=False)
if train_config.low_cpu_fsdp and rank != 0 else None,
)
if fsdp_config.fsdp_activation_checkpointing:
apply_fsdp_checkpointing(model)
elif not train_config.quantization and not train_config.enable_fsdp:
if is_xpu_available():
model.to("xpu:0")
model.to("xpu")
elif torch.cuda.is_available():
model.to("cuda")

Expand Down
18 changes: 13 additions & 5 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
from transformers import LlamaTokenizer
import json


from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
from llama_recipes.utils.memory_utils import MemoryTrace
from accelerate.utils import is_xpu_available, is_ccl_available

if is_xpu_available():
import intel_extension_for_pytorch
import oneccl_bindings_for_pytorch

def set_tokenizer_params(tokenizer: LlamaTokenizer):
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
Expand Down Expand Up @@ -60,8 +63,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
world_size = int(os.environ["WORLD_SIZE"])



autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
if is_xpu_available():
autocast = torch.xpu.amp.autocast if train_config.use_fp16 else nullcontext
else:
autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext

train_prep = []
train_loss = []
Expand Down Expand Up @@ -303,7 +308,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
break
for key in batch.keys():
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
if is_xpu_available():
batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
else:
batch[key] = batch[key].to(local_rank)
else:
if is_xpu_available():
batch[key] = batch[key].to('xpu:0')
Expand Down Expand Up @@ -396,7 +404,7 @@ def clear_gpu_cache(rank=None):
if rank == 0:
print(f"Clearing GPU cache for all ranks")
if is_xpu_available():
torch.xpu_empty_cache()
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()

Expand Down