Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4-bit QLoRA via bitsandbytes (4-bit base model + LoRA) #23479

Merged
merged 66 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
596b1c0
Added lion and paged optimizers and made original tests pass.
TimDettmers May 8, 2023
e66d556
Added tests for paged and lion optimizers.
TimDettmers May 8, 2023
5cdc176
Added and fixed optimizer tests.
TimDettmers May 8, 2023
0773ae5
Style and quality checks.
TimDettmers May 8, 2023
24c49e5
Initial draft. Some tests fail.
TimDettmers May 10, 2023
68b8ba4
Merge remote-tracking branch 'origin/main' into bnb_4bit
TimDettmers May 10, 2023
03b4d78
Fixed dtype bug.
TimDettmers May 10, 2023
524be44
Fixed bug caused by torch_dtype='auto'.
TimDettmers May 10, 2023
06cf851
All test green for 8-bit and 4-bit layers.
TimDettmers May 11, 2023
42e1095
Merge remote-tracking branch 'forked/bnb_paged_optimizers' into forke…
TimDettmers May 11, 2023
2525aee
Added fix for fp32 layer norms and bf16 compute in LLaMA.
TimDettmers May 11, 2023
cb7e54a
Merge remote-tracking branch 'origin/main' into bnb_beta
TimDettmers May 19, 2023
90412ab
Initial draft. Some tests fail.
TimDettmers May 10, 2023
0e6015b
Fixed dtype bug.
TimDettmers May 10, 2023
866886c
Fixed bug caused by torch_dtype='auto'.
TimDettmers May 10, 2023
4c5ebf1
All test green for 8-bit and 4-bit layers.
TimDettmers May 11, 2023
170812b
Added lion and paged optimizers and made original tests pass.
TimDettmers May 8, 2023
6e0d3ac
Added tests for paged and lion optimizers.
TimDettmers May 8, 2023
1582692
Added and fixed optimizer tests.
TimDettmers May 8, 2023
1f25846
Style and quality checks.
TimDettmers May 8, 2023
56110ec
Fixing issues for PR #23479.
TimDettmers May 20, 2023
80396d0
Added fix for fp32 layer norms and bf16 compute in LLaMA.
TimDettmers May 11, 2023
d4b4e4d
Merge branch 'bnb_beta' of github.com:timdettmers/transformers into b…
TimDettmers May 20, 2023
6263752
Reverted variable name change.
TimDettmers May 20, 2023
831fc4a
Initial draft. Some tests fail.
TimDettmers May 10, 2023
b42644a
Fixed dtype bug.
TimDettmers May 10, 2023
9cd4319
Fixed bug caused by torch_dtype='auto'.
TimDettmers May 10, 2023
d68e564
All test green for 8-bit and 4-bit layers.
TimDettmers May 11, 2023
e8dcb57
Added lion and paged optimizers and made original tests pass.
TimDettmers May 8, 2023
ad30995
Added tests for paged and lion optimizers.
TimDettmers May 8, 2023
f1b2ab6
Added and fixed optimizer tests.
TimDettmers May 8, 2023
8b2e43d
Style and quality checks.
TimDettmers May 8, 2023
84cd0b3
Added missing tests.
TimDettmers May 20, 2023
61d2993
Merge branch 'bnb_beta' of github.com:timdettmers/transformers into f…
TimDettmers May 20, 2023
33dde75
Fixup changes.
TimDettmers May 20, 2023
1d830b5
Added fixup changes.
TimDettmers May 20, 2023
5c1a5e0
Merge branch 'bnb_beta' of github.com:timdettmers/transformers into b…
TimDettmers May 20, 2023
2f15b6e
Missed some variables to rename.
TimDettmers May 20, 2023
617b58c
Merge remote-tracking branch 'upstream/main' into HEAD
younesbelkada May 22, 2023
ea7175d
revert trainer tests
younesbelkada May 22, 2023
aac113d
revert test trainer
younesbelkada May 22, 2023
e43237d
another revert
younesbelkada May 22, 2023
13c86fd
fix tests and safety checkers
younesbelkada May 22, 2023
c72f302
protect import
younesbelkada May 22, 2023
7b1b1e6
simplify a bit
younesbelkada May 22, 2023
cf393cf
Update src/transformers/trainer.py
younesbelkada May 22, 2023
f19d80c
few fixes
younesbelkada May 22, 2023
ba287ff
add warning
younesbelkada May 22, 2023
1030921
replace with `load_in_kbit = load_in_4bit or load_in_8bit`
younesbelkada May 22, 2023
1cae462
fix test
younesbelkada May 22, 2023
25f762e
fix tests
younesbelkada May 22, 2023
2f43dc1
this time fix tests
younesbelkada May 22, 2023
a63b649
safety checker
younesbelkada May 22, 2023
49501db
add docs
younesbelkada May 22, 2023
4642523
revert torch_dtype
younesbelkada May 22, 2023
a6ba77b
Apply suggestions from code review
younesbelkada May 22, 2023
27cdff6
multiple fixes
younesbelkada May 22, 2023
b2bc3ab
update docs
younesbelkada May 22, 2023
976f7d0
version checks and multiple fixes
younesbelkada May 22, 2023
9c4946e
replace `is_loaded_in_kbit`
younesbelkada May 22, 2023
6f4f4dc
replace `load_in_kbit`
younesbelkada May 22, 2023
5359b59
change methods names
younesbelkada May 22, 2023
0c0bb65
better checks
younesbelkada May 22, 2023
f4a2a0b
oops
younesbelkada May 22, 2023
13a2ad7
oops
younesbelkada May 22, 2023
0b05092
address final comments
younesbelkada May 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
85 changes: 84 additions & 1 deletion docs/source/en/main_classes/quantization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,45 @@ This is supported by most of the GPU hardwares since the `0.37.0` release of `bi

Learn more about the quantization method in the [LLM.int8()](https://arxiv.org/abs/2208.07339) paper, or the [blogpost](https://huggingface.co/blog/hf-bitsandbytes-integration) about the collaboration.

Since its `0.39.0` release, you can load any model that supports `device_map` using 4-bit quantization, leveraging FP4 data type.

Here are the things you can do using `bitsandbytes` integration

### FP4 quantization

#### Requirements

Make sure that you have installed the requirements below before running any of the code snippets below.

- Latest `bitsandbytes` library
`pip install bitsandbytes>=0.39.0`

- Install latest `accelerate` from source
`pip install git+https://github.com/huggingface/accelerate.git`

- Install latest `transformers` from source
`pip install git+https://github.com/huggingface/transformers.git`

#### Load a large model in 4bit

By using `load_in_4bit=True` when calling the `.from_pretrained` method, you can divide your memory use by 4 (roughly).

```python
# pip install transformers accelerate bitsandbytes
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "bigscience/bloom-1b7"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
```

<Tip warning={true}>

Note that once a model has been loaded in 4-bit it is currently not possible to push the quantized weights on the Hub. Note also that you cannot train 4-bit weights as this is not supported yet. However you can use 4-bit models to train extra parameters, this will be covered in the next section.

</Tip>

### Load a large model in 8bit

You can load a model by roughly halving the memory requirements by using `load_in_8bit=True` argument when calling `.from_pretrained` method
Expand Down Expand Up @@ -48,10 +85,56 @@ With this integration we were able to load large models on smaller devices and r

<Tip warning={true}>

Note that once a model has been loaded in 8-bit it is currently not possible to push the quantized weights on the Hub. Note also that you cannot train 8-bit weights as this is not supported yet. However you can use 8-bit models to train extra parameters, this will be covered in the next section.
Note that once a model has been loaded in 8-bit it is currently not possible to push the quantized weights on the Hub except if you use the latest `transformers` and `bitsandbytes`. Note also that you cannot train 8-bit weights as this is not supported yet. However you can use 8-bit models to train extra parameters, this will be covered in the next section.

</Tip>

#### Advanced usecases

Here we will cover some advanced usecases you can perform with FP4 quantization

##### Change the compute dtype

The compute dtype is used to change the dtype that will be used during computation. For example, hidden states could be in `float32` but computation can be set to bf16 for speedups. By default, the compute dtype is set to `float32`.

```python
import torch
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
```

##### Using NF4 (Normal Float 4) data type

You can also use the NF4 data type, which is a new 4bit datatype adapted for weights that have been initialized using a normal distribution. For that run:

```python
from transformers import BitsAndBytesConfig

nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)

model_nf4 = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=nf4_config)
```

##### Use nested quantization for more memory efficient inference

We also advise users to use the nested quantization technique. This saves more memory at no additional performance - from our empirical observations, this enables fine-tuning llama-13b model on an NVIDIA-T4 16GB with a sequence length of 1024, batch size of 1 and gradient accumulation steps of 4.

```python
from transformers import BitsAndBytesConfig

double_quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)

model_double_quant = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=double_quant_config)
```


### Push quantized models on the 🤗 Hub

You can push a quantized model on the Hub by naively using `push_to_hub` method. This will first push the quantization configuration file, then push the quantized model weights.
Expand Down
54 changes: 54 additions & 0 deletions docs/source/en/perf_infer_gpu_one.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,60 @@ model.save_pretrained("saved_model")

As of PyTorch 2.0, the attention fastpath is supported for both encoders and decoders. The list of supported architectures can be found [here](https://huggingface.co/docs/optimum/bettertransformer/overview#supported-models).

## `bitsandbytes` integration for FP4 mixed-precision inference

You can install `bitsandbytes` and benefit from easy model compression on GPUs. Using FP4 quantization you can expect to reduce up to 8x the model size compared to its native full precision version. Check out below how to get started.

<Tip>

Note that this feature can also be used in a multi GPU setup.

</Tip>

### Requirements

- Latest `bitsandbytes` library
`pip install bitsandbytes>=0.39.0`

- Install latest `accelerate` from source
`pip install git+https://github.com/huggingface/accelerate.git`

- Install latest `transformers` from source
`pip install git+https://github.com/huggingface/transformers.git`

### Running FP4 models - single GPU setup - Quickstart

You can quickly run a FP4 model on a single GPU by running the following code:

```py
from transformers import AutoModelForCausalLM

model_name = "bigscience/bloom-2b5"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)
```

### Running FP4 models - multi GPU setup

The way to load your mixed 8-bit model in multiple GPUs is as follows (same command as single GPU setup):
```py
model_name = "bigscience/bloom-2b5"
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)
```
But you can control the GPU RAM you want to allocate on each GPU using `accelerate`. Use the `max_memory` argument as follows:

```py
max_memory_mapping = {0: "600MB", 1: "1GB"}
model_name = "bigscience/bloom-3b"
model_8bit = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", load_in_4bit=True, max_memory=max_memory_mapping
)
```
In this example, the first GPU will use 600MB of memory and the second 1GB.

### Advanced usage

For more advanced usage of this method, please have a look at the [quantization](main_classes/quantization) documentation page.

## `bitsandbytes` integration for Int8 mixed-precision matrix decomposition

<Tip>
Expand Down