Skip to content

Commit

Permalink
4-bit QLoRA via bitsandbytes (4-bit base model + LoRA) (huggingface#2…
Browse files Browse the repository at this point in the history
…3479)

* Added lion and paged optimizers and made original tests pass.

* Added tests for paged and lion optimizers.

* Added and fixed optimizer tests.

* Style and quality checks.

* Initial draft. Some tests fail.

* Fixed dtype bug.

* Fixed bug caused by torch_dtype='auto'.

* All test green for 8-bit and 4-bit layers.

* Added fix for fp32 layer norms and bf16 compute in LLaMA.

* Initial draft. Some tests fail.

* Fixed dtype bug.

* Fixed bug caused by torch_dtype='auto'.

* All test green for 8-bit and 4-bit layers.

* Added lion and paged optimizers and made original tests pass.

* Added tests for paged and lion optimizers.

* Added and fixed optimizer tests.

* Style and quality checks.

* Fixing issues for PR huggingface#23479.

* Added fix for fp32 layer norms and bf16 compute in LLaMA.

* Reverted variable name change.

* Initial draft. Some tests fail.

* Fixed dtype bug.

* Fixed bug caused by torch_dtype='auto'.

* All test green for 8-bit and 4-bit layers.

* Added lion and paged optimizers and made original tests pass.

* Added tests for paged and lion optimizers.

* Added and fixed optimizer tests.

* Style and quality checks.

* Added missing tests.

* Fixup changes.

* Added fixup changes.

* Missed some variables to rename.

* revert trainer tests

* revert test trainer

* another revert

* fix tests and safety checkers

* protect import

* simplify a bit

* Update src/transformers/trainer.py

* few fixes

* add warning

* replace with `load_in_kbit = load_in_4bit or load_in_8bit`

* fix test

* fix tests

* this time fix tests

* safety checker

* add docs

* revert torch_dtype

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* multiple fixes

* update docs

* version checks and multiple fixes

* replace `is_loaded_in_kbit`

* replace `load_in_kbit`

* change methods names

* better checks

* oops

* oops

* address final comments

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
4 people authored and novice03 committed Jun 23, 2023
1 parent a683a39 commit 9324110
Show file tree
Hide file tree
Showing 10 changed files with 864 additions and 76 deletions.
85 changes: 84 additions & 1 deletion docs/source/en/main_classes/quantization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,45 @@ This is supported by most of the GPU hardwares since the `0.37.0` release of `bi

Learn more about the quantization method in the [LLM.int8()](https://arxiv.org/abs/2208.07339) paper, or the [blogpost](https://huggingface.co/blog/hf-bitsandbytes-integration) about the collaboration.

Since its `0.39.0` release, you can load any model that supports `device_map` using 4-bit quantization, leveraging FP4 data type.

Here are the things you can do using `bitsandbytes` integration

### FP4 quantization

#### Requirements

Make sure that you have installed the requirements below before running any of the code snippets below.

- Latest `bitsandbytes` library
`pip install bitsandbytes>=0.39.0`

- Install latest `accelerate` from source
`pip install git+https://github.com/huggingface/accelerate.git`

- Install latest `transformers` from source
`pip install git+https://github.com/huggingface/transformers.git`

#### Load a large model in 4bit

By using `load_in_4bit=True` when calling the `.from_pretrained` method, you can divide your memory use by 4 (roughly).

```python
# pip install transformers accelerate bitsandbytes
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "bigscience/bloom-1b7"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
```

<Tip warning={true}>

Note that once a model has been loaded in 4-bit it is currently not possible to push the quantized weights on the Hub. Note also that you cannot train 4-bit weights as this is not supported yet. However you can use 4-bit models to train extra parameters, this will be covered in the next section.

</Tip>

### Load a large model in 8bit

You can load a model by roughly halving the memory requirements by using `load_in_8bit=True` argument when calling `.from_pretrained` method
Expand Down Expand Up @@ -48,10 +85,56 @@ With this integration we were able to load large models on smaller devices and r

<Tip warning={true}>

Note that once a model has been loaded in 8-bit it is currently not possible to push the quantized weights on the Hub. Note also that you cannot train 8-bit weights as this is not supported yet. However you can use 8-bit models to train extra parameters, this will be covered in the next section.
Note that once a model has been loaded in 8-bit it is currently not possible to push the quantized weights on the Hub except if you use the latest `transformers` and `bitsandbytes`. Note also that you cannot train 8-bit weights as this is not supported yet. However you can use 8-bit models to train extra parameters, this will be covered in the next section.

</Tip>

#### Advanced usecases

Here we will cover some advanced usecases you can perform with FP4 quantization

##### Change the compute dtype

The compute dtype is used to change the dtype that will be used during computation. For example, hidden states could be in `float32` but computation can be set to bf16 for speedups. By default, the compute dtype is set to `float32`.

```python
import torch
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
```

##### Using NF4 (Normal Float 4) data type

You can also use the NF4 data type, which is a new 4bit datatype adapted for weights that have been initialized using a normal distribution. For that run:

```python
from transformers import BitsAndBytesConfig

nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)

model_nf4 = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=nf4_config)
```

##### Use nested quantization for more memory efficient inference

We also advise users to use the nested quantization technique. This saves more memory at no additional performance - from our empirical observations, this enables fine-tuning llama-13b model on an NVIDIA-T4 16GB with a sequence length of 1024, batch size of 1 and gradient accumulation steps of 4.

```python
from transformers import BitsAndBytesConfig

double_quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)

model_double_quant = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=double_quant_config)
```


### Push quantized models on the 🤗 Hub

You can push a quantized model on the Hub by naively using `push_to_hub` method. This will first push the quantization configuration file, then push the quantized model weights.
Expand Down
54 changes: 54 additions & 0 deletions docs/source/en/perf_infer_gpu_one.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,60 @@ model.save_pretrained("saved_model")

As of PyTorch 2.0, the attention fastpath is supported for both encoders and decoders. The list of supported architectures can be found [here](https://huggingface.co/docs/optimum/bettertransformer/overview#supported-models).

## `bitsandbytes` integration for FP4 mixed-precision inference

You can install `bitsandbytes` and benefit from easy model compression on GPUs. Using FP4 quantization you can expect to reduce up to 8x the model size compared to its native full precision version. Check out below how to get started.

<Tip>

Note that this feature can also be used in a multi GPU setup.

</Tip>

### Requirements

- Latest `bitsandbytes` library
`pip install bitsandbytes>=0.39.0`

- Install latest `accelerate` from source
`pip install git+https://github.com/huggingface/accelerate.git`

- Install latest `transformers` from source
`pip install git+https://github.com/huggingface/transformers.git`

### Running FP4 models - single GPU setup - Quickstart

You can quickly run a FP4 model on a single GPU by running the following code:

```py
from transformers import AutoModelForCausalLM

model_name = "bigscience/bloom-2b5"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)
```

### Running FP4 models - multi GPU setup

The way to load your mixed 8-bit model in multiple GPUs is as follows (same command as single GPU setup):
```py
model_name = "bigscience/bloom-2b5"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)
```
But you can control the GPU RAM you want to allocate on each GPU using `accelerate`. Use the `max_memory` argument as follows:

```py
max_memory_mapping = {0: "600MB", 1: "1GB"}
model_name = "bigscience/bloom-3b"
model_8bit = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", load_in_4bit=True, max_memory=max_memory_mapping
)
```
In this example, the first GPU will use 600MB of memory and the second 1GB.

### Advanced usage

For more advanced usage of this method, please have a look at the [quantization](main_classes/quantization) documentation page.

## `bitsandbytes` integration for Int8 mixed-precision matrix decomposition

<Tip>
Expand Down

0 comments on commit 9324110

Please sign in to comment.