Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN during llama3 finetuning #427

Open
mano3-1 opened this issue May 5, 2024 · 18 comments
Open

NaN during llama3 finetuning #427

mano3-1 opened this issue May 5, 2024 · 18 comments
Labels
currently fixing Am fixing now!

Comments

@mano3-1
Copy link

mano3-1 commented May 5, 2024

Hi,

I'm currently fine-tuning llama3-instruct-8b on a custom dataset using unsloth's FastLanguageModel. I'm using Hugging Face's SFTTrainer to train the model. Surprisingly, the gradient norm and evaluation loss become NaN after a few steps. I've seen a blog from unsloth mentioning that NaNs may appear due to a bug, but they also mentioned that the bug was fixed by Hugging Face and unsloth now (here, under the llama3-Quirks section). So, I not only updated unsloth and Hugging Face but also added the "pad_token" mentioned in the blog. Despite these attempts, the NaN problem still persists. Is there something else that I'm missing? Can someone help me out with this?

Here's the code snippet of how I'm loading the model:

 model, tokenizer = FastLanguageModel.from_pretrained(
      model_name = model_name,
      max_seq_length = args.max_seq_length,
      dtype = compute_dtype,
      load_in_4bit = args.use_4bit,
      # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
  )
  tokenizer.add_special_tokens({"pad_token": "<|reserved_special_token_0|>"})
  model.config.pad_token_id = tokenizer.pad_token_id # updating model config
  tokenizer.padding_side = 'right
  model = FastLanguageModel.get_peft_model(
      model,
      r = args.lora_r, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
      target_modules = lora_modules,
      lora_alpha = args.lora_alpha,
      lora_dropout = args.lora_dropout, # Supports any, but = 0 is optimized
      bias = args.lora_bias,    # Supports any, but = "none" is optimized
      # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
      use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
      random_state = 3407,
      use_rslora = False,  # We support rank stabilized LoRA
      loftq_config = None, # And LoftQ
  )

Following is the training code:

  training_arguments = TrainingArguments(
      output_dir=output_dir,
      num_train_epochs=args.epochs,
      per_device_train_batch_size=args.per_device_train_batch_size,
      per_device_eval_batch_size=args.per_device_eval_batch_size,
      gradient_accumulation_steps=args.gradient_accumulation_steps,
      optim=args.optimizer,
      save_steps=args.save_steps,
      logging_steps=args.logging_steps,
      learning_rate=args.learning_rate,
      weight_decay=args.weight_decay,
      fp16=fp16,
      bf16=bf16,
      max_grad_norm=args.max_grad_norm,
      max_steps=args.max_steps,
      warmup_ratio=args.warmup_ratio,
      # group_by_length=args.group_by_length,
      lr_scheduler_type=args.lr_scheduler_type,
      logging_strategy="steps",
      report_to="tensorboard",
      evaluation_strategy="steps",
      # ddp_find_unused_parameters=False,
  )
  trainer = SFTTrainer(
      model=model,
      train_dataset=train_dataset,
      eval_dataset=eval_dataset,
      dataset_text_field="chats",
      max_seq_length=args.max_seq_length,
      tokenizer=tokenizer,
      args=training_arguments,
      packing=packing
  )
@mano3-1 mano3-1 changed the title NaN during finetuning NaN during llama3 finetuning May 5, 2024
@danielhanchen
Copy link
Contributor

Are you training on embed_tokens and lm_head?

@mano3-1
Copy link
Author

mano3-1 commented May 6, 2024

Hi @danielhanchen,

Thank you for your response. I'm unsure about the inner workings of get_peft_model in Unsloth, but assuming it functions similarly to other peft methods, it should freeze the base model, including the embedding matrix, correct? Consequently, I believe my scripts are only training the Lora parameters. I attempted to use Unsloth's fix_untrained_tokens, but it didn't work out for me. Additionally, I noticed that Unsloth's blog mentions the llama-3-8b base model, whereas I'm using the llama-3-8b-instruct model. Instruct model's reserved tokens should not arise any issues as they are finetuned (unlike base model) right?

@lapp0
Copy link

lapp0 commented May 6, 2024

@mano3-1 what does the traceback say if you run

with torch.autograd.detect_anomaly():
    trainer.train()

@mano3-1
Copy link
Author

mano3-1 commented May 7, 2024

Hi @lapp0,
Here is the traceback:

Traceback (most recent call last):
  File "/home/ubuntu/LLMOps/train/train.py", line 501, in <module>
    main()
  File "/home/ubuntu/LLMOps/train/train.py", line 497, in main
    training_function(args)
  File "/home/ubuntu/LLMOps/train/train.py", line 445, in training_function
    trainer.train()
  File "/opt/conda/envs/LLMOps/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 361, in train
    output = super().train(*args, **kwargs)
  File "/opt/conda/envs/LLMOps/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
    return inner_training_loop(
  File "<string>", line 361, in _fast_inner_training_loop
  File "/opt/conda/envs/LLMOps/lib/python3.10/site-packages/transformers/trainer.py", line 3147, in training_step
    self.accelerator.backward(loss)
  File "/opt/conda/envs/LLMOps/lib/python3.10/site-packages/accelerate/accelerator.py", line 2013, in backward
    loss.backward(**kwargs)
  File "/opt/conda/envs/LLMOps/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/opt/conda/envs/LLMOps/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'Fast_CrossEntropyLossBackward' returned nan values in its 0th output.

@lapp0
Copy link

lapp0 commented May 8, 2024

I'm running into issues with back-propagation in unsloth as well, albeit I'm using a custom loss function and Mistral instead of llama-3. It works fine for AutoModelForCausalLM & get_peft_model, but with unsloth I get

`RuntimeError: Function 'LoRA_MLPBackward' returned nan values in its 0th output.

  File "<string>", line 361, in _fast_inner_training_loop
  File "/opt/conda/lib/python3.10/site-packages/trl/trainer/policy_trainer_base.py", line 549, in training_step
    return super().training_step(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3147, in training_step
    self.accelerator.backward(loss)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 2013, in backward
    loss.backward(**kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 301, in apply
    return user_fn(self, *args)
  File "/opt/conda/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 142, in decorate_bwd
    return bwd(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/unsloth/models/_utils.py", line 348, in backward
    torch.autograd.backward(output, dY)
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'LoRA_MLPBackward' returned nan values in its 0th output.

I'd be interested in the cause of your issue, perhaps it is the same as mine. If I figure anything out with mine I'll let you know.

@mano3-1
Copy link
Author

mano3-1 commented May 8, 2024

Hi @lapp0
Seems like we both are facing similar issue. I tried removing unsloth from my code and trained it with huggingface utilities, it went well. But I seriously want to have this unsloth in the loop, because the memory boost is significant. Do you think this is from unsloth's side or something which is popping due to our scripts?

@lapp0
Copy link

lapp0 commented May 8, 2024

I'm not sure. Your backwards step where it fails is a different layer of the model than me, but the only thing our scripts have in common is unsloth.

How about some debug details?

  1. Could you please share a full reproduction script, which would allow me and daniel to run locally? This includes the whole source file along with your run command.

  2. What is the output of pip3 freeze

@mano3-1
Copy link
Author

mano3-1 commented May 8, 2024

Here is the pip freeze:
requirements.txt

Here is the full training script: link

This is how I trigger the training scripts:
python train.py --max_seq_length 4000 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --sm_train_dir "/opt/ml/processing/train" --sm_validation_dir "/opt/ml/processing/test" --hf_token <yourtoken> --run_experiment False --lora_r 32 --lora_alpha 8 --unsloth True --logging_steps 8 --save_steps 8

you may set hf_token to string "None", if you are loading unsloth models I guess.

@lapp0
Copy link

lapp0 commented May 8, 2024

requirements.txt isn't the same as pip freeze. pip3 freeze will detail the version of all packages.

@danielhanchen
Copy link
Contributor

Oh no sorry guys - i will take a look

@lapp0
Copy link

lapp0 commented May 8, 2024

Thanks @danielhanchen

Here is my reproduction script as well, run on a 4090 with cuda 12.1. @mano3-1 has a standard SFT script so his is probably worth looking at first.

"""
INSTALL DEPS:
pip install torch==2.3.0
pip install transformers tensorboardX bitsandbytes peft accelerate flash_attn --upgrade
pip install "unsloth[cu121] @ git+https://github.com/unslothai/unsloth.git"
pip install "git+https://github.com/lapp0/trl.git@ppov2"
pip install -U xformers --index-url https://download.pytorch.org/whl/cu121
pip install torch==2.3.0  # ensure correct torch still used
"""
import multiprocessing

from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    PreTrainedModel,
    DataCollatorWithPadding,
    BitsAndBytesConfig
)
import torch
from trl.trainer.ppov2_trainer import PPOConfig, PPOTrainer, PolicyAndValueWrapper
from peft import get_peft_model, LoraConfig


base_model_uri = "HuggingFaceH4/mistral-7b-sft-beta"
reward_model_uri = "weqweasdas/RM-Mistral-7B"

################
# Model & Tokenizer
################
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(
    base_model_uri,
    padding_side="left",
    trust_remote_code=True,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
    reward_model_uri,
    num_labels=1,
    quantization_config=quantization_config,
    attn_implementation="flash_attention_2",
)


value_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained(
    reward_model_uri,
    num_labels=1,
    quantization_config=quantization_config,
    attn_implementation="flash_attention_2",
)
value_model = get_peft_model(
    value_model,
    LoraConfig(
        r=16,
        lora_alpha=64,
        lora_dropout=0,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj"],
        bias="none",
        task_type="CAUSAL_LM",
    )
)


from unsloth import FastLanguageModel
base_policy, _ = FastLanguageModel.from_pretrained(
    model_name=base_model_uri,
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)
base_policy = FastLanguageModel.get_peft_model(
    base_policy,
    r=16,
    lora_alpha=64,
    lora_dropout=0,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    bias="none",
    use_gradient_checkpointing="unsloth",
    max_seq_length=2048
)
"""
# Creating base_policy like this works, unsloth doesn't
from transformers import AutoModelForCausalLM
base_policy = AutoModelForCausalLM.from_pretrained(
    base_model_uri,
    num_labels=1,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    ),
    attn_implementation="flash_attention_2",
)
lora_config = LoraConfig(
    r=16,
    lora_alpha=64,
    lora_dropout=0,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
base_policy = get_peft_model(base_policy, lora_config)
"""

# trl.trainer.peft_module_casting_to_bf16(base_model)


base_model = PolicyAndValueWrapper(base_policy, value_model)


################
# Dataset
################
raw_datasets = load_dataset("HuggingFaceH4/ultrachat_200k")
train_dataset = raw_datasets["train_sft"]
eval_dataset = raw_datasets["test_sft"]


def prepare_dataset(dataset, tokenizer):
    """pre-tokenize the dataset before training; only collate during training"""

    def tokenize(element):
        input_ids = tokenizer.apply_chat_template(
            element["messages"][:1],
            padding=False,
            add_generation_prompt=True,
        )
        return {"input_ids": input_ids, "lengths": len(input_ids)}

    return dataset.map(
        tokenize,
        remove_columns=dataset.column_names,
        num_proc=multiprocessing.cpu_count(),
        load_from_cache_file=False,
    )


train_dataset = prepare_dataset(train_dataset, tokenizer).filter(lambda x: x["lengths"] <= 1024)
eval_dataset = prepare_dataset(eval_dataset, tokenizer).filter(lambda x: x["lengths"] <= 1024)

collator = DataCollatorWithPadding(tokenizer)

###############
# Training
################
config = PPOConfig(
    output_dir="./ppov2_experiment_v2",
    report_to="tensorboard",
    update_generation_steps=16,
    gradient_accumulation_steps=8,
    per_device_train_batch_size=2,
    push_to_hub=True,
    hub_model_id="lapp0/ppov2_experiment_v2",
    logging_steps=1,
    learning_rate=3e-6,
    save_steps=4,
    non_eos_penalty=True,
    response_length=128,
    optim="paged_adamw_8bit",
    bf16=True,
    fp16=False,
    truncate_token="eos",
    gradient_checkpointing=True,
    # gradient_checkpointing_kwargs={"use_reentrant": False},
)

trainer = PPOTrainer(
    model=base_model,
    args=config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    reward_model=reward_model,
    data_collator=collator,
    tokenizer=tokenizer,
)
with torch.autograd.detect_anomaly():
    trainer.train()

trainer.generate_completions()

pip3 freeze:

accelerate==0.30.0
aiohttp==3.9.5
aiosignal==1.3.1
anaconda-anon-usage @ file:///croot/anaconda-anon-usage_1710965072196/work
anyio==4.3.0
archspec @ file:///croot/archspec_1709217642129/work
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work
astunparse==1.6.3
async-lru==2.0.4
async-timeout==4.0.3
attrs @ file:///croot/attrs_1695717823297/work
Babel==2.15.0
bash_kernel==0.9.3
beautifulsoup4 @ file:///croot/beautifulsoup4-split_1681493039619/work
bitsandbytes==0.43.1
bleach==6.1.0
boltons @ file:///croot/boltons_1677628692245/work
Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work
certifi @ file:///croot/certifi_1707229174982/work/certifi
cffi @ file:///croot/cffi_1700254295673/work
chardet @ file:///home/builder/ci_310/chardet_1640804867535/work
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
click @ file:///croot/click_1698129812380/work
comm==0.2.2
conda @ file:///croot/conda_1689269889729/work
conda-build @ file:///croot/conda-build_1710789183177/work
conda-content-trust @ file:///croot/conda-content-trust_1693490622020/work
conda-libmamba-solver @ file:///croot/conda-libmamba-solver_1691418897561/work/src
conda-package-handling @ file:///croot/conda-package-handling_1690999929514/work
conda_index @ file:///croot/conda-index_1706633791028/work
conda_package_streaming @ file:///croot/conda-package-streaming_1690987966409/work
cryptography @ file:///croot/cryptography_1710350347627/work
datasets==2.19.1
debugpy==1.8.1
decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
defusedxml==0.7.1
dill==0.3.8
distro @ file:///croot/distro_1701455004953/work
dnspython==2.6.1
docstring_parser==0.16
einops==0.8.0
exceptiongroup @ file:///croot/exceptiongroup_1706031385326/work
executing @ file:///opt/conda/conda-bld/executing_1646925071911/work
expecttest==0.2.1
fastjsonschema==2.19.1
filelock @ file:///croot/filelock_1700591183607/work
flash-attn==2.5.8
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.3.1
gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work
h11==0.14.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.23.0
hypothesis==6.100.1
idna @ file:///croot/idna_1666125576474/work
iniconfig==2.0.0
ipykernel==6.29.4
ipython @ file:///croot/ipython_1704833016303/work
ipywidgets==8.1.2
isoduration==20.11.0
jedi @ file:///tmp/build/80754af9/jedi_1644315229345/work
Jinja2 @ file:///croot/jinja2_1706733616596/work
json5==0.9.25
jsonpatch @ file:///croot/jsonpatch_1710807507480/work
jsonpointer==2.1
jsonschema @ file:///croot/jsonschema_1699041609003/work
jsonschema-specifications @ file:///croot/jsonschema-specifications_1699032386549/work
jupyter==1.0.0
jupyter-archive==3.4.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-http-over-ws==0.0.8
jupyter-lsp==2.2.5
jupyter_client==8.6.1
jupyter_core==5.7.2
jupyter_server==2.14.0
jupyter_server_terminals==0.5.3
jupyterlab==4.1.8
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.1
jupyterlab_widgets==3.0.10
lark==1.1.9
libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work
libmambapy @ file:///croot/mamba-split_1712091911343/work/libmambapy
markdown-it-py==3.0.0
MarkupSafe @ file:///croot/markupsafe_1704205993651/work
matplotlib-inline @ file:///opt/conda/conda-bld/matplotlib-inline_1662014470464/work
mdurl==0.1.2
menuinst @ file:///croot/menuinst_1706732933928/work
mistune==3.0.2
mkl-fft @ file:///croot/mkl_fft_1695058164594/work
mkl-random @ file:///croot/mkl_random_1695059800811/work
mkl-service==2.4.0
more-itertools @ file:///croot/more-itertools_1700662129964/work
mpmath @ file:///croot/mpmath_1690848262763/work
multidict==6.0.5
multiprocess==0.70.16
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nbzip==0.1.0
nest-asyncio==1.6.0
networkx @ file:///croot/networkx_1690561992265/work
ninja==1.11.1.1
notebook==7.1.3
notebook_shim==0.2.4
numpy @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp310-cp310-linux_x86_64.whl#sha256=d8cd837ed43e87f77e6efaa08e8de927ca030a1c9c5d04624432d6fb9a74a5ee
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
optree==0.11.0
overrides==7.7.0
packaging @ file:///croot/packaging_1710807400464/work
pandas==2.2.2
pandocfilters==1.5.1
parso @ file:///opt/conda/conda-bld/parso_1641458642106/work
peft==0.10.0
pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
pillow @ file:///croot/pillow_1707233021655/work
pkginfo @ file:///croot/pkginfo_1679431160147/work
platformdirs @ file:///croot/platformdirs_1692205439124/work
pluggy==1.5.0
prometheus_client==0.20.0
prompt-toolkit @ file:///croot/prompt-toolkit_1704404351921/work
protobuf==3.20.3
psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
pyarrow==16.0.0
pyarrow-hotfix==0.6
pycosat @ file:///croot/pycosat_1696536503704/work
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
Pygments @ file:///croot/pygments_1684279966437/work
pyOpenSSL @ file:///croot/pyopenssl_1708380408460/work
PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work
pytest==8.2.0
python-dateutil==2.9.0.post0
python-etcd==0.4.5
python-json-logger==2.0.7
pytz @ file:///croot/pytz_1695131579487/work
PyYAML @ file:///croot/pyyaml_1698096049011/work
pyzmq==26.0.3
qtconsole==5.5.2
QtPy==2.4.1
referencing @ file:///croot/referencing_1699012038513/work
regex==2024.4.28
requests @ file:///croot/requests_1707355572290/work
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rpds-py @ file:///croot/rpds-py_1698945930462/work
ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work
ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work
safetensors==0.4.3
Send2Trash==1.8.3
sentencepiece==0.2.0
shtab==1.7.1
six @ file:///tmp/build/80754af9/six_1644875935023/work
sniffio==1.3.1
sortedcontainers==2.4.0
soupsieve @ file:///croot/soupsieve_1696347547217/work
stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
sympy @ file:///croot/sympy_1701397643339/work
tensorboardX==2.6.2.2
terminado==0.18.1
tinycss2==1.3.0
tokenizers==0.19.1
tomli @ file:///opt/conda/conda-bld/tomli_1657175507142/work
toolz @ file:///croot/toolz_1667464077321/work
torch==2.3.0+cu121
torchaudio==2.3.0
torchelastic==0.2.2
torchvision==0.18.0
tornado==6.4
tqdm @ file:///croot/tqdm_1679561862951/work
traitlets @ file:///croot/traitlets_1671143879854/work
transformers==4.40.2
triton==2.3.0
trl @ git+https://github.com/lapp0/trl.git@649aff0d142987b9e6a9ecea7ece562074d3f7c6
truststore @ file:///croot/truststore_1695244293384/work
types-dataclasses==0.6.6
types-python-dateutil==2.9.0.20240316
typing_extensions==4.11.0
tyro==0.8.3
tzdata==2024.1
unsloth @ git+https://github.com/unslothai/unsloth.git@a93a885c286934c9c7467324054ca3f9d526a2bd
uri-template==1.3.0
urllib3 @ file:///croot/urllib3_1707770551213/work
wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
webcolors==1.13
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==4.0.10
xformers==0.0.26.post1
xxhash==3.4.1
yarl==1.9.4
zstandard @ file:///croot/zstandard_1677013143055/work

@mano3-1
Copy link
Author

mano3-1 commented May 9, 2024

Hi @lapp0 ,
Although I named it requirements.txt, I have extracted it by doing pip freeze. Kindly check the file, you will find versions of all the libraries

@lapp0
Copy link

lapp0 commented May 9, 2024

Sorry about my confusion @mano3-1

I reviewed and compared our installed packages. Nothing noteworthy in the shared dependencies, other than perhaps the issue is related to the use of xformers. Will experiment with this later.

markdown-it-py==3.0.0
nvidia-curand-cu12==10.3.2.106
beautifulsoup4
parso
nvidia-cudnn-cu12==8.9.2.26
tokenizers==0.19.1
unsloth
Pygments
ptyprocess
nvidia-cuda-runtime-cu12==12.1.105
typing_extensions==4.11.0
matplotlib-inline
decorator
pure-eval
nvidia-cusparse-cu12==12.1.0.106
PyYAML
tomli
sentencepiece==0.2.0
six
async-timeout==4.0.3
prompt-toolkit
jsonschema
soupsieve
referencing
nvidia-cufft-cu12==11.0.2.54
PySocks
traitlets
mdurl==0.1.2
fsspec==2024.3.1
Brotli
xxhash==3.4.1
nvidia-cublas-cu12==12.1.3.1
tyro==0.8.3
platformdirs
packaging
pycparser
cffi
protobuf==3.20.3
pyarrow-hotfix==0.6
nvidia-cusolver-cu12==11.4.5.107
wcwidth
nvidia-cuda-nvrtc-cu12==12.1.105
asttokens
jedi
safetensors==0.4.3
exceptiongroup
aiosignal==1.3.1
nvidia-nvjitlink-cu12==12.4.127
nvidia-cuda-cupti-cu12==12.1.105
bitsandbytes==0.43.1
rpds-py
jsonschema-specifications
pexpect
nvidia-nvtx-cu12==12.1.105
peft==0.10.0

@danielhanchen
Copy link
Contributor

Thanks for the code repro - will test this out - sorry on the issue again!

@danielhanchen danielhanchen added the currently fixing Am fixing now! label May 9, 2024
@DementedWeasel1971
Copy link

Also facing same issue. While using colab and the standard notebook in the unsloth folder. Thought to add.

@mano3-1
Copy link
Author

mano3-1 commented May 14, 2024

hey,
I'm curious if someone has figured out a fix to this?

@danielhanchen
Copy link
Contributor

Sorry guys just started debugging this.
I also updated Unsloth, so maybe it might be better (hopefully).
For local installations, please update Unsloth via

pip uninstall unsloth -y
pip install --upgrade --force-reinstall --no-cache-dir git+https://github.com/unslothai/unsloth.git

For Colab / Kaggle should be fine with a restart

@DementedWeasel1971 When you said the colab notebook we provided broke, could you point to exactly which one thanks.

@mano3-1 Extremely weird actually - I reran Colab with Instruct and it seems fine - would you be able to run just the conversational notebook for Llama-3 here: https://colab.research.google.com/drive/1XamvWYinY6FOSX9GLvnqSjjsNflxdhNc?usp=sharing

@lapp0 I'm currently running your PPO example here: https://colab.research.google.com/drive/1fgJv0eKlRKexOl2RqcxoiZ-HhGrdNWQW?usp=sharing (will wait for it to complete)

@lapp0
Copy link

lapp0 commented May 15, 2024

Thank so much for looking into it! Unfortunately I'm still getting nan on the first training step:

{'loss': 1.9125, 'grad_norm': nan, 'learning_rate': 2.9999207167208437e-06, 'objective/kl': 0.0, 'objective/entropy': 99.8125, 'objective/non_score_reward': 0.0, 'objective/rlhf_reward': -0.58380126953125, 'objective/scores': -0.58380126953125, 'policy/approxkl_avg': 0.0, 'policy/clipfrac_avg': 0.0, 'loss/policy_avg': -6.116561479529992e-09, 'loss/value_avg': 19.12525177001953, 'val/clipfrac_avg': 0.0011788890697062016, 'val/num_eos_tokens': 0.5, 'timer/training_step': 2.293384313583374, 'epoch': 0.0}

Please let me know if there's any other debug details that would help.

Also fyi, to speed up debugging you can set update_generation_steps=1.

Edit:

I pushed a bad commit to my branch, I reverted the broken change. Should be good to try again with head of https://github.com/lapp0/trl.git@ppov2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
currently fixing Am fixing now!
Projects
None yet
Development

No branches or pull requests

4 participants