Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RWKV v5.1 and v5.2 support #145

Merged
merged 11 commits into from
Nov 12, 2023
Merged
20 changes: 16 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ if (RWKV_CUBLAS)

add_compile_definitions(GGML_USE_CUBLAS)

# By default, GGML_CUDA_MMV_Y is set to 1. This value leads to CUDA error on my machine:
# CUDA error 9 at ...\rwkv.cpp\ggml\src\ggml-cuda.cu:6107: invalid configuration argument
# The error appears when the head matrix of v5 3B and v5 7B models is offloaded. I guess the matrix is so large that block_num_y becomes too big.
# Changing it to 2 makes it work. I did not see any performance impact when measuring v5 3B & v5 7B. Hopefully, this will not break other use-cases.
# TODO Re-check after updating ggml whether this is needed
add_compile_definitions(GGML_CUDA_MMV_Y=2)

if (RWKV_STATIC)
set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
Expand Down Expand Up @@ -191,6 +198,7 @@ if (RWKV_HIPBLAS)
if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
endif()

if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
endif()
Expand All @@ -202,18 +210,22 @@ if (RWKV_HIPBLAS)
if (${hipblas_FOUND} AND ${hip_FOUND})
message(STATUS "HIP and hipBLAS found")
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
add_library(ggml-rocm OBJECT
${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu
${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.h )
add_library(ggml-rocm OBJECT
${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu
${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.h)

if (BUILD_SHARED_LIBS)
set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()

target_include_directories(ggml-rocm PUBLIC ${CMAKE_SOURCE_DIR}/ggml/include/ggml)
set_source_files_properties(${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu PROPERTIES LANGUAGE CXX)
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)

if (RWKV_STATIC)
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
endif()

set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} ggml-rocm)
else()
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
Expand Down Expand Up @@ -406,7 +418,7 @@ target_compile_features(ggml PUBLIC c_std_11) # Don't bump
if (MSVC)
target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads)
else()
if(WIN32 AND RWKV_HIPBLAS)
if (WIN32 AND RWKV_HIPBLAS)
target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads)
else()
target_link_libraries(ggml PUBLIC m ${RWKV_EXTRA_LIBS} Threads::Threads)
Expand Down
53 changes: 30 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@ Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT

This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](python%2Frwkv_cpp%2Frwkv_cpp_model.py) for it.

[RWKV](https://arxiv.org/abs/2305.13048) is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.
[RWKV](https://arxiv.org/abs/2305.13048) is a large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.

Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py).
[RWKV v5](https://huggingface.co/BlinkDL/rwkv-5-world) is a major upgrade to RWKV architecture, making it competitive with Transformers in quality. RWKV v5 models are supported.

⚠️ **Python API was restructured on 2023-09-20**, you may need to change paths/package names in your code when updating `rwkv.cpp`.
Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py).

## Quality and performance

If you use `rwkv.cpp` for anything serious, please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you.

Below table is for reference only. Measurements were made on 4C/8T x86 CPU with AVX2, 4 threads.
In general, `RWKV v5` models are 2 times slower than `RWKV v4` models, and require from 1.5 times (sequence length = 1) to 6 times (sequence length = 64) more memory.

Below table is for reference only. Measurements were made on 4C/8T x86 CPU with AVX2, 4 threads. The models are `RWKV v4 Pile 169M`, `RWKV v4 Pile 1.5B`.

| Format | Perplexity (169M) | Latency, ms (1.5B) | File size, GB (1.5B) |
|-----------|-------------------|--------------------|----------------------|
Expand All @@ -30,33 +32,38 @@ Below table is for reference only. Measurements were made on 4C/8T x86 CPU with

### With cuBLAS

Measurements were made on Intel i7 13700K & NVIDIA 3060 Ti 8 GB. Latency per token in ms shown.
Measurements were made on Intel i7 13700K & NVIDIA 3060 Ti 8 GB. The model is `RWKV-4-Pile-169M`, 12 layers were offloaded to GPU.

| Model | Layers on GPU | Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads |
|-----------------------|---------------|--------|----------|-----------|-----------|-----------|------------|
| `RWKV-4-Pile-169M` | 12 | `Q4_0` | 7.9 | 6.2 | 6.9 | 8.6 | 20 |
| `RWKV-4-Pile-169M` | 12 | `Q4_1` | 7.8 | 6.7 | 6.9 | 8.6 | 21 |
| `RWKV-4-Pile-169M` | 12 | `Q5_1` | 8.1 | 6.7 | 6.9 | 9.0 | 22 |
Latency per token in ms shown.

| Model | Layers on GPU | Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads |
|-----------------------|---------------|--------|----------|-----------|-----------|-----------|------------|
| `RWKV-4-Raven-7B-v11` | 32 | `Q4_0` | 59 | 51 | 50 | 54 | 94 |
| `RWKV-4-Raven-7B-v11` | 32 | `Q4_1` | 59 | 51 | 49 | 54 | 94 |
| `RWKV-4-Raven-7B-v11` | 32 | `Q5_1` | 77 | 69 | 67 | 72 | 101 |
| Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads |
|--------|----------|-----------|-----------|-----------|------------|
| `Q4_0` | 7.9 | 6.2 | 6.9 | 8.6 | 20 |
| `Q4_1` | 7.8 | 6.7 | 6.9 | 8.6 | 21 |
| `Q5_1` | 8.1 | 6.7 | 6.9 | 9.0 | 22 |

| Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads |
|--------|----------|-----------|-----------|-----------|------------|
| `Q4_0` | 59 | 51 | 50 | 54 | 94 |
| `Q4_1` | 59 | 51 | 49 | 54 | 94 |
| `Q5_1` | 77 | 69 | 67 | 72 | 101 |

Note: since cuBLAS is supported only for `ggml_mul_mat()`, we still need to use few CPU resources to execute remaining operations.

### With hipBLAS
Measurements were made on CPU AMD Ryzen 9 5900X & GPU AMD Radeon RX 7900 XTX. Latency per token in ms shown.

| Model | Layers on GPU | Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads |
|------------------------------------------|---------------|--------|----------|-----------|-----------|-----------|------------|
| `RWKV-novel-4-World-7B-20230810-ctx128k` | 32 | `f16` | 94 | 91 | 94 | 106 | 944 |
| `RWKV-novel-4-World-7B-20230810-ctx128k` | 32 | `Q4_0` | 83 | 77 | 75 | 110 | 1692 |
| `RWKV-novel-4-World-7B-20230810-ctx128k` | 32 | `Q4_1` | 85 | 80 | 85 | 93 | 1691 |
| `RWKV-novel-4-World-7B-20230810-ctx128k` | 32 | `Q5_1` | 83 | 78 | 83 | 90 | 1115 |
Measurements were made on CPU AMD Ryzen 9 5900X & GPU AMD Radeon RX 7900 XTX. The model is `RWKV-novel-4-World-7B-20230810-ctx128k`, 32 layers were offloaded to GPU.

Latency per token in ms shown.

| Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads |
|--------|----------|-----------|-----------|-----------|------------|
| `f16` | 94 | 91 | 94 | 106 | 944 |
| `Q4_0` | 83 | 77 | 75 | 110 | 1692 |
| `Q4_1` | 85 | 80 | 85 | 93 | 1691 |
| `Q5_1` | 83 | 78 | 83 | 90 | 1115 |

Note: hipBLAS is same as cuBLAS.They only support `ggml_mul_mat()`, we still need to use few CPU resources to execute remaining operations.
Note: same as cuBLAS, hipBLAS only supports `ggml_mul_mat()`, we still need to use few CPU resources to execute remaining operations.

## How to use

Expand Down
8 changes: 3 additions & 5 deletions docs/hipBLAS_on_Windows.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ Skip this step if you already have Build Tools installed.

To install Build Tools, go to [Visual Studio Older Downloads](https://visualstudio.microsoft.com/vs/), download `Visual Studio 2022 and other Products` and run the installer.


## CMake

Skip this step if you already have CMake installed: running `cmake --version` should output `cmake version x.y.z`.
Expand All @@ -21,7 +20,7 @@ Skip this step if you already have Build Tools installed.

The [validation tools](https://rocm.docs.amd.com/en/latest/reference/validation_tools.html) not support on Windows. So you should confirm the Version of `ROCM` by yourself.

Fortunately `AMD` provides complete help documentation, you can use the help documentation to install [ROCM](https://rocm.docs.amd.com/en/latest/deploy/windows/quick_start.html)
Fortunately, `AMD` provides complete help documentation, you can use the help documentation to install [ROCM](https://rocm.docs.amd.com/en/latest/deploy/windows/quick_start.html)

>**If you encounter an error, if it is [AMD ROCm Windows Installation Error 215](https://github.com/RadeonOpenCompute/ROCm/issues/2363), don't worry about this error. ROCM has been installed correctly, but the vs studio plugin installation failed, we can ignore it.**

Expand All @@ -39,8 +38,7 @@ set CXX=C:\Program Files\AMD\ROCm\5.5\bin\clang++.exe

Skip this step if you already have Ninja installed: running `ninja --version` should output `1.11.1`.

Download latest `ninja-win.zip` from [GitHub Releases Page](https://github.com/ninja-build/ninja/releases/tag/v1.11.1) and unzip.Then set as environment variables.
I unzipped it in `C:\Program Files\ninja`, so I set it like this:
Download latest `ninja-win.zip` from [GitHub Releases Page](https://github.com/ninja-build/ninja/releases/tag/v1.11.1) and unzip. Then set as environment variables. I unzipped it in `C:\Program Files\ninja`, so I set it like this:

```Commandline
set ninja=C:\Program Files\ninja\ninja.exe
Expand All @@ -50,7 +48,7 @@ set ninja=C:\Program Files\ninja\ninja.exe
The thing different from the regular CPU build is `-DRWKV_HIPBLAS=ON` ,
`-G "Ninja"`, `-DCMAKE_C_COMPILER=clang`, `-DCMAKE_CXX_COMPILER=clang++`, `-DAMDGPU_TARGETS=gfx1100`

>**Notice** check the `clang` and `clang++` information:
>**Notice**: check the `clang` and `clang++` information:
```Commandline
clang --version
clang++ --version
Expand Down
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated from d925ed to 4b20bb
33 changes: 27 additions & 6 deletions python/convert_pytorch_to_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
n_vocab: int = emb_weight.shape[0]
n_embed: int = emb_weight.shape[1]

is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict
is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict

if is_v5_2:
print('Detected RWKV v5.2')
elif is_v5_1_or_2:
print('Detected RWKV v5.1')
else:
print('Detected RWKV v4')

with open(dest_path, 'wb') as out_file:
is_FP16: bool = data_type == 'FP16' or data_type == 'float16'

Expand All @@ -50,16 +60,27 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
for k in state_dict.keys():
tensor: torch.Tensor = state_dict[k].float()

# Same processing as in "RWKV_in_150_lines.py"
if '.time_' in k:
# (1, 1, n_embed) -> (n_embed)
tensor = tensor.squeeze()

if '.time_decay' in k:
tensor = -torch.exp(tensor)
if is_v5_1_or_2:
if '.time_decay' in k:
if is_v5_2:
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
else:
tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1)

if '.time_first' in k:
tensor = torch.exp(tensor).reshape(-1, 1, 1)

if '.time_faaaa' in k:
tensor = tensor.unsqueeze(-1)
else:
if '.time_decay' in k:
tensor = -torch.exp(tensor)

# Keep 1-dim vectors in FP32
if is_FP16 and len(tensor.shape) > 1:
# Keep 1-dim vectors and small matrices in FP32
if is_FP16 and len(tensor.shape) > 1 and '.time_' not in k:
tensor = tensor.half()

shape = tensor.shape
Expand Down
25 changes: 21 additions & 4 deletions python/merge_lora_into_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
def parse_args():
parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file')
parser.add_argument('src_path', help='Path to source rwkv.cpp model')
parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2', type=str, choices=['v4', 'v5.1', 'v5.2'])
parser.add_argument('lora_path', help='Path to LoRA checkpoint in PyTorch format')
parser.add_argument('lora_alpha', type=int, help='Value of lora_alpha parameter used when training this LoRA checkpoint')
parser.add_argument('lora_alpha', help='Value of lora_alpha parameter used when training this LoRA checkpoint', type=int)
parser.add_argument('dest_path', help='Path to destination rwkv.cpp model, will be overwitten with the merged model')
return parser.parse_args()

Expand Down Expand Up @@ -44,6 +45,10 @@ def write_parameter(out_file, key: str, parameter: torch.Tensor) -> None:
def main() -> None:
args = parse_args()

arch_version: str = args.rwkv_arch_version

assert arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2', f'Invalid RWKV architecture version {arch_version}'

print(f'Reading {args.lora_path}')

lora_state_dict: Dict[str, torch.Tensor] = torch.load(args.lora_path, map_location='cpu')
Expand Down Expand Up @@ -96,11 +101,23 @@ def main() -> None:

# Same processing as in convert_pytorch_to_ggml.py
if '.time_' in key:
# (1, 1, n_embed) -> (n_embed)
replacement = replacement.squeeze()

if '.time_decay' in key:
replacement = -torch.exp(replacement)
if arch_version == 'v5.1' or arch_version == 'v5.2':
if '.time_decay' in key:
if arch_version == 'v5.2':
replacement = torch.exp(-torch.exp(replacement)).unsqueeze(-1)
else:
replacement = torch.exp(-torch.exp(replacement)).reshape(-1, 1, 1)

if '.time_first' in key:
replacement = torch.exp(replacement).reshape(-1, 1, 1)

if '.time_faaaa' in key:
replacement = replacement.unsqueeze(-1)
else:
if '.time_decay' in key:
replacement = -torch.exp(replacement)

if parameter.dtype == torch.float16:
replacement = replacement.half()
Expand Down
6 changes: 5 additions & 1 deletion rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ size_t rwkv_get_n_layer(const struct rwkv_context * ctx) {
size_t rwkv_get_state_len(const struct rwkv_context * ctx) {
const struct rwkv_file_header & header = ctx->model->header;

return (size_t) header.n_embed * 5 * (size_t) header.n_layer;
if (ctx->model->arch_version_major >= 5) {
return (size_t) header.n_embed * (2 + ctx->model->head_size) * (size_t) header.n_layer;
} else {
return (size_t) header.n_embed * 5 * (size_t) header.n_layer;
}
}

// API function.
Expand Down
10 changes: 6 additions & 4 deletions rwkv_eval.inc
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,18 @@ bool rwkv_eval_sequence_in_chunks(

// API function.
void rwkv_init_state(const struct rwkv_context * ctx, float * state) {
memset(state, 0, rwkv_get_state_len(ctx) * sizeof(float));

if (ctx->model->arch_version_major >= 5) {
return;
}

const struct rwkv_file_header & header = ctx->model->header;
const size_t layer_size = (size_t) header.n_embed * 5;
const size_t layer_zero = (size_t) header.n_embed * 4;
const size_t layers_size = (size_t) header.n_layer * layer_size;

for (size_t start = 0; start < layers_size; start += layer_size) {
for (size_t i = 0; i < layer_zero; i++) {
state[start + i] = 0.0F;
}

for (size_t i = layer_zero; i < layer_size; i++) {
state[start + i] = -1e30F;
}
Expand Down