diff --git a/CMakeLists.txt b/CMakeLists.txt index 1030529..5edab37 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -117,6 +117,13 @@ if (RWKV_CUBLAS) add_compile_definitions(GGML_USE_CUBLAS) + # By default, GGML_CUDA_MMV_Y is set to 1. This value leads to CUDA error on my machine: + # CUDA error 9 at ...\rwkv.cpp\ggml\src\ggml-cuda.cu:6107: invalid configuration argument + # The error appears when the head matrix of v5 3B and v5 7B models is offloaded. I guess the matrix is so large that block_num_y becomes too big. + # Changing it to 2 makes it work. I did not see any performance impact when measuring v5 3B & v5 7B. Hopefully, this will not break other use-cases. + # TODO Re-check after updating ggml whether this is needed + add_compile_definitions(GGML_CUDA_MMV_Y=2) + if (RWKV_STATIC) set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) else() @@ -191,6 +198,7 @@ if (RWKV_HIPBLAS) if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang") endif() + if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") endif() @@ -202,18 +210,22 @@ if (RWKV_HIPBLAS) if (${hipblas_FOUND} AND ${hip_FOUND}) message(STATUS "HIP and hipBLAS found") add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) - add_library(ggml-rocm OBJECT - ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu - ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.h ) + add_library(ggml-rocm OBJECT + ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu + ${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.h) + if (BUILD_SHARED_LIBS) set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() + target_include_directories(ggml-rocm PUBLIC ${CMAKE_SOURCE_DIR}/ggml/include/ggml) set_source_files_properties(${CMAKE_SOURCE_DIR}/ggml/src/ggml-cuda.cu PROPERTIES LANGUAGE CXX) target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) + if (RWKV_STATIC) message(FATAL_ERROR "Static linking not supported for HIP/ROCm") endif() + set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} ggml-rocm) else() message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") @@ -406,7 +418,7 @@ target_compile_features(ggml PUBLIC c_std_11) # Don't bump if (MSVC) target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads) else() - if(WIN32 AND RWKV_HIPBLAS) + if (WIN32 AND RWKV_HIPBLAS) target_link_libraries(ggml PUBLIC ${RWKV_EXTRA_LIBS} Threads::Threads) else() target_link_libraries(ggml PUBLIC m ${RWKV_EXTRA_LIBS} Threads::Threads) diff --git a/README.md b/README.md index 42ec5dc..3945ab8 100644 --- a/README.md +++ b/README.md @@ -6,17 +6,19 @@ Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](python%2Frwkv_cpp%2Frwkv_cpp_model.py) for it. -[RWKV](https://arxiv.org/abs/2305.13048) is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts. +[RWKV](https://arxiv.org/abs/2305.13048) is a large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts. -Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py). +[RWKV v5](https://huggingface.co/BlinkDL/rwkv-5-world) is a major upgrade to RWKV architecture, making it competitive with Transformers in quality. RWKV v5 models are supported. -⚠️ **Python API was restructured on 2023-09-20**, you may need to change paths/package names in your code when updating `rwkv.cpp`. +Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py). ## Quality and performance If you use `rwkv.cpp` for anything serious, please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you. -Below table is for reference only. Measurements were made on 4C/8T x86 CPU with AVX2, 4 threads. +In general, `RWKV v5` models are 2 times slower than `RWKV v4` models, and require from 1.5 times (sequence length = 1) to 6 times (sequence length = 64) more memory. + +Below table is for reference only. Measurements were made on 4C/8T x86 CPU with AVX2, 4 threads. The models are `RWKV v4 Pile 169M`, `RWKV v4 Pile 1.5B`. | Format | Perplexity (169M) | Latency, ms (1.5B) | File size, GB (1.5B) | |-----------|-------------------|--------------------|----------------------| @@ -30,33 +32,38 @@ Below table is for reference only. Measurements were made on 4C/8T x86 CPU with ### With cuBLAS -Measurements were made on Intel i7 13700K & NVIDIA 3060 Ti 8 GB. Latency per token in ms shown. +Measurements were made on Intel i7 13700K & NVIDIA 3060 Ti 8 GB. The model is `RWKV-4-Pile-169M`, 12 layers were offloaded to GPU. -| Model | Layers on GPU | Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads | -|-----------------------|---------------|--------|----------|-----------|-----------|-----------|------------| -| `RWKV-4-Pile-169M` | 12 | `Q4_0` | 7.9 | 6.2 | 6.9 | 8.6 | 20 | -| `RWKV-4-Pile-169M` | 12 | `Q4_1` | 7.8 | 6.7 | 6.9 | 8.6 | 21 | -| `RWKV-4-Pile-169M` | 12 | `Q5_1` | 8.1 | 6.7 | 6.9 | 9.0 | 22 | +Latency per token in ms shown. -| Model | Layers on GPU | Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads | -|-----------------------|---------------|--------|----------|-----------|-----------|-----------|------------| -| `RWKV-4-Raven-7B-v11` | 32 | `Q4_0` | 59 | 51 | 50 | 54 | 94 | -| `RWKV-4-Raven-7B-v11` | 32 | `Q4_1` | 59 | 51 | 49 | 54 | 94 | -| `RWKV-4-Raven-7B-v11` | 32 | `Q5_1` | 77 | 69 | 67 | 72 | 101 | +| Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads | +|--------|----------|-----------|-----------|-----------|------------| +| `Q4_0` | 7.9 | 6.2 | 6.9 | 8.6 | 20 | +| `Q4_1` | 7.8 | 6.7 | 6.9 | 8.6 | 21 | +| `Q5_1` | 8.1 | 6.7 | 6.9 | 9.0 | 22 | + +| Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads | +|--------|----------|-----------|-----------|-----------|------------| +| `Q4_0` | 59 | 51 | 50 | 54 | 94 | +| `Q4_1` | 59 | 51 | 49 | 54 | 94 | +| `Q5_1` | 77 | 69 | 67 | 72 | 101 | Note: since cuBLAS is supported only for `ggml_mul_mat()`, we still need to use few CPU resources to execute remaining operations. ### With hipBLAS -Measurements were made on CPU AMD Ryzen 9 5900X & GPU AMD Radeon RX 7900 XTX. Latency per token in ms shown. -| Model | Layers on GPU | Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads | -|------------------------------------------|---------------|--------|----------|-----------|-----------|-----------|------------| -| `RWKV-novel-4-World-7B-20230810-ctx128k` | 32 | `f16` | 94 | 91 | 94 | 106 | 944 | -| `RWKV-novel-4-World-7B-20230810-ctx128k` | 32 | `Q4_0` | 83 | 77 | 75 | 110 | 1692 | -| `RWKV-novel-4-World-7B-20230810-ctx128k` | 32 | `Q4_1` | 85 | 80 | 85 | 93 | 1691 | -| `RWKV-novel-4-World-7B-20230810-ctx128k` | 32 | `Q5_1` | 83 | 78 | 83 | 90 | 1115 | +Measurements were made on CPU AMD Ryzen 9 5900X & GPU AMD Radeon RX 7900 XTX. The model is `RWKV-novel-4-World-7B-20230810-ctx128k`, 32 layers were offloaded to GPU. + +Latency per token in ms shown. + +| Format | 1 thread | 2 threads | 4 threads | 8 threads | 24 threads | +|--------|----------|-----------|-----------|-----------|------------| +| `f16` | 94 | 91 | 94 | 106 | 944 | +| `Q4_0` | 83 | 77 | 75 | 110 | 1692 | +| `Q4_1` | 85 | 80 | 85 | 93 | 1691 | +| `Q5_1` | 83 | 78 | 83 | 90 | 1115 | -Note: hipBLAS is same as cuBLAS.They only support `ggml_mul_mat()`, we still need to use few CPU resources to execute remaining operations. +Note: same as cuBLAS, hipBLAS only supports `ggml_mul_mat()`, we still need to use few CPU resources to execute remaining operations. ## How to use diff --git a/docs/hipBLAS_on_Windows.md b/docs/hipBLAS_on_Windows.md index b4316ec..4c366af 100644 --- a/docs/hipBLAS_on_Windows.md +++ b/docs/hipBLAS_on_Windows.md @@ -8,7 +8,6 @@ Skip this step if you already have Build Tools installed. To install Build Tools, go to [Visual Studio Older Downloads](https://visualstudio.microsoft.com/vs/), download `Visual Studio 2022 and other Products` and run the installer. - ## CMake Skip this step if you already have CMake installed: running `cmake --version` should output `cmake version x.y.z`. @@ -21,7 +20,7 @@ Skip this step if you already have Build Tools installed. The [validation tools](https://rocm.docs.amd.com/en/latest/reference/validation_tools.html) not support on Windows. So you should confirm the Version of `ROCM` by yourself. -Fortunately `AMD` provides complete help documentation, you can use the help documentation to install [ROCM](https://rocm.docs.amd.com/en/latest/deploy/windows/quick_start.html) +Fortunately, `AMD` provides complete help documentation, you can use the help documentation to install [ROCM](https://rocm.docs.amd.com/en/latest/deploy/windows/quick_start.html) >**If you encounter an error, if it is [AMD ROCm Windows Installation Error 215](https://github.com/RadeonOpenCompute/ROCm/issues/2363), don't worry about this error. ROCM has been installed correctly, but the vs studio plugin installation failed, we can ignore it.** @@ -39,8 +38,7 @@ set CXX=C:\Program Files\AMD\ROCm\5.5\bin\clang++.exe Skip this step if you already have Ninja installed: running `ninja --version` should output `1.11.1`. -Download latest `ninja-win.zip` from [GitHub Releases Page](https://github.com/ninja-build/ninja/releases/tag/v1.11.1) and unzip.Then set as environment variables. -I unzipped it in `C:\Program Files\ninja`, so I set it like this: +Download latest `ninja-win.zip` from [GitHub Releases Page](https://github.com/ninja-build/ninja/releases/tag/v1.11.1) and unzip. Then set as environment variables. I unzipped it in `C:\Program Files\ninja`, so I set it like this: ```Commandline set ninja=C:\Program Files\ninja\ninja.exe @@ -50,7 +48,7 @@ set ninja=C:\Program Files\ninja\ninja.exe The thing different from the regular CPU build is `-DRWKV_HIPBLAS=ON` , `-G "Ninja"`, `-DCMAKE_C_COMPILER=clang`, `-DCMAKE_CXX_COMPILER=clang++`, `-DAMDGPU_TARGETS=gfx1100` ->**Notice** check the `clang` and `clang++` information: +>**Notice**: check the `clang` and `clang++` information: ```Commandline clang --version clang++ --version diff --git a/ggml b/ggml index d925ed7..4b20bbd 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit d925ed7a96767192d422a97645f08ad86d5cc6f0 +Subproject commit 4b20bbdf1b6e586addf9d065518b594e94dfa43f diff --git a/python/convert_pytorch_to_ggml.py b/python/convert_pytorch_to_ggml.py index ed8bce0..9956844 100644 --- a/python/convert_pytorch_to_ggml.py +++ b/python/convert_pytorch_to_ggml.py @@ -32,6 +32,16 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t n_vocab: int = emb_weight.shape[0] n_embed: int = emb_weight.shape[1] + is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict + is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict + + if is_v5_2: + print('Detected RWKV v5.2') + elif is_v5_1_or_2: + print('Detected RWKV v5.1') + else: + print('Detected RWKV v4') + with open(dest_path, 'wb') as out_file: is_FP16: bool = data_type == 'FP16' or data_type == 'float16' @@ -50,16 +60,27 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t for k in state_dict.keys(): tensor: torch.Tensor = state_dict[k].float() - # Same processing as in "RWKV_in_150_lines.py" if '.time_' in k: - # (1, 1, n_embed) -> (n_embed) tensor = tensor.squeeze() - if '.time_decay' in k: - tensor = -torch.exp(tensor) + if is_v5_1_or_2: + if '.time_decay' in k: + if is_v5_2: + tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1) + else: + tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1) + + if '.time_first' in k: + tensor = torch.exp(tensor).reshape(-1, 1, 1) + + if '.time_faaaa' in k: + tensor = tensor.unsqueeze(-1) + else: + if '.time_decay' in k: + tensor = -torch.exp(tensor) - # Keep 1-dim vectors in FP32 - if is_FP16 and len(tensor.shape) > 1: + # Keep 1-dim vectors and small matrices in FP32 + if is_FP16 and len(tensor.shape) > 1 and '.time_' not in k: tensor = tensor.half() shape = tensor.shape diff --git a/python/merge_lora_into_ggml.py b/python/merge_lora_into_ggml.py index 6f8ce2d..4f57d7c 100644 --- a/python/merge_lora_into_ggml.py +++ b/python/merge_lora_into_ggml.py @@ -13,8 +13,9 @@ def parse_args(): parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file') parser.add_argument('src_path', help='Path to source rwkv.cpp model') + parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2', type=str, choices=['v4', 'v5.1', 'v5.2']) parser.add_argument('lora_path', help='Path to LoRA checkpoint in PyTorch format') - parser.add_argument('lora_alpha', type=int, help='Value of lora_alpha parameter used when training this LoRA checkpoint') + parser.add_argument('lora_alpha', help='Value of lora_alpha parameter used when training this LoRA checkpoint', type=int) parser.add_argument('dest_path', help='Path to destination rwkv.cpp model, will be overwitten with the merged model') return parser.parse_args() @@ -44,6 +45,10 @@ def write_parameter(out_file, key: str, parameter: torch.Tensor) -> None: def main() -> None: args = parse_args() + arch_version: str = args.rwkv_arch_version + + assert arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2', f'Invalid RWKV architecture version {arch_version}' + print(f'Reading {args.lora_path}') lora_state_dict: Dict[str, torch.Tensor] = torch.load(args.lora_path, map_location='cpu') @@ -96,11 +101,23 @@ def main() -> None: # Same processing as in convert_pytorch_to_ggml.py if '.time_' in key: - # (1, 1, n_embed) -> (n_embed) replacement = replacement.squeeze() - if '.time_decay' in key: - replacement = -torch.exp(replacement) + if arch_version == 'v5.1' or arch_version == 'v5.2': + if '.time_decay' in key: + if arch_version == 'v5.2': + replacement = torch.exp(-torch.exp(replacement)).unsqueeze(-1) + else: + replacement = torch.exp(-torch.exp(replacement)).reshape(-1, 1, 1) + + if '.time_first' in key: + replacement = torch.exp(replacement).reshape(-1, 1, 1) + + if '.time_faaaa' in key: + replacement = replacement.unsqueeze(-1) + else: + if '.time_decay' in key: + replacement = -torch.exp(replacement) if parameter.dtype == torch.float16: replacement = replacement.half() diff --git a/rwkv.cpp b/rwkv.cpp index f7406bf..124602d 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -121,7 +121,11 @@ size_t rwkv_get_n_layer(const struct rwkv_context * ctx) { size_t rwkv_get_state_len(const struct rwkv_context * ctx) { const struct rwkv_file_header & header = ctx->model->header; - return (size_t) header.n_embed * 5 * (size_t) header.n_layer; + if (ctx->model->arch_version_major >= 5) { + return (size_t) header.n_embed * (2 + ctx->model->head_size) * (size_t) header.n_layer; + } else { + return (size_t) header.n_embed * 5 * (size_t) header.n_layer; + } } // API function. diff --git a/rwkv_eval.inc b/rwkv_eval.inc index 37f3a9c..7997900 100644 --- a/rwkv_eval.inc +++ b/rwkv_eval.inc @@ -176,16 +176,18 @@ bool rwkv_eval_sequence_in_chunks( // API function. void rwkv_init_state(const struct rwkv_context * ctx, float * state) { + memset(state, 0, rwkv_get_state_len(ctx) * sizeof(float)); + + if (ctx->model->arch_version_major >= 5) { + return; + } + const struct rwkv_file_header & header = ctx->model->header; const size_t layer_size = (size_t) header.n_embed * 5; const size_t layer_zero = (size_t) header.n_embed * 4; const size_t layers_size = (size_t) header.n_layer * layer_size; for (size_t start = 0; start < layers_size; start += layer_size) { - for (size_t i = 0; i < layer_zero; i++) { - state[start + i] = 0.0F; - } - for (size_t i = layer_zero; i < layer_size; i++) { state[start + i] = -1e30F; } diff --git a/rwkv_file_format.inc b/rwkv_file_format.inc index d9b9d4c..390f7b8 100644 --- a/rwkv_file_format.inc +++ b/rwkv_file_format.inc @@ -129,21 +129,31 @@ struct rwkv_tensor_header { uint32_t dim_count; uint32_t key_length; uint32_t data_type; - uint32_t width; - uint32_t height; + uint32_t size0; + uint32_t size1; + uint32_t size2; size_t size() const; }; size_t rwkv_tensor_header::size() const { - return rwkv_tensor_nbytes(rwkv_type_to_ggml[this->data_type], this->width, this->height); + return rwkv_tensor_nbytes(rwkv_type_to_ggml[this->data_type], this->size0, this->size1, this->size2); } static bool rwkv_fread_tensor_header(FILE * file, struct rwkv_tensor_header & header) { - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_tensor_header) - sizeof(uint32_t), &header)); - header.height = 1; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_SHAPE, header.dim_count == 1 || header.dim_count == 2, "Tensor has an invalid shape (%" PRId32 " dimensions)", header.dim_count); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_tensor_header) - sizeof(uint32_t) * 2, &header)); + header.size1 = 1; + header.size2 = 1; + + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_SHAPE, + header.dim_count == 1 || header.dim_count == 2 || header.dim_count == 3, + "Tensor has an invalid shape (%" PRId32 " dimensions)", + header.dim_count + ); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, header.data_type < TYPE_COUNT, "Tensor data type out of range (%" PRId32 " > %" PRId32 ")", header.data_type, TYPE_COUNT - 1); + RWKV_ASSERT_FALSE_MSG( RWKV_ERROR_DATA_TYPE, rwkv_type_to_ggml[header.data_type] != GGML_TYPE_UNKNOWN, @@ -151,15 +161,29 @@ static bool rwkv_fread_tensor_header(FILE * file, struct rwkv_tensor_header & he rwkv_type_to_string[header.data_type] ); - if (header.dim_count == 2) { - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_uint32(file, header.height)); + if (header.dim_count >= 2) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_uint32(file, header.size1)); + } + + if (header.dim_count >= 3) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_uint32(file, header.size2)); } return true; } static bool rwkv_fwrite_tensor_header(FILE * file, const struct rwkv_tensor_header & header) { - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_data(file, &header, sizeof(struct rwkv_tensor_header) - (header.dim_count == 1 ? sizeof(uint32_t) : 0))); + size_t sub; + + if (header.dim_count == 1) { + sub = sizeof(uint32_t) * 2; + } else if (header.dim_count == 2) { + sub = sizeof(uint32_t); + } else { + sub = 0; + } + + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_data(file, &header, sizeof(struct rwkv_tensor_header) - sub)); return true; } @@ -204,9 +228,13 @@ static bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std:: name.c_str() ); - tensor = header.dim_count == 1 - ? ggml_new_tensor_1d(ctx, ggml_type, header.width) - : ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height); + if (header.dim_count == 1) { + tensor = ggml_new_tensor_1d(ctx, ggml_type, header.size0); + } else if (header.dim_count == 2) { + tensor = ggml_new_tensor_2d(ctx, ggml_type, header.size0, header.size1); + } else { + tensor = ggml_new_tensor_3d(ctx, ggml_type, header.size0, header.size1, header.size2); + } RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor != NULL, "Failed to allocate tensor"); diff --git a/rwkv_gpu_offload.inc b/rwkv_gpu_offload.inc index cc47c64..e0b54b1 100644 --- a/rwkv_gpu_offload.inc +++ b/rwkv_gpu_offload.inc @@ -40,6 +40,10 @@ bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) offload(layer.att_receptance); offload(layer.att_output); + if (layer.att_gate != NULL) { + offload(layer.att_gate); + } + offload(layer.ffn_key); offload(layer.ffn_value); offload(layer.ffn_receptance); diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 095e503..a8a8c6d 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -2,9 +2,12 @@ struct rwkv_layer_state { struct ggml_tensor * ffn_xx; struct ggml_tensor * att_xx; + // Used in RWKV v4. struct ggml_tensor * att_aa; struct ggml_tensor * att_bb; struct ggml_tensor * att_pp; + // Used in RWKV v5+. + struct ggml_tensor * att_heads; }; // The computation graph holds ggml context and the ggml cgraph. @@ -71,7 +74,7 @@ static void rwkv_carry_x( carry = x; } else { // self.layer_norm(x, self.w.blocks[i].ln2) - x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, weight, x), ggml_repeat(ctx, bias, x)); + x = rwkv_layer_norm(ctx, x, weight, bias); // xx = torch.cat((state[5*i+0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:])) x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len); @@ -111,7 +114,7 @@ static void rwkv_att_rkv( ); // r = torch.sigmoid(rw @ xr) - r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr)); + r = rwkv_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr)); // k = kw @ xk k = ggml_mul_mat(ctx, layer.att_key, xk); // v = vw @ xv @@ -175,6 +178,175 @@ static struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tens return ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv)); } +static struct ggml_tensor * rwkv_att_v5( + struct ggml_context * ctx, + struct ggml_tensor * x, + struct rwkv_layer layer, + struct rwkv_layer_state & state, + const int64_t head_count, + const int64_t head_size, + const uint32_t arch_version_minor +) { + size_t n_embed = x->ne[0]; + size_t sequence_length = x->ne[1]; + + x = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); + + struct ggml_tensor * x_prev; + + if (sequence_length > 1) { + x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_length); + x_prev = ggml_set_1d_inplace(ctx, x_prev, state.att_xx, 0); + x_prev = ggml_set_1d_inplace( + ctx, + x_prev, + ggml_view_1d(ctx, x, n_embed * (sequence_length - 1), 0), n_embed * sizeof(float) + ); + } else { + x_prev = state.att_xx; + } + + struct ggml_tensor * xk = ggml_add_inplace( + ctx, + ggml_mul(ctx, x, layer.att_time_mix_k), + ggml_mul( + ctx, + x_prev, + rwkv_1_minus_x(ctx, layer.att_time_mix_k) + ) + ); + + struct ggml_tensor * xv = ggml_add_inplace( + ctx, + ggml_mul(ctx, x, layer.att_time_mix_v), + ggml_mul( + ctx, + x_prev, + rwkv_1_minus_x(ctx, layer.att_time_mix_v) + ) + ); + + struct ggml_tensor * xr = ggml_add_inplace( + ctx, + ggml_mul(ctx, x, layer.att_time_mix_r), + ggml_mul( + ctx, + x_prev, + rwkv_1_minus_x(ctx, layer.att_time_mix_r) + ) + ); + + struct ggml_tensor * xg = NULL; + + if (arch_version_minor >= 2) { + xg = ggml_add_inplace( + ctx, + ggml_mul(ctx, x, layer.att_time_mix_g), + ggml_mul( + ctx, + x_prev, + rwkv_1_minus_x(ctx, layer.att_time_mix_g) + ) + ); + } + + state.att_xx = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_length - 1) * sizeof(float)); + + struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, 1, head_count, sequence_length); + struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), 1, head_size, head_count, sequence_length); + struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_value, xv), head_size, 1, head_count, sequence_length); + + struct ggml_tensor * g = NULL; + + if (arch_version_minor >= 2) { + g = ggml_silu_inplace( + ctx, + ggml_mul_mat(ctx, layer.att_gate, xg) + ); + } + + struct ggml_tensor * tf = layer.att_time_faaaa != NULL ? + layer.att_time_faaaa : + layer.att_time_first; + + struct ggml_tensor * a = rwkv_transpose_then_cont( + ctx, + ggml_mul_mat( + ctx, + k, + rwkv_transpose_then_cont(ctx, v) + ) + ); + + struct ggml_tensor * tf_a = ggml_mul_inplace( + ctx, + ggml_repeat(ctx, tf, a), + a + ); + + struct ggml_tensor * x_new = ggml_new_tensor_2d(ctx, x->type, n_embed, sequence_length); + + struct ggml_tensor * last_state = state.att_heads; + + for (size_t t = 0; t < sequence_length; t++) { + struct ggml_tensor * s = ggml_reshape_3d(ctx, last_state, head_size, head_size, head_count); + + struct ggml_tensor * tf_a_s = ggml_add_inplace( + ctx, + rwkv_get_from_dim_3(ctx, tf_a, t), + s + ); + + struct ggml_tensor * x_new_vector = ggml_mul_mat( + ctx, + rwkv_get_from_dim_3(ctx, r, t), + rwkv_transpose_then_cont(ctx, tf_a_s) + ); + + struct ggml_tensor * td_s = ggml_mul_inplace( + ctx, + ggml_repeat(ctx, layer.att_time_decay, s), + s + ); + + s = ggml_add_inplace(ctx, td_s, rwkv_get_from_dim_3(ctx, a, t)); + + last_state = s; + + x_new = ggml_set_1d_inplace( + ctx, + x_new, + rwkv_flatten(ctx, x_new_vector), + t * n_embed * sizeof(float) + ); + } + + state.att_heads = last_state; + + x = x_new; + + // ggml_group_norm considers groups in the third dimension. + x = ggml_reshape_4d(ctx, x, 1, 1, n_embed, sequence_length); + x = ggml_group_norm_inplace(ctx, x, head_count); + // Convert back to a regular vector. + x = ggml_reshape_2d(ctx, x, n_embed, sequence_length); + x = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + x, + layer.att_ln_x_weight + ), + layer.att_ln_x_bias + ); + + if (arch_version_minor >= 2) { + x = ggml_mul_inplace(ctx, x, g); + } + + return ggml_mul_mat(ctx, layer.att_output, x); +} + static struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { struct ggml_tensor * x_prev; rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); @@ -195,7 +367,7 @@ static struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tens ); // r = torch.sigmoid(rw @ xr) - struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); + struct ggml_tensor * r = rwkv_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); // k = torch.square(torch.relu(kw @ xk)) struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); @@ -211,22 +383,43 @@ static void rwkv_create_input_and_output_views( struct ggml_tensor * input, struct ggml_tensor * output, const size_t n_layer, - const size_t n_embed + const size_t n_embed, + const uint32_t arch_version_major, + const int64_t head_count, + const int64_t head_size ) { + size_t sz_float = sizeof(float); + for (size_t i = 0; i < n_layer; i++) { struct rwkv_layer_state & input_state = inputs[i]; - input_state.ffn_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 0) * sizeof(float)); - input_state.att_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 1) * sizeof(float)); - input_state.att_aa = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 2) * sizeof(float)); - input_state.att_bb = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 3) * sizeof(float)); - input_state.att_pp = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 4) * sizeof(float)); - struct rwkv_layer_state & output_state = outputs[i]; - output_state.ffn_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 0) * sizeof(float)); - output_state.att_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 1) * sizeof(float)); - output_state.att_aa = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 2) * sizeof(float)); - output_state.att_bb = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 3) * sizeof(float)); - output_state.att_pp = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 4) * sizeof(float)); + + if (arch_version_major >= 5) { + size_t vectors_per_layer = 2 + head_size; + + size_t att_heads_size = head_size * head_size * head_count; + + input_state.ffn_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * vectors_per_layer + 0) * sz_float); + input_state.att_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * vectors_per_layer + 1) * sz_float); + input_state.att_heads = ggml_view_1d(ctx, input, att_heads_size, n_embed * (i * vectors_per_layer + 2) * sz_float); + + output_state.ffn_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * vectors_per_layer + 0) * sz_float); + output_state.att_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * vectors_per_layer + 1) * sz_float); + output_state.att_heads = ggml_view_1d(ctx, output, att_heads_size, n_embed * (i * vectors_per_layer + 2) * sz_float); + } else { + input_state.ffn_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 0) * sz_float); + input_state.att_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 1) * sz_float); + input_state.att_aa = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 2) * sz_float); + input_state.att_bb = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 3) * sz_float); + input_state.att_pp = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 4) * sz_float); + + output_state.ffn_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 0) * sz_float); + output_state.att_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 1) * sz_float); + output_state.att_aa = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 2) * sz_float); + output_state.att_bb = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 3) * sz_float); + output_state.att_pp = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 4) * sz_float); + } + } } @@ -246,8 +439,12 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu // Creates a 1-element tensor. graph.tokens = ggml_new_i32(ctx, 0); - struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); - struct ggml_tensor * output = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); + size_t vectors_per_layer = model.arch_version_major >= 5 ? + 2 + model.head_size : + 5; + + struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * vectors_per_layer * n_layer); + struct ggml_tensor * output = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * vectors_per_layer * n_layer); // We collect parts of input state here. Each part is (n_embed) vector. std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); @@ -257,7 +454,7 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); - rwkv_create_input_and_output_views(ctx, inputs.get(), outputs.get(), input, output, n_layer, n_embed); + rwkv_create_input_and_output_views(ctx, inputs.get(), outputs.get(), input, output, n_layer, n_embed, model.arch_version_major, model.head_count, model.head_size); graph.logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab); @@ -267,19 +464,37 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu // x = self.layer_norm(x, self.w.blocks[0].ln0) x = rwkv_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); - for (size_t i = 0; i < model.header.n_layer; i++) { + for (size_t i = 0; i < n_layer; i++) { struct rwkv_layer & layer = model.layers[i]; struct rwkv_layer_state state = inputs[i]; - x = ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state)); + + x = model.arch_version_major >= 5 ? + ggml_add_inplace(ctx, x, rwkv_att_v5( + ctx, + x, + layer, + state, + model.head_count, + model.head_size, + model.arch_version_minor + )) : + ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state)); + x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); struct rwkv_layer_state & output_state = outputs[i]; + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_xx, output_state.att_xx)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_aa, output_state.att_aa)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_bb, output_state.att_bb)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_pp, output_state.att_pp)); + + if (model.arch_version_major >= 5) { + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_heads, output_state.att_heads)); + } else { + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_aa, output_state.att_aa)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_bb, output_state.att_bb)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_pp, output_state.att_pp)); + } } graph.pre_logits_nodes = graph.cgraph->n_nodes; @@ -319,18 +534,11 @@ static bool rwkv_measure_and_build_serial_context(struct rwkv_model & model, str RWKV_ENSURE_OR_FALSE(rwkv_build_serial_graph(model, graph)); - struct ggml_allocr * allocator = ggml_allocr_new_measure(tensor_alignment); - - size_t required_context_size = ggml_allocr_alloc_graph(allocator, graph.cgraph.get()) + + size_t required_context_size = ggml_total_size_for_tensor_data(graph.ggml_ctx) + // With the node limit set 80K, this overhead would be 28 MB. + rwkv_ggml_overhead() - + tensor_alignment - // For some reason, `ggml_allocr_alloc_graph` underestimates required memory amount. - // Instead of diving deep into ggml internals to debug this issue, I will just add some padding. - // 40 MB seems to be enough for Raven 14B model when GGML_MAX_NODES is set to default value of 4096. - + size_t(40) * 1024 * 1024; + + tensor_alignment; - ggml_allocr_free(allocator); ggml_free(graph.ggml_ctx); // 2. Create the real ggml context. @@ -356,8 +564,12 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c graph.tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sequence_length); - struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); - struct ggml_tensor * output = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); + size_t vectors_per_layer = model.arch_version_major >= 5 ? + 2 + model.head_size : + 5; + + struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * vectors_per_layer * n_layer); + struct ggml_tensor * output = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * vectors_per_layer * n_layer); // We collect parts of input state here. Each part is (n_embed) vector. std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); @@ -367,7 +579,7 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); - rwkv_create_input_and_output_views(ctx, inputs.get(), outputs.get(), input, output, n_layer, n_embed); + rwkv_create_input_and_output_views(ctx, inputs.get(), outputs.get(), input, output, n_layer, n_embed, model.arch_version_major, model.head_count, model.head_size); graph.logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab); @@ -379,33 +591,54 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c for (size_t i = 0; i < model.header.n_layer; i++) { struct rwkv_layer & layer = model.layers[i]; - struct rwkv_layer_state state = inputs[i]; - - struct ggml_tensor * x0 = x, * x_prev; - rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x0, x_prev, state.att_xx); - - struct ggml_tensor * r, * k, * v; - rwkv_att_rkv(ctx, layer, x0, x_prev, r, k, v); - ggml_build_forward_expand(graph.cgraph.get(), r); + struct rwkv_layer_state state = inputs[i]; - for (uint32_t t = 0; t < sequence_length; t++) { - struct ggml_tensor * kt = ggml_view_1d(ctx, k, n_embed, n_embed * sizeof(float) * t); - struct ggml_tensor * vt = ggml_view_1d(ctx, v, n_embed, n_embed * sizeof(float) * t); - struct ggml_tensor * xt = ggml_view_1d(ctx, x_prev, n_embed, n_embed * sizeof(float) * t); - struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, kt, vt, state.att_aa, state.att_bb, state.att_pp); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, wkv, xt)); + if (model.arch_version_major >= 5) { + x = ggml_add_inplace(ctx, x, rwkv_att_v5( + ctx, + x, + layer, + state, + model.head_count, + model.head_size, + model.arch_version_minor + )); + } else { + struct ggml_tensor * x0 = x, * x_prev; + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x0, x_prev, state.att_xx); + + struct ggml_tensor * r, * k, * v; + rwkv_att_rkv(ctx, layer, x0, x_prev, r, k, v); + + ggml_build_forward_expand(graph.cgraph.get(), r); + + for (size_t t = 0; t < sequence_length; t++) { + struct ggml_tensor * kt = ggml_view_1d(ctx, k, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * vt = ggml_view_1d(ctx, v, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * xt = ggml_view_1d(ctx, x_prev, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, kt, vt, state.att_aa, state.att_bb, state.att_pp); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, wkv, xt)); + } + + x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, x_prev))); } - x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, x_prev))); + // TODO Can we skip ffn for all but the last token, the same way we skip unembedding? x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); struct rwkv_layer_state & output_state = outputs[i]; - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_xx, output_state.att_xx)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_aa, output_state.att_aa)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_bb, output_state.att_bb)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_pp, output_state.att_pp)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); + + if (model.arch_version_major >= 5) { + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_heads, output_state.att_heads)); + } else { + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_aa, output_state.att_aa)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_bb, output_state.att_bb)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_pp, output_state.att_pp)); + } } graph.pre_logits_nodes = graph.cgraph->n_nodes; @@ -442,18 +675,11 @@ static bool rwkv_measure_and_build_sequential_context(struct rwkv_model & model, RWKV_ENSURE_OR_FALSE(rwkv_build_sequential_graph(model, graph, sequence_length)); - struct ggml_allocr * allocator = ggml_allocr_new_measure(tensor_alignment); - - size_t required_context_size = ggml_allocr_alloc_graph(allocator, graph.cgraph.get()) + + size_t required_context_size = ggml_total_size_for_tensor_data(graph.ggml_ctx) + // With the node limit set 80K, this overhead would be 28 MB. + rwkv_ggml_overhead() - + tensor_alignment - // For some reason, `ggml_allocr_alloc_graph` underestimates required memory amount. - // Instead of diving deep into ggml internals to debug this issue, I will just add some padding. - // 40 MB per token seems to be enough for Raven 14B model. It works for sequence_length at least up to 71. - + sequence_length * 40 * 1024 * 1024; + + tensor_alignment; - ggml_allocr_free(allocator); ggml_free(graph.ggml_ctx); // 2. Create the real ggml context. diff --git a/rwkv_model_loading.inc b/rwkv_model_loading.inc index 3cf0392..fef0ea9 100644 --- a/rwkv_model_loading.inc +++ b/rwkv_model_loading.inc @@ -6,6 +6,7 @@ struct rwkv_layer { struct ggml_tensor * att_time_mix_k; struct ggml_tensor * att_time_mix_v; struct ggml_tensor * att_time_mix_r; + // Removed in RWKV v5.2; set to NULL for this and newer models. struct ggml_tensor * att_time_first; struct ggml_tensor * att_time_decay; struct ggml_tensor * att_key; @@ -13,6 +14,15 @@ struct rwkv_layer { struct ggml_tensor * att_receptance; struct ggml_tensor * att_output; + // Added in RWKV v5.1; set to NULL for earlier models (v4). + struct ggml_tensor * att_ln_x_weight; + struct ggml_tensor * att_ln_x_bias; + + // Added in RWKV v5.2; set to NULL for earlier models (v4, v5.1). + struct ggml_tensor * att_time_faaaa; + struct ggml_tensor * att_time_mix_g; + struct ggml_tensor * att_gate; + struct ggml_tensor * ln2_weight; struct ggml_tensor * ln2_bias; @@ -32,6 +42,11 @@ struct rwkv_model { struct ggml_context * ggml_ctx; struct rwkv_file_header header; + uint32_t arch_version_major; + uint32_t arch_version_minor; + // Added in RWKV v5.1; set to 0 for earlier models (v4). + int64_t head_count; + int64_t head_size; struct ggml_tensor * emb; @@ -74,7 +89,7 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.bias", model.ln0_bias)); uint32_t n_layer = model.header.n_layer; - std::unique_ptr layers(new(std::nothrow) struct rwkv_layer[n_layer]); + std::unique_ptr layers(new(std::nothrow) struct rwkv_layer[n_layer]()); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, layers.get(), "Failed to allocate model layers"); model.layers = std::move(layers); @@ -89,13 +104,29 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first)); + + if (model.arch_version_major >= 5 && model.arch_version_minor >= 2) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_faaaa"), buffer), layer.att_time_faaaa)); + } else { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first)); + } + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output)); + if (model.arch_version_major >= 5) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias)); + + if (model.arch_version_minor >= 2) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_g"), buffer), layer.att_time_mix_g)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate)); + } + } + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias)); @@ -142,13 +173,34 @@ static bool rwkv_load_model_from_file(const char * file_path, struct rwkv_model parameters[std::move(name)] = tensor; } + model.arch_version_major = 4; + model.arch_version_minor = 0; + + if (parameters.find("blocks.0.att.ln_x.weight") != parameters.end()) { + model.arch_version_major = 5; + + if (parameters.find("blocks.0.att.gate.weight") != parameters.end()) { + model.arch_version_minor = 2; + } else { + model.arch_version_minor = 1; + } + } + std::unordered_map & parameters_ref = parameters; - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(model, [&](const char * key, struct ggml_tensor *& dest) { - struct ggml_tensor * tensor = parameters_ref[key]; - RWKV_ENSURE_OR_FALSE_MSG(tensor, "Model parameter %s not found", key); - dest = tensor; - return true; - })); + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params( + model, + [&](const char * key, struct ggml_tensor *& dest) { + struct ggml_tensor * tensor = parameters_ref[key]; + RWKV_ENSURE_OR_FALSE_MSG(tensor, "Model parameter %s not found", key); + dest = tensor; + return true; + } + )); + + if (model.arch_version_major >= 5) { + model.head_count = model.layers[0].att_time_decay->ne[2]; + model.head_size = model.layers[0].ln1_weight->ne[0] / model.head_count; + } // Verify order of dimensions. struct ggml_tensor * emb = model.emb; diff --git a/rwkv_operators.inc b/rwkv_operators.inc index c24c91d..0862035 100644 --- a/rwkv_operators.inc +++ b/rwkv_operators.inc @@ -96,8 +96,8 @@ struct ggml_tensor * rwkv_1_minus_x(struct ggml_context * ctx, struct ggml_tenso } // Element-wise sigmoid(x) -struct ggml_tensor * rwkv_sigmoid(struct ggml_context * ctx, struct ggml_tensor * x) { - return ggml_map_custom1(ctx, x, rwkv_sigmoid_impl, 1, NULL); +struct ggml_tensor * rwkv_sigmoid_inplace(struct ggml_context * ctx, struct ggml_tensor * x) { + return ggml_map_custom1_inplace(ctx, x, rwkv_sigmoid_impl, 1, NULL); } // Element-wise max(x, y) @@ -110,3 +110,26 @@ struct ggml_tensor * rwkv_layer_norm(struct ggml_context * ctx, struct ggml_tens // Looks like ggml_norm does the first part, we only need to apply weight & bias. return ggml_add_inplace(ctx, ggml_mul_inplace(ctx, ggml_norm(ctx, x, 1e-5F), weight), bias); } + +static struct ggml_tensor * rwkv_transpose_then_cont(struct ggml_context * ctx, struct ggml_tensor * x) { + return ggml_cont(ctx, ggml_transpose(ctx, x)); +} + +static struct ggml_tensor * rwkv_get_from_dim_3(struct ggml_context * ctx, struct ggml_tensor * x, int64_t index) { + return ggml_view_4d( + ctx, + x, + x->ne[0], + x->ne[1], + x->ne[2], + 1, + x->nb[1], + x->nb[2], + x->nb[3], + index * (x->ne[0] * x->ne[1] * x->ne[2]) * sizeof(float) + ); +} + +static struct ggml_tensor * rwkv_flatten(struct ggml_context * ctx, struct ggml_tensor * x) { + return ggml_view_1d(ctx, x, ggml_nelements(x), 0); +} diff --git a/rwkv_quantize.inc b/rwkv_quantize.inc index 93ab098..b3ddfc9 100644 --- a/rwkv_quantize.inc +++ b/rwkv_quantize.inc @@ -66,14 +66,14 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const max_out_size = in_size; } - size_t f32_size = rwkv_tensor_nbytes(GGML_TYPE_F32, header.width, header.height); + size_t f32_size = rwkv_tensor_nbytes(GGML_TYPE_F32, header.size0, header.size1, header.size2); if (f32_size > max_in_size) { max_in_size = f32_size; } } - size_t out_size = rwkv_tensor_nbytes(out_type, header.width, header.height); + size_t out_size = rwkv_tensor_nbytes(out_type, header.size0, header.size1, header.size2); if (out_size > max_out_size) { max_out_size = out_size; @@ -105,7 +105,15 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(in_file.file, header.key_length, name), "Failed to read tensor name"); const char * name_str = name.c_str(); - RWKV_MSG("%*s - [%5" PRId32 ", %5" PRId32 "], type = %6s ", (int) max_key_length, name_str, header.width, header.height, rwkv_type_to_string[header.data_type]); + RWKV_MSG( + "%*s - [%5" PRId32 ", %5" PRId32 ", %5" PRId32 "], type = %6s ", + (int) max_key_length, + name_str, + header.size0, + header.size1, + header.size2, + rwkv_type_to_string[header.data_type] + ); data = header.data_type == TYPE_FP16 ? out_buf : in_buf; size_t orig_size = header.size(), new_size = orig_size; @@ -114,6 +122,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const // Quantize only 2D tensors, except embedding and head matrices. // Embedding and head take not too much space, especially in bigger models; // but they significantly increase perplexity when quantized. + // In RWKV v5, time_decay and time_first/time_faaaa are 3D tensors, so they are not quantized. if ((header.data_type == TYPE_FP32 || header.data_type == TYPE_FP16) && header.dim_count == 2 && name != "emb.weight" && @@ -121,7 +130,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const ) { RWKV_MSG("quantizing... "); - size_t nelements = (size_t) header.width * (size_t) header.height; + size_t nelements = (size_t) header.size0 * (size_t) header.size1 * (size_t) header.size2; if (header.data_type == TYPE_FP16) { ggml_fp16_to_fp32_row((const ggml_fp16_t *) out_buf, (float *) in_buf, nelements); diff --git a/rwkv_utilities.inc b/rwkv_utilities.inc index 9b10b22..44be324 100644 --- a/rwkv_utilities.inc +++ b/rwkv_utilities.inc @@ -1,11 +1,11 @@ -static size_t rwkv_tensor_nbytes(const enum ggml_type type, const int64_t width, const int64_t height) { - return (ggml_type_size(type) * width * height) / ggml_blck_size(type); +static size_t rwkv_tensor_nbytes(const enum ggml_type type, const int64_t size0, const int64_t size1, const int64_t size2) { + return (ggml_type_size(type) * size0 * size1 * size2) / ggml_blck_size(type); } // For some reason, ggml_nbytes calculates the size in a way // incompatible with rwkv.cpp; we need our own function for that. static size_t rwkv_tensor_nbytes(const struct ggml_tensor * tensor) { - return rwkv_tensor_nbytes(tensor->type, tensor->ne[0], tensor->ne[1]); + return rwkv_tensor_nbytes(tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2]); } // Minimum amount of memory required for a ggml context, not counting the tensor data. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index be1b94c..60c783a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -21,11 +21,23 @@ function(rwkv_add_test source) endif() endfunction() -file(COPY tiny-rwkv-660K-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) -file(COPY tiny-rwkv-660K-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) -file(COPY tiny-rwkv-660K-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) -file(COPY tiny-rwkv-660K-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) -file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-4v0-660K-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-4v0-660K-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-4v0-660K-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-4v0-660K-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY expected-logits-4v0-660K.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + +file(COPY tiny-rwkv-5v1-730K-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-5v1-730K-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-5v1-730K-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-5v1-730K-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY expected-logits-5v1-730K.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + +file(COPY tiny-rwkv-5v2-730K-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-5v2-730K-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-5v2-730K-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-5v2-730K-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY expected-logits-5v2-730K.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) rwkv_add_test(test_ggml_basics.c) rwkv_add_test(test_quantized_matmul_on_gpu.c) diff --git a/tests/assertions.inc b/tests/assertions.inc index df5ba64..25127ae 100644 --- a/tests/assertions.inc +++ b/tests/assertions.inc @@ -3,12 +3,19 @@ #include +bool late_abort = false; +bool must_abort = false; + #define ASSERT(x, ...) {\ if (!(x)) {\ fprintf(stderr, "*** Assertion failed ***\n");\ fprintf(stderr, __VA_ARGS__);\ fprintf(stderr, "\n%s:%d\n", __FILE__, __LINE__);\ - abort();\ + if (late_abort) {\ + must_abort = true;\ + } else {\ + abort();\ + }\ }\ } diff --git a/tests/expected_logits.bin b/tests/expected-logits-4v0-660K.bin similarity index 100% rename from tests/expected_logits.bin rename to tests/expected-logits-4v0-660K.bin diff --git a/tests/expected-logits-5v1-730K.bin b/tests/expected-logits-5v1-730K.bin new file mode 100644 index 0000000..6637dee --- /dev/null +++ b/tests/expected-logits-5v1-730K.bin @@ -0,0 +1,2 @@ +d{mowR?pO +Օ.ݖ>5q}?DR!ҔBW3u N H@vBL ?GoD*glhi-~=eePJ(DRl˘ؾQ |6ܝY,VLRvu8nFi࿃ ;L" .(}~/>IK0-8k@ڊ}v-b/Y,;kF9iU=-`@?(r?ŎG.@w0@D@ƞ@iK@x[?T:@ZR?V(@n>?@a?j)@4@i@b@ݑ"a@h2@{@?r>@ ?̿M3?j6KXC:D|`"<Հ,pJ3dU d&-Ћ?8Js~(䏃dT[dnkW(ʇ؜?,ǜ{ FaH5 u+5YYjnX@7iqkmn]҂6~u,ɔ4zi0śڐ/Σhy(i`+Յ}tԢ߄KB (zLۆN]~qMڙYs.}rV~"5j=2Ghvψaɔ)f6}R,y_Ƌu[͇+{|8Ω:z(91N+ \ No newline at end of file diff --git a/tests/expected-logits-5v2-730K.bin b/tests/expected-logits-5v2-730K.bin new file mode 100644 index 0000000..94ca07e Binary files /dev/null and b/tests/expected-logits-5v2-730K.bin differ diff --git a/tests/logit_difference_validator.inc b/tests/logit_difference_validator.inc index 8907a75..9269d36 100644 --- a/tests/logit_difference_validator.inc +++ b/tests/logit_difference_validator.inc @@ -13,18 +13,23 @@ // Also test multithreading. #define N_THREADS 2 -void load_expected_logits(float * expected_logits) { - FILE * file = fopen("expected_logits.bin", "rb"); - ASSERT(file != NULL, "Failed to open expected_logits.bin"); +void load_expected_logits(float * expected_logits, const char * version) { + char file_name[128]; + sprintf(file_name, "expected-logits-%s.bin", version); + FILE * file = fopen(file_name, "rb"); + ASSERT(file != NULL, "Failed to open %s", file_name); size_t elements_read = fread(expected_logits, sizeof(float), N_VOCAB, file); ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read); fclose(file); } -void test_model(const char * model_path, const float * expected_logits, const float max_diff) { - fprintf(stderr, "Testing %s\n", model_path); +void test_model(const char * version, const char * format, const float * expected_logits, const float max_diff) { + char file_name[128]; + sprintf(file_name, "tiny-rwkv-%s-%s.bin", version, format); - struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS); + fprintf(stderr, "Testing %s\n", file_name); + + struct rwkv_context * model = rwkv_init_from_file(file_name, N_THREADS); enum rwkv_error_flags error = rwkv_get_last_error(NULL); ASSERT(error == 0, "Unexpected error %d", error); @@ -60,10 +65,9 @@ void test_model(const char * model_path, const float * expected_logits, const fl diff_sum += logits[i] - expected_logits[i]; } - fprintf(stderr, "Serial difference sum: %f\n", diff_sum); + fprintf(stderr, "Serial difference sum: %f, expected %f\n", diff_sum, max_diff); - // When something breaks, difference would be way more than 10 - ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big serial difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); + ASSERT(fabsf(diff_sum) <= fabsf(max_diff) * 1.05F, "Too big serial difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); // --- @@ -76,10 +80,9 @@ void test_model(const char * model_path, const float * expected_logits, const fl diff_sum += logits[i] - expected_logits[i]; } - fprintf(stderr, "Sequence difference sum: %f\n", diff_sum); + fprintf(stderr, "Sequence difference sum: %f, expected %f\n", diff_sum, max_diff); - // When something breaks, difference would be way more than 10 - ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big sequence difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); + ASSERT(fabsf(diff_sum) <= fabsf(max_diff) * 1.05F, "Too big sequence difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); // --- diff --git a/tests/test_context_cloning.c b/tests/test_context_cloning.c index 9087fca..40d0495 100644 --- a/tests/test_context_cloning.c +++ b/tests/test_context_cloning.c @@ -8,7 +8,7 @@ #include "assertions.inc" int main(void) { - struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-5v2-730K-FP32.bin", 2); ASSERT(ctx != NULL, "Unexpected error 0x%.8X", rwkv_get_last_error(NULL)); diff --git a/tests/test_eval_sequence_in_chunks.c b/tests/test_eval_sequence_in_chunks.c index 7c3b10a..804c244 100644 --- a/tests/test_eval_sequence_in_chunks.c +++ b/tests/test_eval_sequence_in_chunks.c @@ -10,7 +10,7 @@ void test_on_prompt(const char * prompt, const size_t prompt_length) { fprintf(stderr, "Calculating expected state and logits for prompt of size %zd\n", prompt_length); - struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-5v2-730K-FP32.bin", 2); ASSERT(ctx != NULL, "Unexpected error 0x%.8X", rwkv_get_last_error(NULL)); diff --git a/tests/test_logit_calculation_skipping.c b/tests/test_logit_calculation_skipping.c index 2765a19..c9e9774 100644 --- a/tests/test_logit_calculation_skipping.c +++ b/tests/test_logit_calculation_skipping.c @@ -14,7 +14,7 @@ const char prompt[TOKEN_COUNT + 1] = "hello world"; void test_serial_mode(void) { fprintf(stderr, "Testing serial mode\n"); - struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-5v2-730K-FP32.bin", 2); ASSERT(ctx != NULL, "Unexpected error 0x%.8X", rwkv_get_last_error(NULL)); @@ -54,7 +54,7 @@ void test_serial_mode(void) { void test_sequential_mode(void) { fprintf(stderr, "Testing sequential mode\n"); - struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-5v2-730K-FP32.bin", 2); ASSERT(ctx != NULL, "Unexpected error 0x%.8X", rwkv_get_last_error(NULL)); diff --git a/tests/test_quantization_format_compatibility.c b/tests/test_quantization_format_compatibility.c index e652e6c..4edb370 100644 --- a/tests/test_quantization_format_compatibility.c +++ b/tests/test_quantization_format_compatibility.c @@ -6,16 +6,39 @@ #include "logit_difference_validator.inc" +#define VERSION_COUNT 3 + int main(void) { fprintf(stderr, "System info: %s\n", rwkv_get_system_info_string()); - float * expected_logits = calloc(N_VOCAB, sizeof(float)); - load_expected_logits(expected_logits); + const char * versions[VERSION_COUNT] = { + "4v0-660K", + "5v1-730K", + "5v2-730K" + }; + + // See the explanation of huge expected differences for v5 models in test_tiny_rwkv.c + const float differences[VERSION_COUNT * 2] = { + // 4v0 + -000.170404F, + +000.278034F, + // 5v1 + -163.439407F, + -018.017435F, + // 5v2 + +025.273308F, + +048.068733F + }; + + for (int i = 0; i < VERSION_COUNT; i++) { + float * expected_logits = calloc(N_VOCAB, sizeof(float)); + load_expected_logits(expected_logits, versions[i]); - test_model("tiny-rwkv-660K-Q5_0.bin", expected_logits, -0.170404F); - test_model("tiny-rwkv-660K-Q5_1.bin", expected_logits, +0.278034F); + test_model(versions[i], "Q5_0", expected_logits, differences[i * 2 + 0]); + test_model(versions[i], "Q5_1", expected_logits, differences[i * 2 + 1]); - free(expected_logits); + free(expected_logits); + } return 0; } diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index b3a45cc..291142a 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -6,69 +6,143 @@ #include "logit_difference_validator.inc" +#define VERSION_COUNT 3 +#define FORMAT_COUNT 7 + int main(void) { + late_abort = true; + fprintf(stderr, "System info: %s\n", rwkv_get_system_info_string()); // Silences the overly verbose output during quantization. rwkv_set_print_errors(NULL, false); - float * expected_logits = calloc(N_VOCAB, sizeof(float)); - load_expected_logits(expected_logits); + const char * versions[VERSION_COUNT] = { + "4v0-660K", + "5v1-730K", + "5v2-730K" + }; + + const char * formats[FORMAT_COUNT] = { + "FP32", + "FP16", + "Q4_0", + "Q4_1", + "Q5_0", + "Q5_1", + "Q8_0" + }; - // Somehow when using cuBLAS the result of Q4_1 is different from CPU only. - float expected_difference_sum[14] = { - +0.000000F, // FP32 + const float expected_difference_sum_full[VERSION_COUNT * 2] = { + // 4v0 + +0.001000F, // FP32 -0.005320F, // FP16 + // 5v1 + +0.001000F, // FP32 + -0.289921F, // FP16 + // 5v2 + +0.001000F, // FP32 + +0.206919F // FP16 + }; - -0.160030F, // Q4_0 -#if defined(GGML_USE_CUBLAS) - -0.547409F, // Q4_1 -#else - -0.370606F, // Q4_1 -#endif - -0.170404F, // Q5_0 - +0.278034F, // Q5_1 - +0.071216F, // Q8_0 - - +0.154614F, // Q4_0 -#if defined(GGML_USE_CUBLAS) - -0.539827F, // Q4_1 -#else - -0.372169F, // Q4_1 -#endif - -0.170043F, // Q5_0 - +0.294953F, // Q5_1 - +0.065571F, // Q8_0 + // *** Why the hell the expected logit difference sum for v4 models is < 1, and for v5 models it can be as high as 160? *** + // + // Due to mistake in Tiny RWKV v4 training code, all FFN layers were zeroed-out during training. + // Output of v4 models is basically incoherent -- there are "words" and spaces between them, but these words consist of random characters. + // I quess that since there is not much "intelligence" to lose, there will be a pretty low logit difference sum after quantization. + // + // In contrast, Tiny RWKV v5 models were trained correctly, and FFN layers have an OK-looking weight distribution. + // v5 models produce mostly real English words, and, sometimes, whole word combinations that make sense. Structure of the output is also correct. + // Since there are real numbers in FFN layers now, I expect quantization to have a way larger effect on the output, compared to v4. + // + // For reference, RWKV v4 169M would give -2395.1636 logit difference sum after quantizing FP32 to Q5_1. So, such orders of magnitude are not unheard of. + // + // In any case, here, the logit difference sum works OK for verifying that inference was not broken after some changes. + + const float expected_difference_sum_quantized_FP32[VERSION_COUNT * (FORMAT_COUNT - 2)] = { + // 4v0 + -000.160030F, // Q4_0 + -000.547409F, // Q4_1 + -000.170404F, // Q5_0 + +000.278034F, // Q5_1 + +000.076282F, // Q8_0 + // 5v1 + +117.932594F, // Q4_0 + -026.712271F, // Q4_1 + -163.439407F, // Q5_0 + -018.017435F, // Q5_1 + +000.585238F, // Q8_0 + // 5v2 + +035.271305F, // Q4_0 + +061.719509F, // Q4_1 + +025.273308F, // Q5_0 + +048.068733F, // Q5_1 + -009.441034F // Q8_0 + }; + + const float expected_difference_sum_quantized_FP16[VERSION_COUNT * (FORMAT_COUNT - 2)] = { + // 4v0 + +000.154614F, // Q4_0 + -000.539827F, // Q4_1 + -000.170043F, // Q5_0 + +000.294953F, // Q5_1 + +000.070944F, // Q8_0 + // 5v1 + +119.471931F, // Q4_0 + -028.245888F, // Q4_1 + -159.870956F, // Q5_0 + -039.708530F, // Q5_1 + -000.962695F, // Q8_0 + // 5v2 + +034.135971F, // Q4_0 + +059.066830F, // Q4_1 + +021.588751F, // Q5_0 + +029.726818F, // Q5_1 + -007.242277F // Q8_0 }; - test_model("tiny-rwkv-660K-FP32.bin", expected_logits, expected_difference_sum[0]); - test_model("tiny-rwkv-660K-FP16.bin", expected_logits, expected_difference_sum[1]); - - rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_0.bin", "Q4_0"); - rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1.bin", "Q4_1"); - rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_0.bin", "Q5_0"); - rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_1.bin", "Q5_1"); - rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q8_0.bin", "Q8_0"); - - test_model("tiny-rwkv-660K-FP32-Q4_0.bin", expected_logits, expected_difference_sum[2]); - test_model("tiny-rwkv-660K-FP32-Q4_1.bin", expected_logits, expected_difference_sum[3]); - test_model("tiny-rwkv-660K-FP32-Q5_0.bin", expected_logits, expected_difference_sum[4]); - test_model("tiny-rwkv-660K-FP32-Q5_1.bin", expected_logits, expected_difference_sum[5]); - test_model("tiny-rwkv-660K-FP32-Q8_0.bin", expected_logits, expected_difference_sum[6]); - - rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_0.bin", "Q4_0"); - rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1.bin", "Q4_1"); - rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q5_0.bin", "Q5_0"); - rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q5_1.bin", "Q5_1"); - rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q8_0.bin", "Q8_0"); - - test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[7]); - test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[8]); - test_model("tiny-rwkv-660K-FP16-Q5_0.bin", expected_logits, expected_difference_sum[9]); - test_model("tiny-rwkv-660K-FP16-Q5_1.bin", expected_logits, expected_difference_sum[10]); - test_model("tiny-rwkv-660K-FP16-Q8_0.bin", expected_logits, expected_difference_sum[11]); - - free(expected_logits); + for (int i_version = 0; i_version < VERSION_COUNT; i_version++) { + float * expected_logits = calloc(N_VOCAB, sizeof(float)); + load_expected_logits(expected_logits, versions[i_version]); + + for (int i_format = 0; i_format < FORMAT_COUNT; i_format++) { + if (i_format < 2) { + test_model(versions[i_version], formats[i_format], expected_logits, expected_difference_sum_full[i_version * 2 + i_format]); + + continue; + } + + char source_file_name[128]; + char dest_format[128]; + char dest_file_name[128]; + + // --- + + sprintf(source_file_name, "tiny-rwkv-%s-FP32.bin", versions[i_version]); + sprintf(dest_format, "FP32-to-%s", formats[i_format]); + sprintf(dest_file_name, "tiny-rwkv-%s-%s.bin", versions[i_version], dest_format); + + rwkv_quantize_model_file(source_file_name, dest_file_name, formats[i_format]); + + test_model(versions[i_version], dest_format, expected_logits, expected_difference_sum_quantized_FP32[i_version * (FORMAT_COUNT - 2) + (i_format - 2)]); + + // --- + + sprintf(source_file_name, "tiny-rwkv-%s-FP16.bin", versions[i_version]); + sprintf(dest_format, "FP16-to-%s", formats[i_format]); + sprintf(dest_file_name, "tiny-rwkv-%s-%s.bin", versions[i_version], dest_format); + + rwkv_quantize_model_file(source_file_name, dest_file_name, formats[i_format]); + + test_model(versions[i_version], dest_format, expected_logits, expected_difference_sum_quantized_FP16[i_version * (FORMAT_COUNT - 2) + (i_format - 2)]); + } + + free(expected_logits); + } + + if (must_abort) { + abort(); + } return 0; } diff --git a/tests/tiny-rwkv-660K-FP16.bin b/tests/tiny-rwkv-4v0-660K-FP16.bin similarity index 100% rename from tests/tiny-rwkv-660K-FP16.bin rename to tests/tiny-rwkv-4v0-660K-FP16.bin diff --git a/tests/tiny-rwkv-660K-FP32.bin b/tests/tiny-rwkv-4v0-660K-FP32.bin similarity index 100% rename from tests/tiny-rwkv-660K-FP32.bin rename to tests/tiny-rwkv-4v0-660K-FP32.bin diff --git a/tests/tiny-rwkv-660K-Q5_0.bin b/tests/tiny-rwkv-4v0-660K-Q5_0.bin similarity index 100% rename from tests/tiny-rwkv-660K-Q5_0.bin rename to tests/tiny-rwkv-4v0-660K-Q5_0.bin diff --git a/tests/tiny-rwkv-660K-Q5_1.bin b/tests/tiny-rwkv-4v0-660K-Q5_1.bin similarity index 100% rename from tests/tiny-rwkv-660K-Q5_1.bin rename to tests/tiny-rwkv-4v0-660K-Q5_1.bin diff --git a/tests/tiny-rwkv-5v1-730K-FP16.bin b/tests/tiny-rwkv-5v1-730K-FP16.bin new file mode 100644 index 0000000..e8edbf8 Binary files /dev/null and b/tests/tiny-rwkv-5v1-730K-FP16.bin differ diff --git a/tests/tiny-rwkv-5v1-730K-FP32.bin b/tests/tiny-rwkv-5v1-730K-FP32.bin new file mode 100644 index 0000000..12377a9 Binary files /dev/null and b/tests/tiny-rwkv-5v1-730K-FP32.bin differ diff --git a/tests/tiny-rwkv-5v1-730K-Q5_0.bin b/tests/tiny-rwkv-5v1-730K-Q5_0.bin new file mode 100644 index 0000000..599c56d Binary files /dev/null and b/tests/tiny-rwkv-5v1-730K-Q5_0.bin differ diff --git a/tests/tiny-rwkv-5v1-730K-Q5_1.bin b/tests/tiny-rwkv-5v1-730K-Q5_1.bin new file mode 100644 index 0000000..0c7a71a Binary files /dev/null and b/tests/tiny-rwkv-5v1-730K-Q5_1.bin differ diff --git a/tests/tiny-rwkv-5v2-730K-FP16.bin b/tests/tiny-rwkv-5v2-730K-FP16.bin new file mode 100644 index 0000000..4d76c6f Binary files /dev/null and b/tests/tiny-rwkv-5v2-730K-FP16.bin differ diff --git a/tests/tiny-rwkv-5v2-730K-FP32.bin b/tests/tiny-rwkv-5v2-730K-FP32.bin new file mode 100644 index 0000000..5e0fdc3 Binary files /dev/null and b/tests/tiny-rwkv-5v2-730K-FP32.bin differ diff --git a/tests/tiny-rwkv-5v2-730K-Q5_0.bin b/tests/tiny-rwkv-5v2-730K-Q5_0.bin new file mode 100644 index 0000000..d044361 Binary files /dev/null and b/tests/tiny-rwkv-5v2-730K-Q5_0.bin differ diff --git a/tests/tiny-rwkv-5v2-730K-Q5_1.bin b/tests/tiny-rwkv-5v2-730K-Q5_1.bin new file mode 100644 index 0000000..4f11975 Binary files /dev/null and b/tests/tiny-rwkv-5v2-730K-Q5_1.bin differ