Skip to content

Commit

Permalink
Add RWKV v5.1 and v5.2 support (#145)
Browse files Browse the repository at this point in the history
* Reformat CMakeLists and docs

* Add RWKV v5.1 and v5.2 support

* Remove sequence_length-based heuristic for allocating ggml context

* Set GGML_CUDA_MMV_Y to 2 by default

* Clarify comment; update ggml

* Make some operations inplace

* Make some operations inplace

* Add late_abort option for tests

* Increase thresholds

* Increase thresholds

* Update README.md
  • Loading branch information
saharNooby committed Nov 12, 2023
1 parent 22a2778 commit 20a8549
Show file tree
Hide file tree
Showing 38 changed files with 743 additions and 219 deletions.
20 changes: 16 additions & 4 deletions CMakeLists.txt
Expand Up @@ -117,6 +117,13 @@ if (RWKV_CUBLAS)

add_compile_definitions(GGML_USE_CUBLAS)

# By default, GGML_CUDA_MMV_Y is set to 1. This value leads to CUDA error on my machine:
# CUDA error 9 at ...\rwkv.cpp\ggml\src\ggml-cuda.cu:6107: invalid configuration argument
# The error appears when the head matrix of v5 3B and v5 7B models is offloaded. I guess the matrix is so large that block_num_y becomes too big.
# Changing it to 2 makes it work. I did not see any performance impact when measuring v5 3B & v5 7B. Hopefully, this will not break other use-cases.
# TODO Re-check after updating ggml whether this is needed
add_compile_definitions(GGML_CUDA_MMV_Y=2)

if (RWKV_STATIC)
set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
Expand Down Expand Up @@ -191,6 +198,7 @@ if (RWKV_HIPBLAS)
if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
endif()

if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
endif()
Expand All @@ -202,18 +210,22 @@ if (RWKV_HIPBLAS)
if (${hipblas_FOUND} AND ${hip_FOUND})
message(STATUS "HIP and hipBLAS found")
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
add_library(ggml-rocm OBJECT
${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu
${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.h )
add_library(ggml-rocm OBJECT
${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu
${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.h)

if (BUILD_SHARED_LIBS)
set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()

target_include_directories(ggml-rocm PUBLIC ${CMAKE_SOURCE_DIR}/ggml/include/ggml)
set_source_files_properties(${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu PROPERTIES LANGUAGE CXX)
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)

if (RWKV_STATIC)
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
endif()

set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} ggml-rocm)
else()
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
Expand Down Expand Up @@ -406,7 +418,7 @@ target_compile_features(ggml PUBLIC c_std_11) # Don't bump
if (MSVC)
target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads)
else()
if(WIN32 AND RWKV_HIPBLAS)
if (WIN32 AND RWKV_HIPBLAS)
target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads)
else()
target_link_libraries(ggml PUBLIC m ${RWKV_EXTRA_LIBS} Threads::Threads)
Expand Down
53 changes: 30 additions & 23 deletions README.md
Expand Up @@ -6,17 +6,19 @@ Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT

This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](python%2Frwkv_cpp%2Frwkv_cpp_model.py) for it.

[RWKV](https://arxiv.org/abs/2305.13048) is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.
[RWKV](https://arxiv.org/abs/2305.13048) is a large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.

Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py).
[RWKV v5](https://huggingface.co/BlinkDL/rwkv-5-world) is a major upgrade to RWKV architecture, making it competitive with Transformers in quality. RWKV v5 models are supported.

⚠️ **Python API was restructured on 2023-09-20**, you may need to change paths/package names in your code when updating `rwkv.cpp`.
Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py).

## Quality and performance

If you use `rwkv.cpp` for anything serious, please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you.

Below table is for reference only. Measurements were made on 4C/8T x86 CPU with AVX2, 4 threads.
In general, `RWKV v5` models are 2 times slower than `RWKV v4` models, and require from 1.5 times (sequence length = 1) to 6 times (sequence length = 64) more memory.

Below table is for reference only. Measurements were made on 4C/8T x86 CPU with AVX2, 4 threads. The models are `RWKV v4 Pile 169M`, `RWKV v4 Pile 1.5B`.

| Format | Perplexity (169M) | Latency, ms (1.5B) | File size, GB (1.5B) |
|-----------|-------------------|--------------------|----------------------|
Expand All @@ -30,33 +32,38 @@ Below table is for reference only. Measurements were made on 4C/8T x86 CPU with

### With cuBLAS

Measurements were made on Intel i7 13700K & NVIDIA 3060 Ti 8 GB. Latency per token in ms shown.
Measurements were made on Intel i7 13700K & NVIDIA 3060 Ti 8 GB. The model is `RWKV-4-Pile-169M`, 12 layers were offloaded to GPU.

| Model | Layers on GPU | Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads |
|-----------------------|---------------|--------|----------|-----------|-----------|-----------|------------|
| `RWKV-4-Pile-169M` | 12 | `Q4_0` | 7.9 | 6.2 | 6.9 | 8.6 | 20 |
| `RWKV-4-Pile-169M` | 12 | `Q4_1` | 7.8 | 6.7 | 6.9 | 8.6 | 21 |
| `RWKV-4-Pile-169M` | 12 | `Q5_1` | 8.1 | 6.7 | 6.9 | 9.0 | 22 |
Latency per token in ms shown.

| Model | Layers on GPU | Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads |
|-----------------------|---------------|--------|----------|-----------|-----------|-----------|------------|
| `RWKV-4-Raven-7B-v11` | 32 | `Q4_0` | 59 | 51 | 50 | 54 | 94 |
| `RWKV-4-Raven-7B-v11` | 32 | `Q4_1` | 59 | 51 | 49 | 54 | 94 |
| `RWKV-4-Raven-7B-v11` | 32 | `Q5_1` | 77 | 69 | 67 | 72 | 101 |
| Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads |
|--------|----------|-----------|-----------|-----------|------------|
| `Q4_0` | 7.9 | 6.2 | 6.9 | 8.6 | 20 |
| `Q4_1` | 7.8 | 6.7 | 6.9 | 8.6 | 21 |
| `Q5_1` | 8.1 | 6.7 | 6.9 | 9.0 | 22 |

| Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads |
|--------|----------|-----------|-----------|-----------|------------|
| `Q4_0` | 59 | 51 | 50 | 54 | 94 |
| `Q4_1` | 59 | 51 | 49 | 54 | 94 |
| `Q5_1` | 77 | 69 | 67 | 72 | 101 |

Note: since cuBLAS is supported only for `ggml_mul_mat()`, we still need to use few CPU resources to execute remaining operations.

### With hipBLAS
Measurements were made on CPU AMD Ryzen 9 5900X & GPU AMD Radeon RX 7900 XTX. Latency per token in ms shown.

| Model | Layers on GPU | Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads |
|------------------------------------------|---------------|--------|----------|-----------|-----------|-----------|------------|
| `RWKV-novel-4-World-7B-20230810-ctx128k` | 32 | `f16` | 94 | 91 | 94 | 106 | 944 |
| `RWKV-novel-4-World-7B-20230810-ctx128k` | 32 | `Q4_0` | 83 | 77 | 75 | 110 | 1692 |
| `RWKV-novel-4-World-7B-20230810-ctx128k` | 32 | `Q4_1` | 85 | 80 | 85 | 93 | 1691 |
| `RWKV-novel-4-World-7B-20230810-ctx128k` | 32 | `Q5_1` | 83 | 78 | 83 | 90 | 1115 |
Measurements were made on CPU AMD Ryzen 9 5900X & GPU AMD Radeon RX 7900 XTX. The model is `RWKV-novel-4-World-7B-20230810-ctx128k`, 32 layers were offloaded to GPU.

Latency per token in ms shown.

| Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads |
|--------|----------|-----------|-----------|-----------|------------|
| `f16` | 94 | 91 | 94 | 106 | 944 |
| `Q4_0` | 83 | 77 | 75 | 110 | 1692 |
| `Q4_1` | 85 | 80 | 85 | 93 | 1691 |
| `Q5_1` | 83 | 78 | 83 | 90 | 1115 |

Note: hipBLAS is same as cuBLAS.They only support `ggml_mul_mat()`, we still need to use few CPU resources to execute remaining operations.
Note: same as cuBLAS, hipBLAS only supports `ggml_mul_mat()`, we still need to use few CPU resources to execute remaining operations.

## How to use

Expand Down
8 changes: 3 additions & 5 deletions docs/hipBLAS_on_Windows.md
Expand Up @@ -8,7 +8,6 @@ Skip this step if you already have Build Tools installed.

To install Build Tools, go to [Visual Studio Older Downloads](https://visualstudio.microsoft.com/vs/), download `Visual Studio 2022 and other Products` and run the installer.


## CMake

Skip this step if you already have CMake installed: running `cmake --version` should output `cmake version x.y.z`.
Expand All @@ -21,7 +20,7 @@ Skip this step if you already have Build Tools installed.

The [validation tools](https://rocm.docs.amd.com/en/latest/reference/validation_tools.html) not support on Windows. So you should confirm the Version of `ROCM` by yourself.

Fortunately `AMD` provides complete help documentation, you can use the help documentation to install [ROCM](https://rocm.docs.amd.com/en/latest/deploy/windows/quick_start.html)
Fortunately, `AMD` provides complete help documentation, you can use the help documentation to install [ROCM](https://rocm.docs.amd.com/en/latest/deploy/windows/quick_start.html)

>**If you encounter an error, if it is [AMD ROCm Windows Installation Error 215](https://github.com/RadeonOpenCompute/ROCm/issues/2363), don't worry about this error. ROCM has been installed correctly, but the vs studio plugin installation failed, we can ignore it.**
Expand All @@ -39,8 +38,7 @@ set CXX=C:\Program Files\AMD\ROCm\5.5\bin\clang++.exe

Skip this step if you already have Ninja installed: running `ninja --version` should output `1.11.1`.

Download latest `ninja-win.zip` from [GitHub Releases Page](https://github.com/ninja-build/ninja/releases/tag/v1.11.1) and unzip.Then set as environment variables.
I unzipped it in `C:\Program Files\ninja`, so I set it like this:
Download latest `ninja-win.zip` from [GitHub Releases Page](https://github.com/ninja-build/ninja/releases/tag/v1.11.1) and unzip. Then set as environment variables. I unzipped it in `C:\Program Files\ninja`, so I set it like this:

```Commandline
set ninja=C:\Program Files\ninja\ninja.exe
Expand All @@ -50,7 +48,7 @@ set ninja=C:\Program Files\ninja\ninja.exe
The thing different from the regular CPU build is `-DRWKV_HIPBLAS=ON` ,
`-G "Ninja"`, `-DCMAKE_C_COMPILER=clang`, `-DCMAKE_CXX_COMPILER=clang++`, `-DAMDGPU_TARGETS=gfx1100`

>**Notice** check the `clang` and `clang++` information:
>**Notice**: check the `clang` and `clang++` information:
```Commandline
clang --version
clang++ --version
Expand Down
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated from d925ed to 4b20bb
33 changes: 27 additions & 6 deletions python/convert_pytorch_to_ggml.py
Expand Up @@ -32,6 +32,16 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
n_vocab: int = emb_weight.shape[0]
n_embed: int = emb_weight.shape[1]

is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict
is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict

if is_v5_2:
print('Detected RWKV v5.2')
elif is_v5_1_or_2:
print('Detected RWKV v5.1')
else:
print('Detected RWKV v4')

with open(dest_path, 'wb') as out_file:
is_FP16: bool = data_type == 'FP16' or data_type == 'float16'

Expand All @@ -50,16 +60,27 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
for k in state_dict.keys():
tensor: torch.Tensor = state_dict[k].float()

# Same processing as in "RWKV_in_150_lines.py"
if '.time_' in k:
# (1, 1, n_embed) -> (n_embed)
tensor = tensor.squeeze()

if '.time_decay' in k:
tensor = -torch.exp(tensor)
if is_v5_1_or_2:
if '.time_decay' in k:
if is_v5_2:
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
else:
tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1)

if '.time_first' in k:
tensor = torch.exp(tensor).reshape(-1, 1, 1)

if '.time_faaaa' in k:
tensor = tensor.unsqueeze(-1)
else:
if '.time_decay' in k:
tensor = -torch.exp(tensor)

# Keep 1-dim vectors in FP32
if is_FP16 and len(tensor.shape) > 1:
# Keep 1-dim vectors and small matrices in FP32
if is_FP16 and len(tensor.shape) > 1 and '.time_' not in k:
tensor = tensor.half()

shape = tensor.shape
Expand Down
25 changes: 21 additions & 4 deletions python/merge_lora_into_ggml.py
Expand Up @@ -13,8 +13,9 @@
def parse_args():
parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file')
parser.add_argument('src_path', help='Path to source rwkv.cpp model')
parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2', type=str, choices=['v4', 'v5.1', 'v5.2'])
parser.add_argument('lora_path', help='Path to LoRA checkpoint in PyTorch format')
parser.add_argument('lora_alpha', type=int, help='Value of lora_alpha parameter used when training this LoRA checkpoint')
parser.add_argument('lora_alpha', help='Value of lora_alpha parameter used when training this LoRA checkpoint', type=int)
parser.add_argument('dest_path', help='Path to destination rwkv.cpp model, will be overwitten with the merged model')
return parser.parse_args()

Expand Down Expand Up @@ -44,6 +45,10 @@ def write_parameter(out_file, key: str, parameter: torch.Tensor) -> None:
def main() -> None:
args = parse_args()

arch_version: str = args.rwkv_arch_version

assert arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2', f'Invalid RWKV architecture version {arch_version}'

print(f'Reading {args.lora_path}')

lora_state_dict: Dict[str, torch.Tensor] = torch.load(args.lora_path, map_location='cpu')
Expand Down Expand Up @@ -96,11 +101,23 @@ def main() -> None:

# Same processing as in convert_pytorch_to_ggml.py
if '.time_' in key:
# (1, 1, n_embed) -> (n_embed)
replacement = replacement.squeeze()

if '.time_decay' in key:
replacement = -torch.exp(replacement)
if arch_version == 'v5.1' or arch_version == 'v5.2':
if '.time_decay' in key:
if arch_version == 'v5.2':
replacement = torch.exp(-torch.exp(replacement)).unsqueeze(-1)
else:
replacement = torch.exp(-torch.exp(replacement)).reshape(-1, 1, 1)

if '.time_first' in key:
replacement = torch.exp(replacement).reshape(-1, 1, 1)

if '.time_faaaa' in key:
replacement = replacement.unsqueeze(-1)
else:
if '.time_decay' in key:
replacement = -torch.exp(replacement)

if parameter.dtype == torch.float16:
replacement = replacement.half()
Expand Down
6 changes: 5 additions & 1 deletion rwkv.cpp
Expand Up @@ -121,7 +121,11 @@ size_t rwkv_get_n_layer(const struct rwkv_context * ctx) {
size_t rwkv_get_state_len(const struct rwkv_context * ctx) {
const struct rwkv_file_header & header = ctx->model->header;

return (size_t) header.n_embed * 5 * (size_t) header.n_layer;
if (ctx->model->arch_version_major >= 5) {
return (size_t) header.n_embed * (2 + ctx->model->head_size) * (size_t) header.n_layer;
} else {
return (size_t) header.n_embed * 5 * (size_t) header.n_layer;
}
}

// API function.
Expand Down
10 changes: 6 additions & 4 deletions rwkv_eval.inc
Expand Up @@ -176,16 +176,18 @@ bool rwkv_eval_sequence_in_chunks(

// API function.
void rwkv_init_state(const struct rwkv_context * ctx, float * state) {
memset(state, 0, rwkv_get_state_len(ctx) * sizeof(float));

if (ctx->model->arch_version_major >= 5) {
return;
}

const struct rwkv_file_header & header = ctx->model->header;
const size_t layer_size = (size_t) header.n_embed * 5;
const size_t layer_zero = (size_t) header.n_embed * 4;
const size_t layers_size = (size_t) header.n_layer * layer_size;

for (size_t start = 0; start < layers_size; start += layer_size) {
for (size_t i = 0; i < layer_zero; i++) {
state[start + i] = 0.0F;
}

for (size_t i = layer_zero; i < layer_size; i++) {
state[start + i] = -1e30F;
}
Expand Down

0 comments on commit 20a8549

Please sign in to comment.