Skip to content

Trainer doesn't handle torch.compiled QLoRA models correctly #29033

Closed
@readwriteexec

Description

@readwriteexec

System Info

  • transformers version: 4.38.0.dev0
  • Platform: Linux-6.1.58+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.28.0.dev0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • Tensorflow version (GPU?): 2.15.0 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.8.1 (cpu)
  • Jax version: 0.4.23
  • JaxLib version: 0.4.23
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

!pip install bitsandbytes
!pip install git+https://github.com/huggingface/accelerate
!pip install git+https://github.com/huggingface/datasets
!pip install git+https://github.com/huggingface/peft
!pip install git+https://github.com/huggingface/transformers
!pip install git+https://github.com/huggingface/trl

import torch

import accelerate
import datasets
import peft
import transformers
import trl

import bitsandbytes

train_dataset = datasets.load_dataset('imdb', split='train')

bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
)

lora_config = peft.LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
model = transformers.AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", quantization_config=bnb_config)
model = peft.prepare_model_for_kbit_training(model)
model = peft.get_peft_model(model, lora_config)

trainer = trl.SFTTrainer(
    # model=model,  # Does not raise a ValueError
    model=torch.compile(model),  # Raises a ValueError
    train_dataset=train_dataset,
    dataset_text_field='text',
    max_seq_length=512,
)

Expected behavior

Expected Behaviour:

Best Case: Calling torch.compile has no effect on whether an exception is raised.
Worst Case: Raising an exception that reflects that torch.compile isn't supported.

Current behaviour:


[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in __init__(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
    431         # At this stage the model is already loaded
    432         if _is_quantized_and_base_model and not _is_peft_model(model):
--> 433             raise ValueError(
    434                 "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
    435                 " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"

ValueError: You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft for more details

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions