Skip to content

Commit

Permalink
Enable Ascend NPU fintuning and inference
Browse files Browse the repository at this point in the history
  • Loading branch information
statelesshz committed Apr 19, 2024
1 parent 6e52cb9 commit 87fdb65
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 30 deletions.
16 changes: 13 additions & 3 deletions recipes/code_llama/code_completion_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time

import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, is_torch_npu_available, is_torch_xpu_available

from llama_recipes.inference.safety_utils import get_safety_checker
from llama_recipes.inference.model_utils import load_model, load_peft_model
Expand Down Expand Up @@ -48,7 +48,12 @@ def main(
sys.exit(1)

# Set the seeds for reproducibility
torch.cuda.manual_seed(seed)
if is_torch_npu_available():
torch.npu.manual_seed(seed)
elif is_torch_xpu_available():
torch.xpu.manual_seed(seed)
else:
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)

model = load_model(model_name, quantization, use_fast_kernels)
Expand Down Expand Up @@ -80,8 +85,13 @@ def main(
sys.exit(1) # Exit the program with an error status

batch = tokenizer(user_prompt, return_tensors="pt")
if is_torch_npu_available():
batch = {k: v.to("npu") for k, v in batch.items()}
elif is_torch_xpu_available():
batch = {k: v.to("xpu") for k, v in batch.items()}
else:
batch = {k: v.to("cuda") for k, v in batch.items()}

batch = {k: v.to("cuda") for k, v in batch.items()}
start = time.perf_counter()
with torch.no_grad():
outputs = model.generate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import sys

import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, is_torch_npu_available, is_torch_xpu_available

from llama_recipes.inference.chat_utils import read_dialogs_from_file
from llama_recipes.inference.model_utils import load_model, load_peft_model
from llama_recipes.inference.safety_utils import get_safety_checker
from accelerate.utils import is_xpu_available


def main(
model_name,
Expand Down Expand Up @@ -56,7 +56,9 @@ def main(


# Set the seeds for reproducibility
if is_xpu_available():
if is_torch_npu_available():
torch.npu.manual_seed(seed)
elif is_torch_xpu_available():
torch.xpu.manual_seed(seed)
else:
torch.cuda.manual_seed(seed)
Expand Down Expand Up @@ -99,7 +101,9 @@ def main(
sys.exit(1) # Exit the program with an error status
tokens= torch.tensor(chat).long()
tokens= tokens.unsqueeze(0)
if is_xpu_available():
if is_torch_npu_available():
tokens= tokens.to("npu:0")
elif is_torch_xpu_available():
tokens= tokens.to("xpu:0")
else:
tokens= tokens.to("cuda:0")
Expand Down
13 changes: 8 additions & 5 deletions recipes/inference/local_inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
import gradio as gr

import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, is_torch_npu_available, is_torch_xpu_available

from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
from llama_recipes.inference.model_utils import load_model, load_peft_model

from accelerate.utils import is_xpu_available

def main(
model_name,
Expand Down Expand Up @@ -64,9 +63,11 @@ def inference(user_prompt, temperature, top_p, top_k, max_new_tokens, **kwargs,)
sys.exit(1) # Exit the program with an error status

# Set the seeds for reproducibility
if is_xpu_available():
if is_torch_npu_available():
torch.npu.manual_seed(seed)
elif is_torch_xpu_available():
torch.xpu.manual_seed(seed)
else:
elif torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)

Expand All @@ -80,7 +81,9 @@ def inference(user_prompt, temperature, top_p, top_k, max_new_tokens, **kwargs,)
tokenizer.pad_token = tokenizer.eos_token

batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt")
if is_xpu_available():
if is_torch_npu_available():
batch = {k: v.to("npu") for k, v in batch.items()}
elif is_torch_xpu_available():
batch = {k: v.to("xpu") for k, v in batch.items()}
else:
batch = {k: v.to("cuda") for k, v in batch.items()}
Expand Down
20 changes: 15 additions & 5 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
AutoTokenizer,
LlamaForCausalLM,
LlamaConfig,
is_torch_npu_available,
is_torch_xpu_available,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

Expand Down Expand Up @@ -47,7 +49,6 @@
print_model_size,
get_policies,
)
from accelerate.utils import is_xpu_available

def setup_wandb(train_config, fsdp_config, **kwargs):
try:
Expand Down Expand Up @@ -85,7 +86,9 @@ def main(**kwargs):
world_size = int(os.environ["WORLD_SIZE"])

if torch.distributed.is_initialized():
if is_xpu_available():
if is_torch_npu_available():
torch.npu.set_device(local_rank)
elif is_torch_xpu_available():
torch.xpu.set_device(local_rank)
elif torch.cuda.is_available():
torch.cuda.set_device(local_rank)
Expand Down Expand Up @@ -173,10 +176,15 @@ def main(**kwargs):
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)

device_id = 0
if is_xpu_available():
if is_torch_npu_available():
device_id = torch.npu.current_device()
device = torch.device("npu")
elif is_torch_xpu_available():
device_id = torch.xpu.current_device()
device = torch.device("xpu")
elif torch.cuda.is_available():
device_id = torch.cuda.current_device()
device = torch.device("cuda")

model = FSDP(
model,
Expand All @@ -188,13 +196,15 @@ def main(**kwargs):
device_id=device_id,
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
param_init_fn=lambda module: module.to_empty(device=device, recurse=False)
if train_config.low_cpu_fsdp and rank != 0 else None,
)
if fsdp_config.fsdp_activation_checkpointing:
apply_fsdp_checkpointing(model)
elif not train_config.quantization and not train_config.enable_fsdp:
if is_xpu_available():
if is_torch_npu_available():
model.to("npu:0")
elif is_torch_xpu_available():
model.to("xpu:0")
elif torch.cuda.is_available():
model.to("cuda")
Expand Down
24 changes: 21 additions & 3 deletions src/llama_recipes/utils/memory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
import threading

import torch
from accelerate.utils import is_xpu_available
from accelerate.utils import is_npu_available, is_xpu_available

def byte2gb(x):
return int(x / 2**30)
# This context manager is used to track the peak memory usage of the process
class MemoryTrace:
def __enter__(self):
gc.collect()
if is_xpu_available():
if is_npu_available():
torch.npu.empty_cache()
torch.npu.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.xpu.memory_allocated())
elif is_xpu_available():
torch.xpu.empty_cache()
torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = byte2gb(torch.xpu.memory_allocated())
Expand Down Expand Up @@ -50,7 +54,19 @@ def __exit__(self, *exc):
self.peak_monitoring = False

gc.collect()
if is_xpu_available():
if is_npu_available():
torch.npu.empty_cache()
self.end = byte2gb(torch.npu.memory_allocated())
self.peak = byte2gb(torch.npu.max_memory_allocated())
npu_info = torch.npu.memory_stats()
self.peak_active_gb = byte2gb(npu_info["active_bytes.all.peak"])
self.malloc_retries = npu_info.get("num_alloc_retries", 0)
self.peak_active_gb = byte2gb(npu_info["active_bytes.all.peak"])
self.m_ooms = npu_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
self.max_reserved = byte2gb(torch.npu.max_memory_reserved())
elif is_xpu_available():
torch.xpu.empty_cache()
self.end = byte2gb(torch.xpu.memory_allocated())
self.peak = byte2gb(torch.xpu.max_memory_allocated())
Expand Down Expand Up @@ -82,6 +98,8 @@ def __exit__(self, *exc):

def print_stats(self):
device_str = None
if is_npu_available():
device_str = "NPU"
if is_xpu_available():
device_str = "XPU"
elif torch.cuda.is_available():
Expand Down
35 changes: 25 additions & 10 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
from llama_recipes.utils.memory_utils import MemoryTrace
from accelerate.utils import is_xpu_available, is_ccl_available
from accelerate.utils import is_npu_available, is_xpu_available, is_ccl_available

def set_tokenizer_params(tokenizer: LlamaTokenizer):
tokenizer.pad_token_id = 0
Expand Down Expand Up @@ -102,13 +102,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
break
for key in batch.keys():
if train_config.enable_fsdp:
if is_xpu_available():
if is_npu_available():
batch[key] = batch[key].to(torch.device(f"npu:{local_rank}"))
elif is_xpu_available():
batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
else:
batch[key] = batch[key].to(local_rank)
else:

if is_xpu_available():
if is_npu_available():
batch[key] = batch[key].to('npu:0')
elif is_xpu_available():
batch[key] = batch[key].to('xpu:0')
else:
batch[key] = batch[key].to('cuda:0')
Expand Down Expand Up @@ -162,8 +166,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche

epoch_end_time = time.perf_counter()-epoch_start_time
epoch_times.append(epoch_end_time)
# Reducing total_loss across all devices if there's more than one CUDA device
if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
# Reducing total_loss across all devices if there's more than one CUDA/NPU/XPU device
if is_npu_available() and (torch.npu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
elif is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
Expand Down Expand Up @@ -305,7 +311,9 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
else:
if is_xpu_available():
if is_npu_available():
batch[key] = batch[key].to('npu:0')
elif is_xpu_available():
batch[key] = batch[key].to('xpu:0')
else:
batch[key] = batch[key].to('cuda:0')
Expand All @@ -325,8 +333,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
)

# If there's more than one CUDA device, reduce evaluation loss across all devices
if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
# If there's more than one CUDA/NPU/XPU device, reduce evaluation loss across all devices
if is_npu_available() and (torch.npu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
elif is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
Expand Down Expand Up @@ -370,6 +380,8 @@ def setup():
if is_ccl_available():
# distributed training on xpus
dist.init_process_group("ccl")
elif is_npu_available():
dist.init_process_group("hccl")
else:
dist.init_process_group("nccl")

Expand All @@ -395,7 +407,9 @@ def clear_gpu_cache(rank=None):
"""Clear the GPU cache for all ranks"""
if rank == 0:
print(f"Clearing GPU cache for all ranks")
if is_xpu_available():
if is_npu_available():
torch.npu_empty_cache()
elif is_xpu_available():
torch.xpu_empty_cache()
else:
torch.cuda.empty_cache()
Expand Down Expand Up @@ -438,7 +452,8 @@ def get_policies(cfg, rank):
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
) or
(is_xpu_available()))
(is_xpu_available())
or is_npu_available())


mixed_precision_policy = None
Expand Down

0 comments on commit 87fdb65

Please sign in to comment.