Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]Enable Ascend NPU fintuning and inference #447

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 13 additions & 3 deletions recipes/code_llama/code_completion_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time

import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, is_torch_npu_available, is_torch_xpu_available

from llama_recipes.inference.safety_utils import get_safety_checker
from llama_recipes.inference.model_utils import load_model, load_peft_model
Expand Down Expand Up @@ -48,7 +48,12 @@ def main(
sys.exit(1)

# Set the seeds for reproducibility
torch.cuda.manual_seed(seed)
if is_torch_npu_available():
torch.npu.manual_seed(seed)
elif is_torch_xpu_available():
torch.xpu.manual_seed(seed)
else:
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)

model = load_model(model_name, quantization, use_fast_kernels)
Expand Down Expand Up @@ -80,8 +85,13 @@ def main(
sys.exit(1) # Exit the program with an error status

batch = tokenizer(user_prompt, return_tensors="pt")
if is_torch_npu_available():
batch = {k: v.to("npu") for k, v in batch.items()}
elif is_torch_xpu_available():
batch = {k: v.to("xpu") for k, v in batch.items()}
else:
batch = {k: v.to("cuda") for k, v in batch.items()}

batch = {k: v.to("cuda") for k, v in batch.items()}
start = time.perf_counter()
with torch.no_grad():
outputs = model.generate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import sys

import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, is_torch_npu_available, is_torch_xpu_available

from llama_recipes.inference.chat_utils import read_dialogs_from_file
from llama_recipes.inference.model_utils import load_model, load_peft_model
from llama_recipes.inference.safety_utils import get_safety_checker
from accelerate.utils import is_xpu_available


def main(
model_name,
Expand Down Expand Up @@ -56,7 +56,9 @@ def main(


# Set the seeds for reproducibility
if is_xpu_available():
if is_torch_npu_available():
torch.npu.manual_seed(seed)
elif is_torch_xpu_available():
torch.xpu.manual_seed(seed)
else:
torch.cuda.manual_seed(seed)
Expand Down Expand Up @@ -99,7 +101,9 @@ def main(
sys.exit(1) # Exit the program with an error status
tokens= torch.tensor(chat).long()
tokens= tokens.unsqueeze(0)
if is_xpu_available():
if is_torch_npu_available():
tokens= tokens.to("npu:0")
elif is_torch_xpu_available():
tokens= tokens.to("xpu:0")
else:
tokens= tokens.to("cuda:0")
Expand Down
13 changes: 8 additions & 5 deletions recipes/inference/local_inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
import gradio as gr

import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, is_torch_npu_available, is_torch_xpu_available

from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
from llama_recipes.inference.model_utils import load_model, load_peft_model

from accelerate.utils import is_xpu_available

def main(
model_name,
Expand Down Expand Up @@ -64,9 +63,11 @@ def inference(user_prompt, temperature, top_p, top_k, max_new_tokens, **kwargs,)
sys.exit(1) # Exit the program with an error status

# Set the seeds for reproducibility
if is_xpu_available():
if is_torch_npu_available():
torch.npu.manual_seed(seed)
elif is_torch_xpu_available():
torch.xpu.manual_seed(seed)
else:
elif torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)

Expand All @@ -80,7 +81,9 @@ def inference(user_prompt, temperature, top_p, top_k, max_new_tokens, **kwargs,)
tokenizer.pad_token = tokenizer.eos_token

batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt")
if is_xpu_available():
if is_torch_npu_available():
batch = {k: v.to("npu") for k, v in batch.items()}
elif is_torch_xpu_available():
batch = {k: v.to("xpu") for k, v in batch.items()}
else:
batch = {k: v.to("cuda") for k, v in batch.items()}
Expand Down
20 changes: 15 additions & 5 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
AutoTokenizer,
LlamaForCausalLM,
LlamaConfig,
is_torch_npu_available,
is_torch_xpu_available,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

Expand Down Expand Up @@ -47,7 +49,6 @@
print_model_size,
get_policies,
)
from accelerate.utils import is_xpu_available

def setup_wandb(train_config, fsdp_config, **kwargs):
try:
Expand Down Expand Up @@ -85,7 +86,9 @@ def main(**kwargs):
world_size = int(os.environ["WORLD_SIZE"])

if torch.distributed.is_initialized():
if is_xpu_available():
if is_torch_npu_available():
torch.npu.set_device(local_rank)
elif is_torch_xpu_available():
torch.xpu.set_device(local_rank)
elif torch.cuda.is_available():
torch.cuda.set_device(local_rank)
Expand Down Expand Up @@ -173,10 +176,15 @@ def main(**kwargs):
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)

device_id = 0
if is_xpu_available():
if is_torch_npu_available():
device_id = torch.npu.current_device()
device = torch.device("npu")
elif is_torch_xpu_available():
device_id = torch.xpu.current_device()
device = torch.device("xpu")
elif torch.cuda.is_available():
device_id = torch.cuda.current_device()
device = torch.device("cuda")

model = FSDP(
model,
Expand All @@ -188,13 +196,15 @@ def main(**kwargs):
device_id=device_id,
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
param_init_fn=lambda module: module.to_empty(device=device, recurse=False)
if train_config.low_cpu_fsdp and rank != 0 else None,
)
if fsdp_config.fsdp_activation_checkpointing:
apply_fsdp_checkpointing(model)
elif not train_config.quantization and not train_config.enable_fsdp:
if is_xpu_available():
if is_torch_npu_available():
model.to("npu")
elif is_torch_xpu_available():
model.to("xpu:0")
elif torch.cuda.is_available():
model.to("cuda")
Expand Down
24 changes: 21 additions & 3 deletions src/llama_recipes/utils/memory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
import threading

import torch
from accelerate.utils import is_xpu_available
from accelerate.utils import is_npu_available, is_xpu_available

def byte2gb(x):
return int(x / 2**30)
# This context manager is used to track the peak memory usage of the process
class MemoryTrace:
def __enter__(self):
gc.collect()
if is_xpu_available():
if is_npu_available():
torch.npu.empty_cache()
torch.npu.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.xpu.memory_allocated())
elif is_xpu_available():
torch.xpu.empty_cache()
torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.xpu.memory_allocated())
Expand Down Expand Up @@ -50,7 +54,19 @@ def __exit__(self, *exc):
self.peak_monitoring = False

gc.collect()
if is_xpu_available():
if is_npu_available():
torch.npu.empty_cache()
self.end = byte2gb(torch.npu.memory_allocated())
self.peak = byte2gb(torch.npu.max_memory_allocated())
npu_info = torch.npu.memory_stats()
self.peak_active_gb = byte2gb(npu_info["active_bytes.all.peak"])
self.malloc_retries = npu_info.get("num_alloc_retries", 0)
self.peak_active_gb = byte2gb(npu_info["active_bytes.all.peak"])
self.m_ooms = npu_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
self.max_reserved = byte2gb(torch.npu.max_memory_reserved())
elif is_xpu_available():
torch.xpu.empty_cache()
self.end = byte2gb(torch.xpu.memory_allocated())
self.peak = byte2gb(torch.xpu.max_memory_allocated())
Expand Down Expand Up @@ -82,6 +98,8 @@ def __exit__(self, *exc):

def print_stats(self):
device_str = None
if is_npu_available():
device_str = "NPU"
if is_xpu_available():
device_str = "XPU"
elif torch.cuda.is_available():
Expand Down
50 changes: 37 additions & 13 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
from llama_recipes.utils.memory_utils import MemoryTrace
from accelerate.utils import is_xpu_available, is_ccl_available
from accelerate.utils import is_npu_available, is_xpu_available, is_ccl_available

def set_tokenizer_params(tokenizer: LlamaTokenizer):
tokenizer.pad_token_id = 0
Expand Down Expand Up @@ -55,13 +55,22 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
if train_config.use_fp16 and train_config.enable_fsdp:
scaler = ShardedGradScaler()
elif train_config.use_fp16 and not train_config.enable_fsdp:
scaler = torch.cuda.amp.GradScaler()
if is_npu_available:
scaler = torch.npu.amp.GradScaler()
elif is_xpu_available:
scaler = torch.xpu.amp.GradScaler()
else:
scaler = torch.cuda.amp.GradScaler()
if train_config.enable_fsdp:
world_size = int(os.environ["WORLD_SIZE"])



autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
if is_npu_available():
autocast = torch.npu.amp.autocast if train_config.use_fp16 else nullcontext
elif is_xpu_available():
autocast = torch.xpu.amp.autocast if train_config.use_fp16 else nullcontext
else:
autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext

train_prep = []
train_loss = []
Expand Down Expand Up @@ -102,13 +111,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
break
for key in batch.keys():
if train_config.enable_fsdp:
if is_xpu_available():
if is_npu_available():
batch[key] = batch[key].to(torch.device(f"npu:{local_rank}"))
elif is_xpu_available():
batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
else:
batch[key] = batch[key].to(local_rank)
else:

if is_xpu_available():
if is_npu_available():
batch[key] = batch[key].to('npu:0')
elif is_xpu_available():
batch[key] = batch[key].to('xpu:0')
else:
batch[key] = batch[key].to('cuda:0')
Expand Down Expand Up @@ -162,8 +175,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche

epoch_end_time = time.perf_counter()-epoch_start_time
epoch_times.append(epoch_end_time)
# Reducing total_loss across all devices if there's more than one CUDA device
if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
# Reducing total_loss across all devices if there's more than one CUDA/NPU/XPU device
if is_npu_available() and (torch.npu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
elif is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
Expand Down Expand Up @@ -305,7 +320,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
else:
if is_xpu_available():
if is_npu_available():
batch[key] = batch[key].to('npu:0')
elif is_xpu_available():
batch[key] = batch[key].to('xpu:0')
else:
batch[key] = batch[key].to('cuda:0')
Expand All @@ -325,8 +342,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
)

# If there's more than one CUDA device, reduce evaluation loss across all devices
if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
# If there's more than one CUDA/NPU/XPU device, reduce evaluation loss across all devices
if is_npu_available() and (torch.npu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
elif is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
Expand Down Expand Up @@ -370,6 +389,8 @@ def setup():
if is_ccl_available():
# distributed training on xpus
dist.init_process_group("ccl")
elif is_npu_available():
dist.init_process_group("hccl")
else:
dist.init_process_group("nccl")

Expand All @@ -395,7 +416,9 @@ def clear_gpu_cache(rank=None):
"""Clear the GPU cache for all ranks"""
if rank == 0:
print(f"Clearing GPU cache for all ranks")
if is_xpu_available():
if is_npu_available():
torch.npu.empty_cache()
elif is_xpu_available():
torch.xpu_empty_cache()
else:
torch.cuda.empty_cache()
Expand Down Expand Up @@ -438,7 +461,8 @@ def get_policies(cfg, rank):
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
) or
(is_xpu_available()))
(is_xpu_available())
or is_npu_available())


mixed_precision_policy = None
Expand Down