diff --git a/README.md b/README.md index 281b7f0..80be95e 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,15 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [gger Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **focused on CPU**, but cuBLAS is also supported. -This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it. +This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](python%2Frwkv_cpp%2Frwkv_cpp_model.py) for it. [RWKV](https://arxiv.org/abs/2305.13048) is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts. Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py). -### Quality and performance +⚠️ **Python API was restructured on 2023-09-20**, you may need to change paths/package names in your code when updating `rwkv.cpp`. + +## Quality and performance If you use `rwkv.cpp` for anything serious, please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you. @@ -26,7 +28,7 @@ Below table is for reference only. Measurements were made on 4C/8T x86 CPU with | `FP16` | **15.623** | 117 | 2.82 | | `FP32` | **15.623** | 198 | 5.64 | -#### With cuBLAS +### With cuBLAS Measurements were made on Intel i7 13700K & NVIDIA 3060 Ti 8 GB. Latency per token in ms shown. @@ -124,80 +126,64 @@ This option would require a little more manual work, but you can use it with any ```commandline # Windows -python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin FP16 +python python\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin FP16 # Linux / MacOS -python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin FP16 +python python/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin FP16 ``` **Optionally**, quantize the model into one of quantized formats from the table above: ```commandline # Windows -python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1 +python python\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1 # Linux / MacOS -python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q5_1.bin Q5_1 +python python/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q5_1.bin Q5_1 ``` ### 4. Run the model -**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/). +#### Using the command line -**Note**: change the model path with the non-quantized model for the full weights model. +**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/). To generate some text, run: ```commandline # Windows -python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q5_1.bin +python python\generate_completions.py C:\rwkv.cpp-169M-Q5_1.bin # Linux / MacOS -python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin +python python/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin ``` To chat with a bot, run: ```commandline # Windows -python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q5_1.bin +python python\chat_with_bot.py C:\rwkv.cpp-169M-Q5_1.bin # Linux / MacOS -python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin +python python/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin ``` Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings. ---- - -Example of using `rwkv.cpp` in your custom Python script: - -```python -import rwkv_cpp_model -import rwkv_cpp_shared_library +#### Using in your own code -# Change to model paths used above (quantized or full weights) -model_path = r'C:\rwkv.cpp-169M.bin' +The short and simple script [inference_example.py](python%2Finference_example.py) demostrates the use of `rwkv.cpp` in Python. +To use `rwkv.cpp` in C/C++, include the header [rwkv.h](rwkv.h). -model = rwkv_cpp_model.RWKVModel( - rwkv_cpp_shared_library.load_rwkv_shared_library(), - model_path, - thread_count=4, #need to adjust when use cuBLAS - gpu_layers_count=5 #only enabled when use cuBLAS -) +To use `rwkv.cpp` in any other language, see [Bindings](#Bindings) section below. If your language is missing, you can try to bind to the C API using the tooling provided by your language. -logits, state = None, None - -for token in [1, 2, 3]: - logits, state = model.eval(token, state) - - print(f'Output logits: {logits}') +## Bindings -# Don't forget to free the memory after you've done working with the model -model.free() +These projects wrap `rwkv.cpp` for easier use in other languages/frameworks. -``` +* Golang: [seasonjs/rwkv](https://github.com/seasonjs/rwkv) +* Node.js: [Atome-FE/llama-node](https://github.com/Atome-FE/llama-node) ## Compatibility @@ -214,13 +200,6 @@ For reference only, here is a list of latest versions of `rwkv.cpp` that have su See also [docs/FILE_FORMAT.md](docs/FILE_FORMAT.md) for version numbers of `rwkv.cpp` model files and their changelog. -## Bindings - -These projects wrap `rwkv.cpp` for easier use in other languages/frameworks. - -* Golang: [seasonjs/rwkv](https://github.com/seasonjs/rwkv) -* Node.js: [Atome-FE/llama-node](https://github.com/Atome-FE/llama-node) - ## Contributing Please follow the code style described in [docs/CODE_STYLE.md](docs/CODE_STYLE.md). diff --git a/docs/CODE_STYLE.md b/docs/CODE_STYLE.md index 9fb0d8b..ccd670c 100644 --- a/docs/CODE_STYLE.md +++ b/docs/CODE_STYLE.md @@ -27,6 +27,10 @@ Overall, keep code in similar style as it was before. - Place braces on the same line as the statement. - Always add braces to `if`, `for`, `while`, `do` and other similar statements. - Prefix top-level function and struct names with `rwkv_`. +- Mark all functions that are not members of public API as `static`. + - Public API is the set of functions defined in `rwkv.h`. +- Mark all immutable function arguments as `const`. + - This is not required for local variables. ## Python diff --git a/extras/CMakeLists.txt b/extras/CMakeLists.txt index d4e7cd6..2a8ab3b 100644 --- a/extras/CMakeLists.txt +++ b/extras/CMakeLists.txt @@ -9,7 +9,5 @@ function(rwkv_add_extra source) endif() endfunction() -file(GLOB extras *.c) -foreach (extra ${extras}) - rwkv_add_extra(${extra}) -endforeach() +rwkv_add_extra(cpu_info.c) +rwkv_add_extra(quantize.c) diff --git a/extras/cpu_info.c b/extras/cpu_info.c index d3e6a20..74e36e4 100644 --- a/extras/cpu_info.c +++ b/extras/cpu_info.c @@ -1,7 +1,9 @@ -#include "rwkv.h" - #include -int main() { +#include + +int main(void) { printf("%s", rwkv_get_system_info_string()); + + return 0; } diff --git a/extras/quantize.c b/extras/quantize.c index 92e4c37..578e632 100644 --- a/extras/quantize.c +++ b/extras/quantize.c @@ -1,11 +1,11 @@ -#include "ggml.h" -#include "rwkv.h" - #include #include #include -#ifdef _WIN32 +#include +#include + +#if defined(_WIN32) bool QueryPerformanceFrequency(uint64_t* lpFrequency); bool QueryPerformanceCounter(uint64_t* lpPerformanceCount); @@ -22,7 +22,7 @@ bool QueryPerformanceCounter(uint64_t* lpPerformanceCount); #define TIME_DIFF(freq, start, end) (double) ((end.tv_nsec - start.tv_nsec) / 1000000) / 1000 #endif -enum ggml_type type_from_string(const char* string) { +static enum ggml_type type_from_string(const char * string) { if (strcmp(string, "Q4_0") == 0) return GGML_TYPE_Q4_0; if (strcmp(string, "Q4_1") == 0) return GGML_TYPE_Q4_1; if (strcmp(string, "Q5_0") == 0) return GGML_TYPE_Q5_0; @@ -31,9 +31,10 @@ enum ggml_type type_from_string(const char* string) { return GGML_TYPE_COUNT; } -int main(int argc, char * argv[]) { +int main(const int argc, const char * argv[]) { if (argc != 4 || type_from_string(argv[3]) == GGML_TYPE_COUNT) { - fprintf(stderr, "Usage: %s INPUT OUTPUT FORMAT\n\nAvailable formats: Q4_0 Q4_1 Q5_0 Q5_1 Q8_0\n", argv[0]); + fprintf(stderr, "Usage: %s INPUT_FILE OUTPUT_FILE FORMAT\n\nAvailable formats: Q4_0 Q4_1 Q5_0 Q5_1 Q8_0\n", argv[0]); + return EXIT_FAILURE; } @@ -48,11 +49,15 @@ int main(int argc, char * argv[]) { double diff = TIME_DIFF(freq, start, end); + fprintf(stderr, "Took %.3f s\n", diff); + if (success) { - fprintf(stderr, "Succeeded in %.3fs\n", diff); + fprintf(stderr, "Success\n"); + return EXIT_SUCCESS; } else { - fprintf(stderr, "Error in %.3fs: 0x%.8X\n", diff, rwkv_get_last_error(NULL)); + fprintf(stderr, "Error: 0x%.8X\n", rwkv_get_last_error(NULL)); + return EXIT_FAILURE; } } diff --git a/rwkv/20B_tokenizer.json b/python/20B_tokenizer.json similarity index 100% rename from rwkv/20B_tokenizer.json rename to python/20B_tokenizer.json diff --git a/rwkv/chat_with_bot.py b/python/chat_with_bot.py similarity index 95% rename from rwkv/chat_with_bot.py rename to python/chat_with_bot.py index 78cbecd..af2b9af 100644 --- a/rwkv/chat_with_bot.py +++ b/python/chat_with_bot.py @@ -6,14 +6,13 @@ import argparse import pathlib import copy +import json +import time import torch import sampling -import rwkv_cpp_model -import rwkv_cpp_shared_library -from rwkv_tokenizer import get_tokenizer -import json +from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model +from tokenizer_util import get_tokenizer from typing import List, Dict, Optional -import time # ======================================== Script settings ======================================== @@ -98,7 +97,7 @@ def load_thread_state(_thread: str) -> None: # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end. # See https://github.com/BlinkDL/ChatRWKV/pull/110/files -def split_last_end_of_line(tokens): +def split_last_end_of_line(tokens: List[int]) -> List[int]: if len(tokens) > 0 and tokens[-1] == DOUBLE_END_OF_LINE_TOKEN: tokens = tokens[:-1] + [END_OF_LINE_TOKEN, END_OF_LINE_TOKEN] @@ -106,7 +105,7 @@ def split_last_end_of_line(tokens): # ================================================================================================= -processing_start = time.time() +processing_start: float = time.time() prompt_tokens = tokenizer_encode(init_prompt) prompt_token_count = len(prompt_tokens) @@ -114,7 +113,7 @@ def split_last_end_of_line(tokens): process_tokens(split_last_end_of_line(prompt_tokens)) -processing_duration = time.time() - processing_start +processing_duration: float = time.time() - processing_start print(f'Processed in {int(processing_duration)} s, {int(processing_duration / prompt_token_count * 1000)} ms per token') @@ -125,11 +124,11 @@ def split_last_end_of_line(tokens): while True: # Read user input - user_input = input(f'> {user}{separator} ') - msg = user_input.replace('\\n', '\n').strip() + user_input: str = input(f'> {user}{separator} ') + msg: str = user_input.replace('\\n', '\n').strip() - temperature = TEMPERATURE - top_p = TOP_P + temperature: float = TEMPERATURE + top_p: float = TOP_P if '-temp=' in msg: temperature = float(msg.split('-temp=')[1].split(' ')[0]) diff --git a/rwkv/convert_pytorch_to_ggml.py b/python/convert_pytorch_to_ggml.py similarity index 93% rename from rwkv/convert_pytorch_to_ggml.py rename to python/convert_pytorch_to_ggml.py index 958fbe1..ed8bce0 100644 --- a/rwkv/convert_pytorch_to_ggml.py +++ b/python/convert_pytorch_to_ggml.py @@ -16,7 +16,7 @@ def parse_args(): return parser.parse_args() def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: - n_layer = 0 + n_layer: int = 0 while f'blocks.{n_layer}.ln1.weight' in state_dict: n_layer += 1 @@ -28,9 +28,9 @@ def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str) -> None: emb_weight: torch.Tensor = state_dict['emb.weight'] - n_layer = get_layer_count(state_dict) - n_vocab = emb_weight.shape[0] - n_embed = emb_weight.shape[1] + n_layer: int = get_layer_count(state_dict) + n_vocab: int = emb_weight.shape[0] + n_embed: int = emb_weight.shape[1] with open(dest_path, 'wb') as out_file: is_FP16: bool = data_type == 'FP16' or data_type == 'float16' @@ -48,7 +48,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t )) for k in state_dict.keys(): - tensor = state_dict[k].float() + tensor: torch.Tensor = state_dict[k].float() # Same processing as in "RWKV_in_150_lines.py" if '.time_' in k: diff --git a/rwkv/convert_pytorch_to_ggml.test.py b/python/convert_pytorch_to_ggml.test.py similarity index 88% rename from rwkv/convert_pytorch_to_ggml.test.py rename to python/convert_pytorch_to_ggml.test.py index 501a85e..2a75d15 100644 --- a/rwkv/convert_pytorch_to_ggml.test.py +++ b/python/convert_pytorch_to_ggml.test.py @@ -5,7 +5,7 @@ from typing import Dict def test() -> None: - test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp' + test_file_path: str = 'convert_pytorch_rwkv_to_ggml_test.tmp' try: state_dict: Dict[str, torch.Tensor] = { @@ -15,8 +15,8 @@ def test() -> None: convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='FP32') - with open(test_file_path, 'rb') as input: - actual_bytes: bytes = input.read() + with open(test_file_path, 'rb') as test_file: + actual_bytes: bytes = test_file.read() expected_bytes: bytes = struct.pack( '=iiiiii' + 'iiiii10sffffff' + 'iiii19sf', diff --git a/rwkv/generate_completions.py b/python/generate_completions.py similarity index 76% rename from rwkv/generate_completions.py rename to python/generate_completions.py index 4cf2a3d..89a2e02 100644 --- a/rwkv/generate_completions.py +++ b/python/generate_completions.py @@ -1,11 +1,12 @@ # Generates completions from RWKV model based on a prompt. +# Usage example: python generate_completions.py C:\rwkv.cpp-169M-Q5_1.bin 20B import argparse import time import sampling -import rwkv_cpp_model -import rwkv_cpp_shared_library -from rwkv_tokenizer import get_tokenizer +from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model +from tokenizer_util import get_tokenizer +from typing import List # ======================================== Script settings ======================================== @@ -13,7 +14,7 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [ggerganov/ggml](https://github.com/ggerganov/ggml). -Besides usual **FP32**, it supports **FP16** and **quantized INT4** inference on CPU. This project is **CPU only**.""" +Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **focused on CPU**, but cuBLAS is also supported.""" # How many completions to generate. generation_count: int = 3 @@ -35,7 +36,7 @@ tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer) -prompt_tokens = tokenizer_encode(prompt) +prompt_tokens: List[int] = tokenizer_encode(prompt) library = rwkv_cpp_shared_library.load_rwkv_shared_library() print(f'System info: {library.rwkv_get_system_info_string()}') @@ -43,7 +44,7 @@ print('Loading RWKV model') model = rwkv_cpp_model.RWKVModel(library, args.model_path) -prompt_token_count = len(prompt_tokens) +prompt_token_count: int = len(prompt_tokens) print(f'{prompt_token_count} tokens in prompt') init_logits, init_state = None, None @@ -54,16 +55,18 @@ for GENERATION in range(generation_count): print(f'\n--- Generation {GENERATION} ---\n') print(prompt, end='[') - start = time.time() + + start: float = time.time() logits, state = init_logits.clone(), init_state.clone() for i in range(tokens_per_generation): - token = sampling.sample_logits(logits, temperature, top_p) + token: int = sampling.sample_logits(logits, temperature, top_p) print(tokenizer_decode([token]), end='', flush=True) logits, state = model.eval(token, state, state, logits) - delay = time.time() - start + delay: float = time.time() - start + print(']\n\nTook %.3f sec, %d ms per token' % (delay, delay / tokens_per_generation * 1000)) diff --git a/python/inference_example.py b/python/inference_example.py new file mode 100644 index 0000000..116694c --- /dev/null +++ b/python/inference_example.py @@ -0,0 +1,46 @@ +# Generates some text with an RWKV model. +# Usage example: python inference_example.py C:\rwkv.cpp-169M-Q5_1.bin 20B + +import argparse +import sampling +from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model +from tokenizer_util import get_tokenizer +from typing import List + +# Parse received arguments. +parser = argparse.ArgumentParser(description='Generate some text with an RWKV model') +parser.add_argument('model_path', help='Path to RWKV model in ggml format') +parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B') +args = parser.parse_args() + +# Load the model. +library = rwkv_cpp_shared_library.load_rwkv_shared_library() +model = rwkv_cpp_model.RWKVModel(library, args.model_path) + +# Set up the tokenizer. +tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer) + +# Prepare the prompt. +prompt: str = """One upon a time,""" +prompt_tokens: List[int] = tokenizer_encode(prompt) + +# Process the prompt. +init_logits, init_state = None, None + +for token in prompt_tokens: + init_logits, init_state = model.eval(token, init_state, init_state, init_logits) + +logits, state = init_logits.clone(), init_state.clone() + +# Generate and print the completion. +print(prompt, end='') + +for i in range(32): + token: int = sampling.sample_logits(logits, temperature=0.8, top_p=0.5) + + print(tokenizer_decode([token]), end='', flush=True) + + logits, state = model.eval(token, state, state, logits) + +# Don't forget to free the memory after you are done working with the model! +model.free() diff --git a/rwkv/measure_pexplexity.py b/python/measure_pexplexity.py similarity index 95% rename from rwkv/measure_pexplexity.py rename to python/measure_pexplexity.py index 815875c..f483277 100644 --- a/rwkv/measure_pexplexity.py +++ b/python/measure_pexplexity.py @@ -6,9 +6,9 @@ import time import argparse import torch -import rwkv_cpp_model -import rwkv_cpp_shared_library -from rwkv_tokenizer import get_tokenizer +from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model +from tokenizer_util import get_tokenizer +from typing import List def parse_args(): parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file') @@ -26,7 +26,7 @@ def parse_args(): _, tokenizer_encode = get_tokenizer(args.tokenizer) -tokens = tokenizer_encode(text) +tokens: List[int] = tokenizer_encode(text) token_count: int = len(tokens) print(f'{token_count} tokens in the text') diff --git a/rwkv/merge_lora_into_ggml.py b/python/merge_lora_into_ggml.py similarity index 100% rename from rwkv/merge_lora_into_ggml.py rename to python/merge_lora_into_ggml.py diff --git a/rwkv/prompt/Chinese-Chat.json b/python/prompt/Chinese-Chat.json similarity index 100% rename from rwkv/prompt/Chinese-Chat.json rename to python/prompt/Chinese-Chat.json diff --git a/rwkv/prompt/Chinese-QA.json b/python/prompt/Chinese-QA.json similarity index 100% rename from rwkv/prompt/Chinese-QA.json rename to python/prompt/Chinese-QA.json diff --git a/rwkv/prompt/English-Chat.json b/python/prompt/English-Chat.json similarity index 100% rename from rwkv/prompt/English-Chat.json rename to python/prompt/English-Chat.json diff --git a/rwkv/prompt/English-QA.json b/python/prompt/English-QA.json similarity index 100% rename from rwkv/prompt/English-QA.json rename to python/prompt/English-QA.json diff --git a/rwkv/prompt/Japanese-Chat.json b/python/prompt/Japanese-Chat.json similarity index 100% rename from rwkv/prompt/Japanese-Chat.json rename to python/prompt/Japanese-Chat.json diff --git a/rwkv/prompt/Japanese-QA.json b/python/prompt/Japanese-QA.json similarity index 100% rename from rwkv/prompt/Japanese-QA.json rename to python/prompt/Japanese-QA.json diff --git a/rwkv/quantize.py b/python/quantize.py similarity index 95% rename from rwkv/quantize.py rename to python/quantize.py index fe45da6..4d269cf 100644 --- a/rwkv/quantize.py +++ b/python/quantize.py @@ -3,7 +3,7 @@ # Usage: python quantize.py C:\rwkv.cpp-169M-FP32.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1 import argparse -import rwkv_cpp_shared_library +from rwkv_cpp import rwkv_cpp_shared_library def parse_args(): format_names = rwkv_cpp_shared_library.QUANTIZED_FORMAT_NAMES diff --git a/rwkv/requirements.txt b/python/requirements.txt similarity index 100% rename from rwkv/requirements.txt rename to python/requirements.txt diff --git a/python/rwkv_cpp/__init__.py b/python/rwkv_cpp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rwkv/rwkv_cpp_model.py b/python/rwkv_cpp/rwkv_cpp_model.py similarity index 77% rename from rwkv/rwkv_cpp_model.py rename to python/rwkv_cpp/rwkv_cpp_model.py index e612dd1..ae3b0d5 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/python/rwkv_cpp/rwkv_cpp_model.py @@ -1,7 +1,13 @@ import os import torch import multiprocessing -import rwkv_cpp_shared_library + +# I'm sure this is not strictly correct, but let's keep this crutch for now. +try: + import rwkv_cpp_shared_library +except ModuleNotFoundError: + from . import rwkv_cpp_shared_library + from typing import Tuple, Optional, List class RWKVModel: @@ -16,7 +22,7 @@ def __init__( thread_count: int = max(1, multiprocessing.cpu_count() // 2), gpu_layer_count: int = 0, **kwargs - ): + ) -> None: """ Loads the model and prepares it for inference. In case of any error, this method will throw an exception. @@ -40,31 +46,30 @@ def __init__( assert thread_count > 0, 'Thread count must be > 0' assert gpu_layer_count >= 0, 'GPU layer count must be >= 0' - self._library = shared_library + self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library - self._ctx = self._library.rwkv_init_from_file(model_path, thread_count) + self._ctx: rwkv_cpp_shared_library.RWKVContext = self._library.rwkv_init_from_file(model_path, thread_count) if gpu_layer_count > 0: self._library.rwkv_gpu_offload_layers(self._ctx, gpu_layer_count) - self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx) - self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx) + self._state_buffer_element_count: int = self._library.rwkv_get_state_buffer_element_count(self._ctx) + self._logits_buffer_element_count: int = self._library.rwkv_get_logits_buffer_element_count(self._ctx) - self._valid = True + self._valid: bool = True @property - def n_vocab(self): + def n_vocab(self) -> int: return self._library.rwkv_get_n_vocab(self._ctx) @property - def n_embed(self): + def n_embed(self) -> int: return self._library.rwkv_get_n_embed(self._ctx) @property - def n_layer(self): + def n_layer(self) -> int: return self._library.rwkv_get_n_layer(self._ctx) - def eval( self, token: int, @@ -96,19 +101,19 @@ def eval( assert self._valid, 'Model was freed' if state_in is not None: - validate_tensor(state_in, 'state_in', self._state_buffer_element_count) + self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count) state_in_ptr = state_in.data_ptr() else: state_in_ptr = 0 if state_out is not None: - validate_tensor(state_out, 'state_out', self._state_buffer_element_count) + self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count) else: state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu') if logits_out is not None: - validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) + self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) else: logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu') @@ -138,7 +143,7 @@ def eval_sequence( this limit when using large models and/or large sequence lengths. Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models. - If you get `GGML_ASSERT: ...\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. + If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. To get rid of the assertion failure, reduce the model size and/or sequence length. In case of any error, this method will throw an exception. @@ -163,19 +168,19 @@ def eval_sequence( assert self._valid, 'Model was freed' if state_in is not None: - validate_tensor(state_in, 'state_in', self._state_buffer_element_count) + self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count) state_in_ptr = state_in.data_ptr() else: state_in_ptr = 0 if state_out is not None: - validate_tensor(state_out, 'state_out', self._state_buffer_element_count) + self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count) else: state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu') if logits_out is not None: - validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) + self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) else: logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu') @@ -189,7 +194,7 @@ def eval_sequence( return logits_out, state_out - def free(self): + def free(self) -> None: """ Frees all allocated resources. In case of any error, this method will throw an exception. @@ -202,13 +207,13 @@ def free(self): self._library.rwkv_free(self._ctx) - def __del__(self): + def __del__(self) -> None: # Free the context on GC in case user forgot to call free() explicitly. if hasattr(self, '_valid') and self._valid: self.free() -def validate_tensor(buf: torch.Tensor, name: str, size: int) -> None: - assert buf.device == torch.device('cpu'), f'{name} is not on CPU' - assert buf.dtype == torch.float32, f'{name} is not of type float32' - assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' - assert buf.is_contiguous(), f'{name} is not contiguous' + def _validate_tensor(self, buf: torch.Tensor, name: str, size: int) -> None: + assert buf.device == torch.device('cpu'), f'{name} is not on CPU' + assert buf.dtype == torch.float32, f'{name} is not of type float32' + assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' + assert buf.is_contiguous(), f'{name} is not contiguous' diff --git a/rwkv/rwkv_cpp_shared_library.py b/python/rwkv_cpp/rwkv_cpp_shared_library.py similarity index 85% rename from rwkv/rwkv_cpp_shared_library.py rename to python/rwkv_cpp/rwkv_cpp_shared_library.py index edc4736..951ed5b 100644 --- a/rwkv/rwkv_cpp_shared_library.py +++ b/python/rwkv_cpp/rwkv_cpp_shared_library.py @@ -2,9 +2,9 @@ import sys import ctypes import pathlib -from typing import Optional, List +from typing import Optional, List, Tuple, Callable -QUANTIZED_FORMAT_NAMES = ( +QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = ( 'Q4_0', 'Q4_1', 'Q5_0', @@ -17,15 +17,15 @@ class RWKVContext: - def __init__(self, ptr: ctypes.pointer): - self.ptr = ptr + def __init__(self, ptr: ctypes.pointer) -> None: + self.ptr: ctypes.pointer = ptr class RWKVSharedLibrary: """ Python wrapper around rwkv.cpp shared library. """ - def __init__(self, shared_library_path: str): + def __init__(self, shared_library_path: str) -> None: """ Loads the shared library from specified file. In case of any error, this method will throw an exception. @@ -146,7 +146,7 @@ def rwkv_eval( this limit when using large models and/or large sequence lengths. Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models. - If you get `GGML_ASSERT: ...\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. + If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. To get rid of the assertion failure, reduce the model size and/or sequence length. Throws an exception in case of any error. Error messages would be printed to stderr. @@ -297,7 +297,6 @@ def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int: return self.library.rwkv_get_n_vocab(ctx.ptr) - def load_rwkv_shared_library() -> RWKVSharedLibrary: """ Attempts to find rwkv.cpp shared library and load it. @@ -313,25 +312,40 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary: else: file_name = 'librwkv.so' - repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent - - paths = [ - # If we are in "rwkv" directory - f'../bin/Release/{file_name}', - # If we are in repo root directory - f'bin/Release/{file_name}', - # If we compiled in build directory - f'build/bin/Release/{file_name}', - # If we compiled in build directory - f'build/{file_name}', - # Search relative to this file - str(repo_root_dir / 'bin' / 'Release' / file_name), - # Fallback - str(repo_root_dir / file_name) + # Possible sub-paths to the library relative to the repo dir. + child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [ + # No lookup for Debug config here. + # I assume that if a user wants to debug the library, + # they will be able to find the library and set the exact path explicitly. + lambda p: p / 'bin' / 'Release' / file_name, + lambda p: p / 'bin' / file_name, + # Some people prefer to build in the "build" subdirectory. + lambda p: p / 'build' / 'bin' / 'Release' / file_name, + lambda p: p / 'build' / file_name, + # Fallback. + lambda p: p / file_name + ] + + working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd())) + + parent_paths: List[pathlib.Path] = [ + # Possible repo dirs relative to the working dir. + # ./python/rwkv_cpp + working_dir.parent.parent, + # ./python + working_dir.parent, + # . + working_dir, + # Repo dir relative to this Python file. + pathlib.Path(os.path.abspath(__file__)).parent.parent.parent ] - for path in paths: - if os.path.isfile(path): - return RWKVSharedLibrary(path) + for parent_path in parent_paths: + for child_path in child_paths: + full_path: pathlib.Path = child_path(parent_path) + + if os.path.isfile(full_path): + return RWKVSharedLibrary(str(full_path)) - return RWKVSharedLibrary(paths[-1]) + assert False, (f'Failed to find {file_name} automatically; ' + f'you need to find the library and create RWKVSharedLibrary specifying the path to it') diff --git a/rwkv/rwkv_vocab_v20230424.txt b/python/rwkv_cpp/rwkv_vocab_v20230424.txt similarity index 100% rename from rwkv/rwkv_vocab_v20230424.txt rename to python/rwkv_cpp/rwkv_vocab_v20230424.txt diff --git a/rwkv/rwkv_tokenizer.py b/python/rwkv_cpp/rwkv_world_tokenizer.py similarity index 71% rename from rwkv/rwkv_tokenizer.py rename to python/rwkv_cpp/rwkv_world_tokenizer.py index cf033b7..ca864ef 100644 --- a/rwkv/rwkv_tokenizer.py +++ b/python/rwkv_cpp/rwkv_world_tokenizer.py @@ -1,23 +1,19 @@ import os -import tokenizers import pathlib from typing import List, Set, Tuple, Callable # Taken from https://github.com/BlinkDL/ChatRWKV/tree/main/tokenizer/rwkv_tokenizer.py class Trie: - __slots__ = tuple('ch,to,values,front'.split(',')) + __slots__ = ('ch', 'to', 'values', 'front') - to: List - values: Set - - def __init__(self, front=None, ch=None): + def __init__(self, front=None, ch=None) -> None: self.ch = ch - self.to = [None for _ in range(256)] - self.values = set() + self.to: List = [None for _ in range(256)] + self.values: Set = set() self.front = front - def add(self, key: bytes, idx: int = 0, val=None): + def add(self, key: bytes, idx: int = 0, val=None) -> 'Trie': if idx == len(key): if val is None: val = key @@ -33,7 +29,7 @@ def add(self, key: bytes, idx: int = 0, val=None): return self.to[ch].add(key, idx=idx + 1, val=val) - def find_longest(self, key: bytes, idx: int = 0): + def find_longest(self, key: bytes, idx: int = 0) -> Tuple[int, 'Trie', set]: u: Trie = self ch: int = key[idx] ret = None @@ -54,12 +50,12 @@ def find_longest(self, key: bytes, idx: int = 0): return ret - def __repr__(self): + def __repr__(self) -> str: fr = self ret = [] - while fr != None: - if fr.ch != None: + while fr is not None: + if fr.ch is not None: ret.append(fr.ch) fr = fr.front @@ -68,7 +64,7 @@ def __repr__(self): class WorldTokenizer: - def __init__(self, file_path): + def __init__(self, file_path) -> None: self.index_to_token = {} with open(file_path, 'r', encoding='utf-8') as f: @@ -82,14 +78,14 @@ def __init__(self, file_path): assert len(x) == int(line[line.rindex(' '):]) self.index_to_token[idx] = x - self.token2idx = {} + self.token_to_index = {} for k, v in self.index_to_token.items(): - self.token2idx[v] = int(k) + self.token_to_index[v] = int(k) self.root = Trie() - for t, i in self.token2idx.items(): + for t, i in self.token_to_index.items(): _ = self.root.add(t, val=(t, i)) def encode_bytes(self, src: bytes) -> List[int]: @@ -116,19 +112,14 @@ def decode(self, tokens: List[int]) -> str: # Downstream code needs to detect \uFFFD and attempt to postpone decoding until more tokens arrive and UTF-8 sequences are complete. return self.decode_bytes(tokens).decode('utf-8', errors='replace') -def get_tokenizer(tokenizer: str = '20B') -> Tuple[ +def get_world_tokenizer_v20230424() -> Tuple[ Callable[[List[int]], str], Callable[[str], List[int]] ]: + """ + Loads the default World tokenizer, commonly used in RWKV v4 World models. + Returns a tuple of `decode` and `encode` functions. + """ parent: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent - - if tokenizer == 'world': - print('Loading world tokenizer') - tokenizer: WorldTokenizer = WorldTokenizer(parent / 'rwkv_vocab_v20230424.txt') - return tokenizer.decode, tokenizer.encode - elif tokenizer == '20B': - print('Loading 20B tokenizer') - tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(parent / '20B_tokenizer.json')) - return tokenizer.decode, lambda x: tokenizer.encode(x).ids - else: - assert False, f'Unknown tokenizer {tokenizer}' + tokenizer: WorldTokenizer = WorldTokenizer(parent / 'rwkv_vocab_v20230424.txt') + return tokenizer.decode, tokenizer.encode diff --git a/rwkv/rwkv_tokenizer.test.py b/python/rwkv_cpp/rwkv_world_tokenizer.test.py similarity index 73% rename from rwkv/rwkv_tokenizer.test.py rename to python/rwkv_cpp/rwkv_world_tokenizer.test.py index bade393..bc295d6 100644 --- a/rwkv/rwkv_tokenizer.test.py +++ b/python/rwkv_cpp/rwkv_world_tokenizer.test.py @@ -1,8 +1,8 @@ -from rwkv_tokenizer import get_tokenizer +import rwkv_world_tokenizer from typing import List -def test(): - tokenizer_decode, tokenizer_encode = get_tokenizer('world') +def test() -> None: + decode, encode = rwkv_world_tokenizer.get_world_tokenizer_v20230424() test_string: str = 'I\'ll \'d test блабла 以下は、]) -> <|endoftext|><|padding|> int' @@ -10,10 +10,10 @@ def test(): 10079, 1682, 3463, 295, 125, 25258, 7588, 2318, 125, 790, 125, 49520, 125, 63, 21888] - actual_tokens: List[int] = tokenizer_encode(test_string) + actual_tokens: List[int] = encode(test_string) assert actual_tokens == expected_tokens, f'\nActual: {actual_tokens}\nExpected: {expected_tokens}' - decoded_string: str = tokenizer_decode(actual_tokens) + decoded_string: str = decode(actual_tokens) assert test_string == decoded_string, f'\nDecoding mismatch: \n{decoded_string}' print('All tests pass') diff --git a/rwkv/sampling.py b/python/sampling.py similarity index 92% rename from rwkv/sampling.py rename to python/sampling.py index 270a4be..3d86e93 100644 --- a/rwkv/sampling.py +++ b/python/sampling.py @@ -4,7 +4,7 @@ from torch.nn import functional as F def sample_logits(out: torch.Tensor, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int: - probs = F.softmax(out.cpu(), dim=-1).numpy() + probs: np.ndarray = F.softmax(out.cpu(), dim=-1).numpy() return sample_probs(probs, temperature, top_p, logit_bias) @@ -16,7 +16,7 @@ def sample_probs(probs: np.ndarray, temperature: float = 1.0, top_p: float = 0.8 top_p = 1.0 if logit_bias is not None and len(logit_bias) > 0: - logits = np.log(probs) + logits: np.ndarray = np.log(probs) ids, values = zip(*logit_bias.items()) logits[list(ids)] += values diff --git a/python/tokenizer_util.py b/python/tokenizer_util.py new file mode 100644 index 0000000..1f5e06d --- /dev/null +++ b/python/tokenizer_util.py @@ -0,0 +1,21 @@ +import os +import tokenizers +import pathlib +from rwkv_cpp import rwkv_world_tokenizer +from typing import List, Tuple, Callable + +def get_tokenizer(tokenizer: str = '20B') -> Tuple[ + Callable[[List[int]], str], + Callable[[str], List[int]] +]: + parent: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent + + if tokenizer == 'world': + print('Loading world tokenizer') + return rwkv_world_tokenizer.get_world_tokenizer_v20230424() + elif tokenizer == '20B': + print('Loading 20B tokenizer') + tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(parent / '20B_tokenizer.json')) + return tokenizer.decode, lambda x: tokenizer.encode(x).ids + else: + assert False, f'Unknown tokenizer {tokenizer}' diff --git a/rwkv.cpp b/rwkv.cpp index 19c394d..f7406bf 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -2,12 +2,6 @@ #include "ggml.h" #include "ggml-alloc.h" -#ifdef GGML_USE_CUBLAS -#include "ggml/src/ggml-cuda.h" -#elif defined(GGML_USE_CLBLAST) -#include "ggml/src/ggml-opencl.h" -#endif - #include #include #include @@ -25,1169 +19,37 @@ #include #if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) -#define stat _stat64 -#define fstat _fstat64 -#define ftell _ftelli64 -#define fseek _fseeki64 - -#ifndef NDEBUG -#include -#define RWKV_MAYBE_BREAK __debugbreak() -#endif +# define stat _stat64 +# define fstat _fstat64 +# define ftell _ftelli64 +# define fseek _fseeki64 +# if !defined(NDEBUG) +# include +# define RWKV_MAYBE_BREAK __debugbreak() +# endif #else -#if !defined(__APPLE__) -#define ftell ftello -#define fseek fseeko -#endif +# if !defined(__APPLE__) +# define ftell ftello +# define fseek fseeko +# endif #endif static_assert(sizeof(stat::st_size) >= 8, "File offsets should be 64-bit or else rwkv.cpp will not be able to load model files over 2 GB"); static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit or else rwkv.cpp will not be able to load model files over 2 GB"); -// --- Error handling --- - -thread_local enum rwkv_error_flags global_last_error = RWKV_ERROR_NONE; -thread_local bool global_print_errors = true; - -inline enum rwkv_error_flags operator|(enum rwkv_error_flags a, enum rwkv_error_flags b) { - return static_cast(static_cast(a) | static_cast(b)); -} - -inline enum rwkv_error_flags operator|=(enum rwkv_error_flags & a, enum rwkv_error_flags b) { - return a = a | b; -} - -#define RWKV_MSG(...) do { if (global_print_errors) fprintf(stderr, __VA_ARGS__); } while (0) -#define RWKV_CTX_MSG(ctx, ...) do { if (ctx->print_errors) fprintf(stderr, __VA_ARGS__); } while (0) - -// If the condition x is false, adds ERR_VAL to the last error, and returns RET_VAL. -#define RWKV_ASSERT(ERR_VAL, RET_VAL, x) do { \ - if (!(x)) { \ - global_last_error |= ERR_VAL; \ - RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ - RWKV_MAYBE_BREAK; \ - return RET_VAL; \ - } } while (0) - -// If the condition x is false, adds ERR_VAL to the last error, prints a message to stderr, and returns RET_VAL. -#define RWKV_ASSERT_MSG(ERR_VAL, RET_VAL, x, ...) do { \ - if (!(x)) { \ - global_last_error |= ERR_VAL; \ - RWKV_MSG(__VA_ARGS__); \ - RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ - RWKV_MAYBE_BREAK; \ - return RET_VAL; \ - } } while (0) - -// If the condition x is false, adds ERR_VAL to the ctx's last error, prints a message to stderr, and returns RET_VAL. -#define RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, RET_VAL, x, ...) do { \ - if (!(x)) { \ - ((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \ - RWKV_CTX_MSG(ctx, __VA_ARGS__); \ - RWKV_CTX_MSG(ctx, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ - RWKV_MAYBE_BREAK; \ - return RET_VAL; \ - } } while (0) - -// If the condition x is false, adds ERR_VAL to the ctx's last error, and returns RET_VAL. -#define RWKV_CTX_ASSERT(ctx, ERR_VAL, RET_VAL, x) do { \ - if (!(x)) { \ - ((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \ - RWKV_CTX_MSG(ctx, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ - RWKV_MAYBE_BREAK; \ - return RET_VAL; \ - } } while (0) - -// If the condition x is false, returns RET_VAL. -#define RWKV_ENSURE(RET_VAL, x) do { \ - if (!(x)) { \ - RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ - RWKV_MAYBE_BREAK; \ - return RET_VAL; \ - } } while (0) - -// If the condition x is false, prints a message to stderr, and returns RET_VAL. -#define RWKV_ENSURE_MSG(RET_VAL, x, ...) do { \ - if (!(x)) { \ - RWKV_MSG(__VA_ARGS__); \ - RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ - RWKV_MAYBE_BREAK; \ - return RET_VAL; \ - } } while (0) - -// If the condition x is false, prints a message to stderr, and returns RET_VAL. -#define RWKV_CTX_ENSURE_MSG(ctx, RET_VAL, x, ...) do { \ - if (!(x)) { \ - ((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \ - RWKV_CTX_MSG(ctx, __VA_ARGS__); \ - RWKV_CTX_MSG(ctx, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ - RWKV_MAYBE_BREAK; \ - return RET_VAL; \ - } } while (0) - -#define RWKV_ASSERT_FALSE_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, false, x, __VA_ARGS__) -#define RWKV_ASSERT_NULL_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, NULL, x, __VA_ARGS__) - -#define RWKV_CTX_ASSERT_FALSE_MSG(ctx, ERR_VAL, x, ...) RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, false, x, __VA_ARGS__) - -#define RWKV_ASSERT_FALSE(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, false, x) -#define RWKV_ASSERT_NULL(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, NULL, x) - -#define RWKV_CTX_ASSERT_FALSE(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, false, x) - -#define RWKV_ENSURE_OR_FALSE(x) RWKV_ENSURE(false, x) -#define RWKV_ENSURE_OR_NULL(x) RWKV_ENSURE(NULL, x) -#define RWKV_ENSURE_OR_FALSE_MSG(x, ...) RWKV_ENSURE_MSG(false, x, __VA_ARGS__) - -// --- Utilities --- - -size_t rwkv_tensor_nbytes(const enum ggml_type type, const int64_t width, const int64_t height) { - return (ggml_type_size(type) * width * height) / ggml_blck_size(type); -} - -// For some reason, ggml_nbytes calculates the size in a way incompatible with rwkv.cpp -size_t rwkv_tensor_nbytes(const struct ggml_tensor * tensor) { - return rwkv_tensor_nbytes(tensor->type, tensor->ne[0], tensor->ne[1]); -} - -size_t rwkv_ggml_overhead() { - return ggml_tensor_overhead() * GGML_MAX_NODES + ggml_graph_overhead(); -} - -struct ggml_context * rwkv_init_ggml_context(const size_t memory_size, const bool no_alloc) { - struct ggml_init_params init_params = { - memory_size, - NULL, - no_alloc - }; - - return ggml_init(init_params); -} - -// --- IO utilities --- - -// Reads a single uint32 value from a file. -bool rwkv_fread_uint32(FILE * file, uint32_t & dest) { - return fread((void *) &dest, sizeof(uint32_t), 1, file) == 1; -} - -// Reads a single string value from a file. -bool rwkv_fread_string(FILE * file, size_t length, std::string & dest) { - dest.resize(length); - return fread((void *) dest.data(), length, 1, file) == 1; -} - -// Reads a single data buffer from a file. -bool rwkv_fread_data(FILE * file, size_t length, void * dest) { - return fread(dest, length, 1, file) == 1; -} - -// Writes a single uint32 value to a file. -bool rwkv_fwrite_uint32(FILE * file, const uint32_t value) { - return fwrite((const void *) &value, sizeof(uint32_t), 1, file); -} - -// Writes a single string value to a file. -bool rwkv_fwrite_string(FILE * file, const std::string & value) { - return fwrite((const void *) value.data(), value.length(), 1, file) == 1; -} - -// Writes a single data buffer to a file. -bool rwkv_fwrite_data(FILE * file, const void * data, const size_t length) { - return fwrite(data, length, 1, file) == 1; -} - -// --- File handling --- - -#define TYPE_UNKNOWN TYPE_COUNT - -enum rwkv_type { - TYPE_FP32, - TYPE_FP16, - TYPE_Q4_0, - TYPE_Q4_1, - TYPE_Q4_1_O, // Unsupported - TYPE_Q4_2, // Unsupported - TYPE_Q4_3, // Unsupported - TYPE_Q5_0, - TYPE_Q5_1, - TYPE_Q8_0, - TYPE_COUNT -}; - -#define GGML_TYPE_UNKNOWN GGML_TYPE_COUNT - -extern const enum ggml_type rwkv_type_to_ggml[TYPE_COUNT + 1] = { - GGML_TYPE_F32, /* FP32 */ - GGML_TYPE_F16, /* FP16 */ - GGML_TYPE_Q4_0, /* Q4_0 */ - GGML_TYPE_Q4_1, /* Q4_1 */ - GGML_TYPE_UNKNOWN, /* Q4_1_O */ - GGML_TYPE_UNKNOWN, /* Q4_2 */ - GGML_TYPE_UNKNOWN, /* Q4_3 */ - GGML_TYPE_Q5_0, /* Q5_0 */ - GGML_TYPE_Q5_1, /* Q5_1 */ - GGML_TYPE_Q8_0, /* Q8_0 */ - GGML_TYPE_COUNT /* COUNT */ -}; - -extern const enum rwkv_type rwkv_type_from_ggml[GGML_TYPE_COUNT + 1] = { - TYPE_FP32, /* FP32 */ - TYPE_FP16, /* FP16 */ - TYPE_Q4_0, /* Q4_0 */ - TYPE_Q4_1, /* Q4_1 */ - TYPE_Q4_2, /* Q4_2 */ - TYPE_Q4_3, /* Q4_3 */ - TYPE_Q5_0, /* Q5_0 */ - TYPE_Q5_1, /* Q5_1 */ - TYPE_Q8_0, /* Q8_0 */ - TYPE_COUNT, /* Q8_1 */ - TYPE_COUNT, /* I8 */ - TYPE_COUNT, /* I16 */ - TYPE_COUNT, /* I32 */ - TYPE_COUNT, /* COUNT */ -}; - -extern const char * rwkv_type_to_string[TYPE_COUNT + 1] = {"FP32", "FP16", "Q4_0", "Q4_1", "Q4_1_O", "Q4_2", "Q4_3", "Q5_0", "Q5_1", "Q8_0", "unknown"}; - -enum rwkv_type rwkv_type_from_string(const char * str) { - for (int ord = 0; ord < TYPE_COUNT; ord++) { - if (strcmp(str, rwkv_type_to_string[ord]) == 0) { - return (enum rwkv_type) ord; - } - } - - return TYPE_UNKNOWN; -} - -struct rwkv_file_header { - uint32_t magic; - uint32_t version; - uint32_t n_vocab; - uint32_t n_embed; - uint32_t n_layer; - uint32_t data_type; -}; - -bool rwkv_is_file_version_in_range(uint32_t version) { - return version >= RWKV_FILE_VERSION_MIN && version <= RWKV_FILE_VERSION_MAX; -} - -bool rwkv_fread_file_header(FILE * file, struct rwkv_file_header & header, bool verify_data_type = true) { - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_file_header), &header)); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_MAGIC, header.magic == RWKV_FILE_MAGIC); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_VERSION, rwkv_is_file_version_in_range(header.version), "Unsupported file version %" PRId32, header.version); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, header.data_type < TYPE_COUNT, "Model data type out of range (%" PRId32 " > %" PRId32 ")", header.data_type, TYPE_COUNT - 1); - - if (verify_data_type) { - enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type]; - - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_DATA_TYPE, - ggml_type != GGML_TYPE_UNKNOWN, - "Models in %s format cannot be loaded anymore because the format was removed.\n" - "You need to quantize the model into another format or use an older version of rwkv.cpp.\n" - "See https://github.com/saharNooby/rwkv.cpp#compatibility for more info", - rwkv_type_to_string[header.data_type] - ); - - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_DATA_TYPE, - (!ggml_is_quantized(ggml_type) || header.version == RWKV_FILE_VERSION_1), - "The quantized model file in %s format was created with an old version of rwkv.cpp and can not be loaded anymore.\n" - "You need to requantize the model or use an older version of rwkv.cpp.\n" - "See https://github.com/saharNooby/rwkv.cpp#compatibility for more info", - rwkv_type_to_string[header.data_type] - ); - } - - return true; -} - -bool rwkv_fwrite_file_header(FILE * file, const struct rwkv_file_header & header) { - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_data(file, &header, sizeof(struct rwkv_file_header))); - return true; -} - -struct rwkv_tensor_header { - uint32_t dim_count; - uint32_t key_length; - uint32_t data_type; - uint32_t width; - uint32_t height; - - size_t size() const; -}; - -size_t rwkv_tensor_header::size() const { - return rwkv_tensor_nbytes(rwkv_type_to_ggml[this->data_type], this->width, this->height); -} - -struct rwkv_tensor { - struct rwkv_tensor_header header; - std::string name; - uint8_t * data; -}; - -bool rwkv_fread_tensor_header(FILE * file, struct rwkv_tensor_header & header) { - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_tensor_header) - sizeof(uint32_t), &header)); - header.height = 1; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_SHAPE, header.dim_count == 1 || header.dim_count == 2, "Tensor has an invalid shape (%" PRId32 " dimensions)", header.dim_count); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, header.data_type < TYPE_COUNT, "Tensor data type out of range (%" PRId32 " > %" PRId32 ")", header.data_type, TYPE_COUNT - 1); - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_DATA_TYPE, - rwkv_type_to_ggml[header.data_type] != GGML_TYPE_UNKNOWN, - "Tensor data type (%s) is no longer supported", - rwkv_type_to_string[header.data_type] - ); - - if (header.dim_count == 2) { - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_uint32(file, header.height)); - } - - return true; -} - -bool rwkv_fwrite_tensor_header(FILE * file, const struct rwkv_tensor_header & header) { - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_data(file, &header, sizeof(struct rwkv_tensor_header) - (header.dim_count == 1 ? sizeof(uint32_t) : 0))); - return true; -} - -bool rwkv_fskip_tensor_name_and_data(FILE * file, const struct rwkv_tensor_header & header) { - return fseek(file, header.key_length + header.size(), SEEK_CUR) == 0; -} - -bool rwkv_fskip_tensor_data(FILE * file, const struct rwkv_tensor_header & header) { - return fseek(file, header.size(), SEEK_CUR) == 0; -} - -bool rwkv_fread_tensor_header_and_skip(FILE * file, struct rwkv_tensor_header & header) { - RWKV_ENSURE_OR_FALSE(rwkv_fread_tensor_header(file, header)); - RWKV_ASSERT_FALSE(RWKV_ERROR_DATA, rwkv_fskip_tensor_name_and_data(file, header)); - return true; -} - -bool rwkv_fread_tensor_data(FILE * file, struct rwkv_tensor & output, void * buffer = NULL) { - size_t data_size = output.header.size(); - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, output.header.key_length, output.name)); - - if (buffer) { - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, data_size, buffer)); - } else { - output.data = NULL; - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fskip_tensor_name_and_data(file, output.header)); - } - - return true; -} - -bool rwkv_fread_tensor(FILE * file, struct rwkv_tensor & output, void * buffer = NULL) { - RWKV_ENSURE_OR_FALSE(rwkv_fread_tensor_header(file, output.header)); - RWKV_ENSURE_OR_FALSE(rwkv_fread_tensor_data(file, output, buffer)); - return true; -} - -bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header & header, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name"); - - enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type]; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, ggml_type != GGML_TYPE_UNKNOWN, "Unsupported tensor data type %s from %s", rwkv_type_to_string[header.data_type], name.c_str()); - - tensor = header.dim_count == 1 - ? ggml_new_tensor_1d(ctx, ggml_type, header.width) - : ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height); - - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); - ggml_set_name(tensor, name.c_str()); - - // Tensor data may be NULL if no_alloc is true - if (tensor->data != NULL) { - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, rwkv_tensor_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); - } else { - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fskip_tensor_data(file, header), "Failed to skip tensor data from %s", name.c_str()); - } - - return true; -} - -bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { - struct rwkv_tensor_header header; - RWKV_ENSURE_OR_FALSE_MSG(rwkv_fread_tensor_header(file, header), "Invalid tensor header"); - return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor); -} - -bool rwkv_fwrite_tensor(FILE * file, const struct rwkv_tensor & tensor) { - RWKV_ENSURE_OR_FALSE(rwkv_fwrite_tensor_header(file, tensor.header)); - RWKV_ENSURE_OR_FALSE(rwkv_fwrite_string(file, tensor.name)); - RWKV_ENSURE_OR_FALSE(rwkv_fwrite_data(file, tensor.data, tensor.header.size())); - return true; -} - -// --- Model loading --- +#include "rwkv_error_handling.inc" -struct rwkv_layer { - struct ggml_tensor * ln1_weight; - struct ggml_tensor * ln1_bias; +#include "rwkv_utilities.inc" - // RWKV, also called "attention" by the author. - struct ggml_tensor * att_time_mix_k; - struct ggml_tensor * att_time_mix_v; - struct ggml_tensor * att_time_mix_r; - struct ggml_tensor * att_time_first; - struct ggml_tensor * att_time_decay; - struct ggml_tensor * att_key; - struct ggml_tensor * att_value; - struct ggml_tensor * att_receptance; - struct ggml_tensor * att_output; +#include "rwkv_file_format.inc" - struct ggml_tensor * ln2_weight; - struct ggml_tensor * ln2_bias; - - // FFN. - struct ggml_tensor * ffn_time_mix_k; - struct ggml_tensor * ffn_time_mix_r; - struct ggml_tensor * ffn_key; - struct ggml_tensor * ffn_value; - struct ggml_tensor * ffn_receptance; -}; - -// The model holds all parameter tensors and the ggml context containing them. -// Each tensor has data and can be used in computations happening in other contexts. -struct rwkv_model { - // This context holds all parameter tensors. - // It must not be used for computations. - struct ggml_context * ggml_ctx; - - struct rwkv_file_header header; - - struct ggml_tensor * emb; - - struct ggml_tensor * ln0_weight; - struct ggml_tensor * ln0_bias; - - std::unique_ptr layers; - - struct ggml_tensor * ln_out_weight; - struct ggml_tensor * ln_out_bias; - - struct ggml_tensor * head; - - // How many layers were offloaded to the GPU. - size_t offloaded_layer_count; - - // How many RWKV contexts reference this model. - int reference_count; -}; - -struct rwkv_file { - FILE * file; - - rwkv_file(FILE * file): file(file) {} - - ~rwkv_file() { - if (file) { - fclose(file); - } - } -}; - -// https://stackoverflow.com/a/6458689 -template -bool rwkv_set_params(struct rwkv_model & model, F callback) { - RWKV_ENSURE_OR_FALSE(callback("emb.weight", model.emb)); - RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.weight", model.ln0_weight)); - RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.bias", model.ln0_bias)); - - uint32_t n_layer = model.header.n_layer; - std::unique_ptr layers(new(std::nothrow) struct rwkv_layer[n_layer]); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, layers.get(), "Failed to allocate model layers"); - model.layers = std::move(layers); - - for (uint32_t i = 0; i < n_layer; i++) { - char buffer[128]; - size_t offset = sprintf(buffer, "blocks.%" PRId32 ".", i); - - rwkv_layer & layer = model.layers[i]; - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.weight"), buffer), layer.ln1_weight)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.bias"), buffer), layer.ln1_bias)); - - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output)); - - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias)); - - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance)); - } - - RWKV_ENSURE_OR_FALSE(callback("ln_out.weight", model.ln_out_weight)); - RWKV_ENSURE_OR_FALSE(callback("ln_out.bias", model.ln_out_bias)); - RWKV_ENSURE_OR_FALSE(callback("head.weight", model.head)); - - return true; -} - -// Creates a ggml context and loads all parameter tensors from a model file. -bool rwkv_load_model_from_file(const char * file_path, struct rwkv_model & model) { - struct stat file_stat; - - std::unordered_map parameters; - - rwkv_file file(fopen(file_path, "rb")); - - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file.file, "Failed to open file %s", file_path); - // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to get the file length. - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file.file), &file_stat) == 0, "Failed to stat file %s", file_path); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(file.file, model.header), "Invalid file header"); - - model.ggml_ctx = rwkv_init_ggml_context( - // ggml tensors must be aligned; assuming here that overhead of parameter headers, included in the file size, will account for that. - file_stat.st_size + rwkv_ggml_overhead(), - false - ); - - std::string name; - - struct ggml_tensor * tensor; - - while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, model.ggml_ctx, name, tensor), "Failed to read a model parameter"); - - parameters[std::move(name)] = tensor; - } - - std::unordered_map & parameters_ref = parameters; - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(model, [&](const char * key, struct ggml_tensor *& dest) { - struct ggml_tensor * tensor = parameters_ref[key]; - RWKV_ENSURE_OR_FALSE_MSG(tensor, "Model parameter %s not found", key); - dest = tensor; - return true; - })); - - // Verify order of dimensions - struct ggml_tensor * emb = model.emb; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model.header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model.header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); - - return true; -} - -// --- Operators --- - -void rwkv_exp_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) { - GGML_ASSERT(dest->type == GGML_TYPE_F32); - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(dest)); - GGML_ASSERT(ggml_is_contiguous(src)); - GGML_ASSERT(ggml_are_same_shape(src, dest)); - - // Assuming 2D tensors. - int64_t element_count = src->ne[0] * src->ne[1]; - float * src_data = (float *) src->data; - float * dest_data = (float *) dest->data; - - for (int64_t i = 0; i < element_count; i++) { - dest_data[i] = expf(src_data[i]); - } - - // Suppress warnings for unused parameters. - (void) ith; - (void) nth; - (void) userdata; -} - -void rwkv_1_minus_x_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) { - GGML_ASSERT(dest->type == GGML_TYPE_F32); - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(dest)); - GGML_ASSERT(ggml_is_contiguous(src)); - GGML_ASSERT(ggml_are_same_shape(src, dest)); - - // Assuming 2D tensors. - int64_t element_count = src->ne[0] * src->ne[1]; - float * src_data = (float *) src->data; - float * dest_data = (float *) dest->data; - - for (int64_t i = 0; i < element_count; i++) { - dest_data[i] = 1.0F - src_data[i]; - } - - // Suppress warnings for unused parameters. - (void) ith; - (void) nth; - (void) userdata; -} - -void rwkv_sigmoid_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) { - GGML_ASSERT(dest->type == GGML_TYPE_F32); - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(dest)); - GGML_ASSERT(ggml_is_contiguous(src)); - GGML_ASSERT(ggml_are_same_shape(src, dest)); - - // Assuming 2D tensors. - int64_t element_count = src->ne[0] * src->ne[1]; - float * src_data = (float *) src->data; - float * dest_data = (float *) dest->data; - - for (int64_t i = 0; i < element_count; i++) { - dest_data[i] = 1.0F / (1.0F + expf(-src_data[i])); - } - - // Suppress warnings for unused parameters. - (void) ith; - (void) nth; - (void) userdata; -} - -void rwkv_max_impl( - struct ggml_tensor * dest, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - int ith, - int nth, - void * userdata -) { - GGML_ASSERT(dest->type == GGML_TYPE_F32); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(dest)); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_are_same_shape(src0, dest)); - GGML_ASSERT(ggml_are_same_shape(src1, dest)); - - // Assuming 2D tensors. - int64_t element_count = src0->ne[0] * src0->ne[1]; - float * src0_data = (float *) src0->data; - float * src1_data = (float *) src1->data; - float * dest_data = (float *) dest->data; - - for (int64_t i = 0; i < element_count; i++) { - dest_data[i] = fmaxf(src0_data[i], src1_data[i]); - } - - // Suppress warnings for unused parameters. - (void) ith; - (void) nth; - (void) userdata; -} - -struct ggml_tensor * rwkv_exp(ggml_context * ctx, struct ggml_tensor * x) { - return ggml_map_custom1(ctx, x, rwkv_exp_impl, 1, NULL); -} - -struct ggml_tensor * rwkv_1_minus_x(ggml_context * ctx, struct ggml_tensor * x) { - return ggml_map_custom1(ctx, x, rwkv_1_minus_x_impl, 1, NULL); -} - -struct ggml_tensor * rwkv_sigmoid(ggml_context * ctx, struct ggml_tensor * x) { - return ggml_map_custom1(ctx, x, rwkv_sigmoid_impl, 1, NULL); -} - -struct ggml_tensor * rwkv_max(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y) { - return ggml_map_custom2(ctx, x, y, rwkv_max_impl, 1, NULL); -} - -struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) { - // LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias` - // Looks like ggml_norm does the first part, we only need to apply weight & bias. - return ggml_add_inplace(ctx, ggml_mul_inplace(ctx, ggml_norm(ctx, x, 1e-5F), weight), bias); -} - -// --- Implementation --- - -// View tensors of a state of a single layer. -struct rwkv_layer_state { - struct ggml_tensor * ffn_xx; - struct ggml_tensor * att_xx; - struct ggml_tensor * att_aa; - struct ggml_tensor * att_bb; - struct ggml_tensor * att_pp; -}; - -// The computation graph holds ggml context and the ggml cgraph. -// It can be either a serial or a sequential graph. -struct rwkv_computation_graph { - struct ggml_context * ggml_ctx; - // ggml_cgraph is so large that it can cause stack overflows if not stored on the heap. - std::unique_ptr cgraph; - - // Input tensors. - struct ggml_tensor * tokens; - struct ggml_tensor * input_state; - std::unique_ptr input_layers; - - // Output tensors. - struct ggml_tensor * output_state; - std::unique_ptr output_layers; - struct ggml_tensor * logits; - - // ggml graph counters before the graph was extended with logits tensor. - int pre_logits_nodes; - int pre_logits_leafs; - // ggml graph counters after the graph was extended with logits tensor. - int post_logits_nodes; - int post_logits_leafs; -}; - -// The context holds the model and both serial and sequential computation graphs. -struct rwkv_context { - struct rwkv_model * model; - - // The serial graph implements the traditional RNN mode that processes only one token at a time (serial mode). - struct rwkv_computation_graph serial_graph; - // The sequence graph implements the "sequence mode" (or transformer/GPT mode) that processes multiple tokens at a time. - // This can be an order of magnitude or so faster than serial execution if used properly. - struct rwkv_computation_graph sequential_graph; - size_t last_used_sequence_length; - - uint32_t n_threads; - - enum rwkv_error_flags last_error; - bool print_errors; -}; - -void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) { - bool * ptr = ctx ? &ctx->print_errors : &global_print_errors; - *ptr = print_errors; -} - -bool rwkv_get_print_errors(struct rwkv_context * ctx) { - return ctx ? ctx->print_errors : global_print_errors; -} - -enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) { - enum rwkv_error_flags * ptr = ctx ? &ctx->last_error : &global_last_error; - enum rwkv_error_flags value = *ptr; - *ptr = RWKV_ERROR_NONE; - return value; -} - -void rwkv_carry_x(struct ggml_context * ctx, - struct ggml_tensor * weight, - struct ggml_tensor * bias, - struct ggml_tensor *& x, - struct ggml_tensor *& x_prev, - struct ggml_tensor *& carry -) { - const size_t n_embed = x->ne[0]; - const size_t sequence_len = x->ne[1]; - - if (sequence_len == 1) { - // self.layer_norm(x, self.w.blocks[i].ln2) - x = rwkv_layer_norm(ctx, x, weight, bias); - - // xx = state[5*i+0] - x_prev = carry; - - // state[5*i+0] = x - carry = x; - } else { - // self.layer_norm(x, self.w.blocks[i].ln2) - x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, weight, x), ggml_repeat(ctx, bias, x)); - - // xx = torch.cat((state[5*i+0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:])) - x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len); - x_prev = ggml_set_1d_inplace(ctx, x_prev, carry, 0); - x_prev = ggml_set_1d_inplace(ctx, x_prev, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float)); - - // state[5*i+0] = x[-1,:] - carry = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float)); - } -} +#include "rwkv_model_loading.inc" -void rwkv_att_rkv( - struct ggml_context * ctx, - struct rwkv_layer layer, - struct ggml_tensor * x, - struct ggml_tensor * x_prev, - struct ggml_tensor *& r, - struct ggml_tensor *& k, - struct ggml_tensor *& v -) { - // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) - struct ggml_tensor * xk = ggml_add_inplace(ctx, - ggml_mul(ctx, x, layer.att_time_mix_k), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) - ); +#include "rwkv_operators.inc" - // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) - struct ggml_tensor * xv = ggml_add_inplace(ctx, - ggml_mul(ctx, x, layer.att_time_mix_v), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) - ); - - // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) - struct ggml_tensor * xr = ggml_add_inplace(ctx, - ggml_mul(ctx, x, layer.att_time_mix_r), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) - ); - - // r = torch.sigmoid(rw @ xr) - r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr)); - // k = kw @ xk - k = ggml_mul_mat(ctx, layer.att_key, xk); - // v = vw @ xv - v = ggml_mul_mat(ctx, layer.att_value, xv); -} - -struct ggml_tensor * rwkv_att_wkv( - struct ggml_context * ctx, - struct ggml_tensor * att_time_first, - struct ggml_tensor * att_time_decay, - struct ggml_tensor * k, - struct ggml_tensor * v, - struct ggml_tensor *& aa, - struct ggml_tensor *& bb, - struct ggml_tensor *& pp -) { - // ww = time_first + k - struct ggml_tensor * ww = ggml_add(ctx, att_time_first, k); - // qq = torch.maximum(pp, ww) - struct ggml_tensor * qq = rwkv_max(ctx, pp, ww); - // e1 = torch.exp(pp - qq) - struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq)); - // e2 = torch.exp(ww - qq) - struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); - - // a = e1 * aa + e2 * v - struct ggml_tensor * a = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); - // b = e1 * bb + e2 - struct ggml_tensor * b = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); - - // ww = pp + time_decay - ww = ggml_add(ctx, pp, att_time_decay); - // qq = torch.maximum(ww, k) - qq = rwkv_max(ctx, ww, k); - // e1 = torch.exp(ww - qq) - e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); - // e2 = torch.exp(k[t] - qq) - e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq)); - - // state[5 * i + 2] = e1 * aa + e2 * v - // state[5 * i + 3] = e1 * bb + e2 - // state[5 * i + 4] = qq - aa = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); - bb = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); - pp = qq; - - // wkv = a / b - return ggml_div(ctx, a, b); -} - -struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { - struct ggml_tensor * x_prev; - rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); - - struct ggml_tensor * r, * k, * v; - rwkv_att_rkv(ctx, layer, x, x_prev, r, k, v); - - struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, k, v, state.att_aa, state.att_bb, state.att_pp); - - // ow @ (r * xx) - return ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv)); -} - -struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { - struct ggml_tensor * x_prev; - rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); - - // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) - // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) - struct ggml_tensor * xk = ggml_add_inplace( - ctx, - ggml_mul(ctx, x, layer.ffn_time_mix_k), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) - ); - - // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) - struct ggml_tensor * xr = ggml_add_inplace( - ctx, - ggml_mul(ctx, x, layer.ffn_time_mix_r), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) - ); - - // r = torch.sigmoid(rw @ xr) - struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); - - // k = torch.square(torch.relu(kw @ xk)) - struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); - - // r * (vw @ k) - return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); -} - -void rwkv_create_input_and_output_views( - struct rwkv_layer_state * inputs, - struct rwkv_layer_state * outputs, - struct ggml_tensor * input, - struct ggml_tensor * output, - struct ggml_context * ctx, - size_t n_layer, - size_t n_embed -) { - for (size_t i = 0; i < n_layer; i++) { - struct rwkv_layer_state & input_state = inputs[i]; - input_state.ffn_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 0) * sizeof(float)); - input_state.att_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 1) * sizeof(float)); - input_state.att_aa = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 2) * sizeof(float)); - input_state.att_bb = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 3) * sizeof(float)); - input_state.att_pp = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 4) * sizeof(float)); - - struct rwkv_layer_state & output_state = outputs[i]; - output_state.ffn_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 0) * sizeof(float)); - output_state.att_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 1) * sizeof(float)); - output_state.att_aa = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 2) * sizeof(float)); - output_state.att_bb = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 3) * sizeof(float)); - output_state.att_pp = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 4) * sizeof(float)); - } -} - -// Creates and sets the input and output ggml tensors, builds the computation graph. -bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_computation_graph & graph) { - graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); - - struct rwkv_file_header & header = model.header; - const size_t n_vocab = header.n_vocab; - const size_t n_embed = header.n_embed; - const size_t n_layer = header.n_layer; - - struct ggml_context * ctx = graph.ggml_ctx; - - // Creates a 1-element tensor. - graph.tokens = ggml_new_i32(ctx, 0); - - struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); - struct ggml_tensor * output = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); - - // We collect parts of input state here. Each part is (n_embed) vector. - std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, inputs.get(), "Failed to allocate input state parts"); - - // We collect parts of output state here. Each part is (n_embed) vector. - std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); - - rwkv_create_input_and_output_views(inputs.get(), outputs.get(), input, output, ctx, n_layer, n_embed); - - graph.logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab); - - // x = self.w.emb.weight[token] - struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, graph.tokens); - - // x = self.layer_norm(x, self.w.blocks[0].ln0) - x = rwkv_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); - - for (size_t i = 0; i < model.header.n_layer; i++) { - struct rwkv_layer & layer = model.layers[i]; - - struct rwkv_layer_state state = inputs[i]; - x = ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state)); - x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); - - struct rwkv_layer_state & output_state = outputs[i]; - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_xx, output_state.att_xx)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_aa, output_state.att_aa)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_bb, output_state.att_bb)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_pp, output_state.att_pp)); - } - - graph.pre_logits_nodes = graph.cgraph->n_nodes; - graph.pre_logits_leafs = graph.cgraph->n_leafs; - - // x = self.layer_norm(x[-1,:], self.w.ln_out) - x = rwkv_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias); - - // x = (self.w.head.weight @ x).float() - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), graph.logits)); - - graph.post_logits_nodes = graph.cgraph->n_nodes; - graph.post_logits_leafs = graph.cgraph->n_leafs; - - graph.input_state = input; - graph.input_layers = std::move(inputs); - - graph.output_state = output; - graph.output_layers = std::move(outputs); - - return true; -} - -// Stolen from llama.cpp. -static const size_t tensor_alignment = 32; - -// Prepares the computation graph for inference, measuring and allocating all input and output tensors. -bool rwkv_measure_and_build_serial_context(struct rwkv_model & model, struct rwkv_computation_graph & graph) { - if (graph.ggml_ctx) { - ggml_free(graph.ggml_ctx); - - graph.ggml_ctx = NULL; - } - - // 1. Measure the space required for the ggml context. - graph.ggml_ctx = rwkv_init_ggml_context(rwkv_ggml_overhead(), true); - - RWKV_ENSURE_OR_FALSE(rwkv_build_serial_graph(model, graph)); - - struct ggml_allocr * allocator = ggml_allocr_new_measure(tensor_alignment); - - size_t required_context_size = ggml_allocr_alloc_graph(allocator, graph.cgraph.get()) + - + rwkv_ggml_overhead() - + tensor_alignment - // For some reason, calculation above does not result in enough memory allocated. - // Instead of diving deep into ggml internals to debug this issue, I will just add some padding. - // 64 MB seems to be enough for Raven 14B model. - + size_t(64) * 1024 * 1024; - - ggml_allocr_free(allocator); - ggml_free(graph.ggml_ctx); - - // 2. Create the real ggml context. - graph.ggml_ctx = rwkv_init_ggml_context(required_context_size, false); - - RWKV_ENSURE_OR_FALSE(rwkv_build_serial_graph(model, graph)); - - return true; -} - -// --- - -// Creates and sets the input and output ggml tensors, builds the computation graph. -bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_computation_graph & graph, const size_t sequence_length) { - graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); - - struct rwkv_file_header & header = model.header; - const size_t n_vocab = header.n_vocab; - const size_t n_embed = header.n_embed; - const size_t n_layer = header.n_layer; - - struct ggml_context * ctx = graph.ggml_ctx; - - graph.tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sequence_length); - - struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); - struct ggml_tensor * output = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); - - // We collect parts of input state here. Each part is (n_embed) vector. - std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, inputs.get(), "Failed to allocate input state parts"); - - // We collect parts of output state here. Each part is (n_embed) vector. - std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); - - rwkv_create_input_and_output_views(inputs.get(), outputs.get(), input, output, ctx, n_layer, n_embed); - - graph.logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab); - - // x = self.w.emb.weight[token] - struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, graph.tokens); - - // x = self.layer_norm(x, self.w.blocks[0].ln0) - x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, model.ln0_weight, x), ggml_repeat(ctx, model.ln0_bias, x)); - - for (size_t i = 0; i < model.header.n_layer; i++) { - struct rwkv_layer & layer = model.layers[i]; - struct rwkv_layer_state state = inputs[i]; - - struct ggml_tensor * x0 = x, * x_prev; - rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x0, x_prev, state.att_xx); - - struct ggml_tensor * r, * k, * v; - rwkv_att_rkv(ctx, layer, x0, x_prev, r, k, v); - - ggml_build_forward_expand(graph.cgraph.get(), r); - - for (uint32_t t = 0; t < sequence_length; t++) { - struct ggml_tensor * kt = ggml_view_1d(ctx, k, n_embed, n_embed * sizeof(float) * t); - struct ggml_tensor * vt = ggml_view_1d(ctx, v, n_embed, n_embed * sizeof(float) * t); - struct ggml_tensor * xt = ggml_view_1d(ctx, x_prev, n_embed, n_embed * sizeof(float) * t); - struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, kt, vt, state.att_aa, state.att_bb, state.att_pp); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, wkv, xt)); - } - - x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, x_prev))); - x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); - - struct rwkv_layer_state & output_state = outputs[i]; - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_xx, output_state.att_xx)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_aa, output_state.att_aa)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_bb, output_state.att_bb)); - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_pp, output_state.att_pp)); - } - - graph.pre_logits_nodes = graph.cgraph->n_nodes; - graph.pre_logits_leafs = graph.cgraph->n_leafs; - - // x = self.layer_norm(x[-1,:], self.w.ln_out) - x = rwkv_layer_norm(ctx, ggml_view_1d(ctx, x, n_embed, n_embed * sizeof(float) * (sequence_length - 1)), model.ln_out_weight, model.ln_out_bias); - - // x = (self.w.head.weight @ x).float() - ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), graph.logits)); - - graph.post_logits_nodes = graph.cgraph->n_nodes; - graph.post_logits_leafs = graph.cgraph->n_leafs; - - graph.input_state = input; - graph.input_layers = std::move(inputs); - - graph.output_state = output; - graph.output_layers = std::move(outputs); - - return true; -} - -// Prepares the computation graph for inference, measuring and allocating all input and output tensors. -bool rwkv_measure_and_build_sequential_context(struct rwkv_model & model, struct rwkv_computation_graph & graph, const size_t sequence_length) { - if (graph.ggml_ctx) { - ggml_free(graph.ggml_ctx); - - graph.ggml_ctx = NULL; - } - - // 1. Measure the space required for the ggml context. - graph.ggml_ctx = rwkv_init_ggml_context(rwkv_ggml_overhead(), true); - - RWKV_ENSURE_OR_FALSE(rwkv_build_sequential_graph(model, graph, sequence_length)); - - struct ggml_allocr * allocator = ggml_allocr_new_measure(tensor_alignment); - - size_t required_context_size = ggml_allocr_alloc_graph(allocator, graph.cgraph.get()) + - + rwkv_ggml_overhead() - + tensor_alignment - // For some reason, calculation above does not result in enough memory allocated. - // Instead of diving deep into ggml internals to debug this issue, I will just add some padding. - // 64 MB per token seems to be enough for Raven 14B model. It works for sequence_length = 5; not tested on larger lengths. - + sequence_length * 64 * 1024 * 1024; - - ggml_allocr_free(allocator); - ggml_free(graph.ggml_ctx); - - // 2. Create the real ggml context. - graph.ggml_ctx = rwkv_init_ggml_context(required_context_size, false); - - RWKV_ENSURE_OR_FALSE(rwkv_build_sequential_graph(model, graph, sequence_length)); - - return true; -} - -// --- +#include "rwkv_graph.inc" +// API function. struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { global_last_error = RWKV_ERROR_NONE; @@ -1205,6 +67,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t return ctx.release(); } +// API function. struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads) { std::unique_ptr clone(new(std::nothrow) struct rwkv_context()); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, clone, "Failed to allocate rwkv_context"); @@ -1223,187 +86,55 @@ struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32 return clone.release(); } -bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) { -#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) - const auto offload = [&](struct ggml_tensor * tensor) { - // TODO Support multi-GPU - tensor->backend = GGML_BACKEND_GPU; -#ifdef GGML_USE_CUBLAS - ggml_cuda_transform_tensor(tensor->data, tensor); -#elif defined(GGML_USE_CLBLAST) - ggml_cl_transform_tensor(tensor->data, tensor); -#endif - }; - - const size_t n_gpu = std::min(n_layers, ctx->model->header.n_layer); - - if (ctx->model->offloaded_layer_count < n_gpu) { - for (size_t & i = ctx->model->offloaded_layer_count; i < n_gpu; i++) { - const struct rwkv_layer & layer = ctx->model->layers[i]; - - // TODO Also offload other operations to GPU with ggml_cuda_assign_buffers - offload(layer.att_key); - offload(layer.att_value); - offload(layer.att_receptance); - offload(layer.att_output); - - offload(layer.ffn_key); - offload(layer.ffn_value); - offload(layer.ffn_receptance); - } - - return true; - } -#endif - return false; -} - -void rwkv_set_inputs(const struct rwkv_context * ctx, const struct rwkv_computation_graph & graph, const float * state_in) { - if (state_in) { - memcpy(graph.input_state->data, state_in, rwkv_tensor_nbytes(graph.input_state)); - } else { - rwkv_init_state(ctx, (float *) graph.input_state->data); - } -} - -void rwkv_get_outputs(const struct rwkv_computation_graph & graph, float * state_out, float * logits_out) { - if (state_out) { - memcpy(state_out, graph.output_state->data, rwkv_tensor_nbytes(graph.output_state)); - } +#include "rwkv_gpu_offload.inc" - if (logits_out) { - memcpy(logits_out, graph.logits->data, rwkv_tensor_nbytes(graph.logits)); - } -} - -void rwkv_eval_graph(struct rwkv_computation_graph & graph, const uint32_t n_threads, const bool compute_logits) { - // Short circuit computation of logits if they are not needed. - if (!compute_logits) { - graph.cgraph->n_nodes = graph.pre_logits_nodes; - graph.cgraph->n_leafs = graph.pre_logits_leafs; - } else { - graph.cgraph->n_nodes = graph.post_logits_nodes; - graph.cgraph->n_leafs = graph.post_logits_leafs; - } - - struct ggml_cplan * plan = ggml_graph_plan(graph.cgraph.get(), n_threads); - - std::unique_ptr work_data{ new(std::nothrow) uint8_t[plan->work_size] }; - plan->work_data = work_data.get(); - - ggml_graph_compute(graph.cgraph.get(), plan); - - free(plan); -} - -bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) { - ctx->last_error = RWKV_ERROR_NONE; - - const struct rwkv_file_header & header = ctx->model->header; - const size_t n_vocab = header.n_vocab; - RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < n_vocab, "Token (%" PRId32 ") is out of range (0 .. %zu)", token, n_vocab - 1); - - rwkv_set_inputs(ctx, ctx->serial_graph, state_in); - ggml_set_i32(ctx->serial_graph.tokens, token); - - rwkv_eval_graph(ctx->serial_graph, ctx->n_threads, logits_out != NULL); - - rwkv_get_outputs(ctx->serial_graph, state_out, logits_out); - - return true; -} +#include "rwkv_eval.inc" -bool rwkv_eval_sequence( - struct rwkv_context * ctx, - const uint32_t * sequence, - const size_t sequence_len, - const float * state_in, - float * state_out, - float * logits_out -) { - ctx->last_error = RWKV_ERROR_NONE; - - RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, sequence_len > 0, "Sequence length is 0"); - - const size_t n_vocab = ctx->model->header.n_vocab; - - if (sequence) { - for (size_t i = 0; i < sequence_len; i++) { - const uint32_t token = sequence[i]; - - RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < n_vocab, "Token at index %zu (%" PRId32 ") is out of range (0 .. %zu)", i, token, n_vocab - 1); - } - } - - if (ctx->last_used_sequence_length != sequence_len) { - RWKV_ENSURE_OR_FALSE(rwkv_measure_and_build_sequential_context(*ctx->model, ctx->sequential_graph, sequence_len)); - - ctx->last_used_sequence_length = sequence_len; - } - - // Allow building the sequence graph without actually evaluating, by specifying sequence = NULL. - if (sequence) { - rwkv_set_inputs(ctx, ctx->sequential_graph, state_in); - memcpy(ctx->sequential_graph.tokens->data, sequence, sequence_len * sizeof(uint32_t)); - - rwkv_eval_graph(ctx->sequential_graph, ctx->n_threads, logits_out != NULL); - - rwkv_get_outputs(ctx->sequential_graph, state_out, logits_out); - } - - return true; -} - -// Provided for compatibility. +// API function. +// Provided for backwards compatibility. extern "C" RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) { return rwkv_get_state_len(ctx); } -// Provided for compatibility. +// API function. +// Provided for backwards compatibility. extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) { return rwkv_get_logits_len(ctx); } -extern "C" RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) { +// API function. +size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) { return (size_t) ctx->model->header.n_vocab; } -extern "C" RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx) { +// API function. +size_t rwkv_get_n_embed(const struct rwkv_context * ctx) { return (size_t) ctx->model->header.n_embed; } -extern "C" RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx) { +// API function. +size_t rwkv_get_n_layer(const struct rwkv_context * ctx) { return (size_t) ctx->model->header.n_layer; } +// API function. size_t rwkv_get_state_len(const struct rwkv_context * ctx) { const struct rwkv_file_header & header = ctx->model->header; return (size_t) header.n_embed * 5 * (size_t) header.n_layer; } +// API function. size_t rwkv_get_logits_len(const struct rwkv_context * ctx) { return (size_t) ctx->model->header.n_vocab; } -void rwkv_init_state(const struct rwkv_context * ctx, float * state) { - const struct rwkv_file_header & header = ctx->model->header; - const size_t layer_size = (size_t) header.n_embed * 5; - const size_t layer_zero = (size_t) header.n_embed * 4; - const size_t layers_size = (size_t) header.n_layer * layer_size; - - for (size_t start = 0; start < layers_size; start += layer_size) { - for (size_t i = 0; i < layer_zero; i++) { - state[start + i] = 0.0F; - } - - for (size_t i = layer_zero; i < layer_size; i++) { - state[start + i] = -1e30F; - } +// API function. +void rwkv_free(struct rwkv_context * ctx) { + if (ctx == NULL) { + return; } -} -void rwkv_free(struct rwkv_context * ctx) { if (--ctx->model->reference_count == 0) { ggml_free(ctx->model->ggml_ctx); @@ -1419,169 +150,28 @@ void rwkv_free(struct rwkv_context * ctx) { std::unique_ptr rwkv_ctx(ctx); } -bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * type_name) { - global_last_error = RWKV_ERROR_NONE; - - enum ggml_type out_type = rwkv_type_to_ggml[rwkv_type_from_string(type_name)]; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, ggml_is_quantized(out_type), "Unsupported output data type (%s)", rwkv_type_to_string[rwkv_type_from_ggml[out_type]]); - - RWKV_MSG("Loading model from '%s'\n", in_path); - - struct stat in_stat; - - struct rwkv_file in_file(fopen(in_path, "rb")); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, in_file.file, "Failed to open %s for reading", in_path); - - // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length. - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(in_file.file), &in_stat) == 0, "failed to stat file %s", in_path); - - struct rwkv_file out_file(fopen(out_path, "wb")); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, out_file.file, "Failed to open %s for writing", out_path); - - struct rwkv_file_header in_header; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(in_file.file, in_header), "Invalid file header"); - - enum ggml_type in_type = rwkv_type_to_ggml[in_header.data_type]; - RWKV_ASSERT_FALSE_MSG( - RWKV_ERROR_FILE, - in_type == GGML_TYPE_F32 || in_type == GGML_TYPE_F16, - "Unsupported input data type (%s); needs to be FP32 or FP16", - rwkv_type_to_string[rwkv_type_from_ggml[in_type]] - ); - - struct rwkv_file_header out_header = in_header; - out_header.version = RWKV_FILE_VERSION; - out_header.data_type = rwkv_type_from_ggml[out_type]; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fwrite_file_header(out_file.file, out_header), "Failed to write file header"); - - // Process parameters - size_t orig_total_size = 0; - size_t new_total_size = 0; - - // Required to init the F16 tables - // Doesn't crash if ggml_init fails - ggml_free(ggml_init({ 0, NULL, true })); - - size_t max_in_size = 0; - size_t max_out_size = 0; - size_t max_key_length = 0; - - while (ftell(in_file.file) < in_stat.st_size) { - struct rwkv_tensor_header header; - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, rwkv_fread_tensor_header_and_skip(in_file.file, header)); - - size_t in_size = header.size(); - - if (in_size > max_in_size) { - max_in_size = in_size; - } - - // f16 type tensors get relocated to out and then converted into f32 at in - if (header.data_type == TYPE_FP16) { - if (in_size > max_out_size) { - max_out_size = in_size; - } - - size_t f32_size = rwkv_tensor_nbytes(GGML_TYPE_F32, header.width, header.height); - - if (f32_size > max_in_size) { - max_in_size = f32_size; - } - } - - size_t out_size = rwkv_tensor_nbytes(out_type, header.width, header.height); - - if (out_size > max_out_size) { - max_out_size = out_size; - } - - if (header.key_length > max_key_length) { - max_key_length = header.key_length; - } - } - - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(in_file.file, sizeof(struct rwkv_file_header), SEEK_SET) == 0, "Failed to seek in file"); - - // This is a histogram of quantized values. If it shows single 1.0, then all 0.0, something went very wrong! - int64_t hist_all[16] {}; - - std::unique_ptr scratch(new(std::nothrow) uint8_t[max_in_size + max_out_size]); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate buffer"); - - uint8_t * in_buf = scratch.get(); - uint8_t * out_buf = in_buf + max_in_size; - - struct rwkv_tensor tensor; - struct rwkv_tensor_header & header = tensor.header; - std::string & name = tensor.name; - uint8_t *& data = tensor.data; - - while (ftell(in_file.file) < in_stat.st_size) { - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(in_file.file, header), "Failed to read tensor header"); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(in_file.file, header.key_length, name), "Failed to read tensor name"); - - const char * name_str = name.c_str(); - RWKV_MSG("%*s - [%5" PRId32 ", %5" PRId32 "], type = %6s ", (int) max_key_length, name_str, header.width, header.height, rwkv_type_to_string[header.data_type]); - - data = header.data_type == TYPE_FP16 ? out_buf : in_buf; - size_t orig_size = header.size(), new_size = orig_size; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_data(in_file.file, orig_size, data), "\nFailed to read tensor data of %s", name_str); - - // Quantize only 2D tensors, except embedding and head matrices. - // Embedding and head take not too much space, especially in bigger models; - // but they significantly increase perplexity when quantized. - if ((header.data_type == TYPE_FP32 || header.data_type == TYPE_FP16) && header.dim_count == 2 && name != "emb.weight" && name != "head.weight") { - RWKV_MSG("quantizing... "); - - size_t nelements = (size_t) header.width * (size_t) header.height; - - if (header.data_type == TYPE_FP16) { - ggml_fp16_to_fp32_row((const ggml_fp16_t *) out_buf, (float *) in_buf, nelements); - } - - int64_t hist_cur[16] {}; - new_size = ggml_quantize_chunk(out_type, (const float *) in_buf, out_buf, 0, nelements, hist_cur); - header.data_type = rwkv_type_from_ggml[out_type]; - data = out_buf; - - RWKV_MSG("size = %8.2f MB -> %8.2f MB | hist: ", orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); - - for (int i = 0; i < 16; i++) { - RWKV_MSG("%5.3f ", hist_cur[i] / (float) nelements); - hist_all[i] += hist_cur[i]; - } - - RWKV_MSG("\n"); - } else { - RWKV_MSG("size = %8.3f MB\n", orig_size / 1024.0 / 1024.0); - } - - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_tensor(out_file.file, tensor), "Failed to write tensor %s", name_str); - orig_total_size += orig_size; - new_total_size += new_size; - } - - RWKV_MSG("original size = %8.2f MB\n", orig_total_size / 1024.0 / 1024.0); - RWKV_MSG("quantized size = %8.2f MB\n", new_total_size / 1024.0 / 1024.0); - RWKV_MSG("compression ratio = %8.2f\n", orig_total_size / float(new_total_size)); - - int64_t sum_all = 0; - - for (int i = 0; i < 16; i++) { - sum_all += hist_all[i]; - } - - RWKV_MSG("hist: "); - - for (int i = 0; i < 16; ++i) { - printf("%5.3f ", hist_all[i] / float(sum_all)); - } +// API function. +void rwkv_set_print_errors(struct rwkv_context * ctx, const bool print_errors) { + bool * ptr = ctx ? &ctx->print_errors : &global_print_errors; + *ptr = print_errors; +} - RWKV_MSG("\n"); +// API function. +bool rwkv_get_print_errors(const struct rwkv_context * ctx) { + return ctx ? ctx->print_errors : global_print_errors; +} - return true; +// API function. +enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) { + enum rwkv_error_flags * ptr = ctx ? &ctx->last_error : &global_last_error; + enum rwkv_error_flags value = *ptr; + *ptr = RWKV_ERROR_NONE; + return value; } +#include "rwkv_quantize.inc" + +// API function. const char * rwkv_get_system_info_string(void) { static std::string s; diff --git a/rwkv.h b/rwkv.h index 87bca55..4b5ddee 100644 --- a/rwkv.h +++ b/rwkv.h @@ -5,9 +5,9 @@ #include #include -#ifdef RWKV_SHARED +#if defined(RWKV_SHARED) # if defined(_WIN32) && !defined(__MINGW32__) -# ifdef RWKV_BUILD +# if defined(RWKV_BUILD) # define RWKV_API __declspec(dllexport) # else # define RWKV_API __declspec(dllimport) @@ -29,7 +29,7 @@ // Default file version is the latest version. #define RWKV_FILE_VERSION RWKV_FILE_VERSION_MAX -#ifdef __cplusplus +#if defined(__cplusplus) extern "C" { #endif @@ -73,11 +73,11 @@ extern "C" { // If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors, // as well as the default for new context. // - print_errors: whether error messages should be automatically printed. - RWKV_API void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors); + RWKV_API void rwkv_set_print_errors(struct rwkv_context * ctx, const bool print_errors); // Gets whether errors are automatically printed to stderr. // - ctx: the context to retrieve the setting for, or NULL for the global setting. - RWKV_API bool rwkv_get_print_errors(struct rwkv_context * ctx); + RWKV_API bool rwkv_get_print_errors(const struct rwkv_context * ctx); // Retrieves and clears the error flags. // - ctx: the context the retrieve the error for, or NULL for the global error. @@ -110,7 +110,13 @@ extern "C" { // - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass. // - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. // - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. - RWKV_API bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out); + RWKV_API bool rwkv_eval( + struct rwkv_context * ctx, + const uint32_t token, + const float * state_in, + float * state_out, + float * logits_out + ); // Evaluates the model for a sequence of tokens. // Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so. @@ -135,7 +141,14 @@ extern "C" { // - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass. // - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. // - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. - RWKV_API bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out); + RWKV_API bool rwkv_eval_sequence( + struct rwkv_context * ctx, + const uint32_t * tokens, + const size_t sequence_len, + const float * state_in, + float * state_out, + float * logits_out + ); // Returns the number of tokens in the given model's vocabulary. // Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536). @@ -183,7 +196,7 @@ extern "C" { // Returns system information string. RWKV_API const char * rwkv_get_system_info_string(void); -#ifdef __cplusplus +#if defined(__cplusplus) } #endif diff --git a/rwkv_error_handling.inc b/rwkv_error_handling.inc new file mode 100644 index 0000000..f4868f1 --- /dev/null +++ b/rwkv_error_handling.inc @@ -0,0 +1,95 @@ +thread_local enum rwkv_error_flags global_last_error = RWKV_ERROR_NONE; +thread_local bool global_print_errors = true; + +inline static enum rwkv_error_flags operator|(enum rwkv_error_flags a, enum rwkv_error_flags b) { + return static_cast(static_cast(a) | static_cast(b)); +} + +inline static enum rwkv_error_flags operator|=(enum rwkv_error_flags & a, enum rwkv_error_flags b) { + return a = a | b; +} + +// Prints a message to stderr if error printing is enabled globally. +#define RWKV_MSG(...) do { if (global_print_errors) fprintf(stderr, __VA_ARGS__); } while (0) + +// Prints a message to stderr if error printing is enabled in the context. +#define RWKV_CTX_MSG(ctx, ...) do { if (ctx->print_errors) fprintf(stderr, __VA_ARGS__); } while (0) + +// If the condition x is false, adds ERR_VAL to the last error, and returns RET_VAL. +#define RWKV_ASSERT(ERR_VAL, RET_VAL, x) do { \ + if (!(x)) { \ + global_last_error |= ERR_VAL; \ + RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ + return RET_VAL; \ + } } while (0) + +// If the condition x is false, adds ERR_VAL to the last error, prints a message to stderr, and returns RET_VAL. +#define RWKV_ASSERT_MSG(ERR_VAL, RET_VAL, x, ...) do { \ + if (!(x)) { \ + global_last_error |= ERR_VAL; \ + RWKV_MSG(__VA_ARGS__); \ + RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ + return RET_VAL; \ + } } while (0) + +// If the condition x is false, adds ERR_VAL to the ctx's last error, prints a message to stderr, and returns RET_VAL. +#define RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, RET_VAL, x, ...) do { \ + if (!(x)) { \ + ((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \ + RWKV_CTX_MSG(ctx, __VA_ARGS__); \ + RWKV_CTX_MSG(ctx, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ + return RET_VAL; \ + } } while (0) + +// If the condition x is false, adds ERR_VAL to the ctx's last error, and returns RET_VAL. +#define RWKV_CTX_ASSERT(ctx, ERR_VAL, RET_VAL, x) do { \ + if (!(x)) { \ + ((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \ + RWKV_CTX_MSG(ctx, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ + return RET_VAL; \ + } } while (0) + +// If the condition x is false, returns RET_VAL. +#define RWKV_ENSURE(RET_VAL, x) do { \ + if (!(x)) { \ + RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ + return RET_VAL; \ + } } while (0) + +// If the condition x is false, prints a message to stderr, and returns RET_VAL. +#define RWKV_ENSURE_MSG(RET_VAL, x, ...) do { \ + if (!(x)) { \ + RWKV_MSG(__VA_ARGS__); \ + RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ + return RET_VAL; \ + } } while (0) + +// If the condition x is false, prints a message to stderr, and returns RET_VAL. +#define RWKV_CTX_ENSURE_MSG(ctx, RET_VAL, x, ...) do { \ + if (!(x)) { \ + ((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \ + RWKV_CTX_MSG(ctx, __VA_ARGS__); \ + RWKV_CTX_MSG(ctx, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ + RWKV_MAYBE_BREAK; \ + return RET_VAL; \ + } } while (0) + +#define RWKV_ASSERT_FALSE_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, false, x, __VA_ARGS__) +#define RWKV_ASSERT_NULL_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, NULL, x, __VA_ARGS__) + +#define RWKV_CTX_ASSERT_FALSE_MSG(ctx, ERR_VAL, x, ...) RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, false, x, __VA_ARGS__) + +#define RWKV_ASSERT_FALSE(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, false, x) +#define RWKV_ASSERT_NULL(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, NULL, x) + +#define RWKV_CTX_ASSERT_FALSE(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, false, x) + +#define RWKV_ENSURE_OR_FALSE(x) RWKV_ENSURE(false, x) +#define RWKV_ENSURE_OR_NULL(x) RWKV_ENSURE(NULL, x) +#define RWKV_ENSURE_OR_FALSE_MSG(x, ...) RWKV_ENSURE_MSG(false, x, __VA_ARGS__) diff --git a/rwkv_eval.inc b/rwkv_eval.inc new file mode 100644 index 0000000..38bbb33 --- /dev/null +++ b/rwkv_eval.inc @@ -0,0 +1,116 @@ +// Copies state from an input buffer to the ggml tensor of the graph. +static void rwkv_set_inputs(const struct rwkv_context * ctx, const struct rwkv_computation_graph & graph, const float * state_in) { + if (state_in) { + memcpy(graph.input_state->data, state_in, rwkv_tensor_nbytes(graph.input_state)); + } else { + rwkv_init_state(ctx, (float *) graph.input_state->data); + } +} + +// Copies state and logits from ggml tensors of the graph to output buffers. +static void rwkv_get_outputs(const struct rwkv_computation_graph & graph, float * state_out, float * logits_out) { + if (state_out) { + memcpy(state_out, graph.output_state->data, rwkv_tensor_nbytes(graph.output_state)); + } + + if (logits_out) { + memcpy(logits_out, graph.logits->data, rwkv_tensor_nbytes(graph.logits)); + } +} + +// Evaluates a computation graph, optionally skipping logit computation. +static void rwkv_eval_graph(struct rwkv_computation_graph & graph, const uint32_t n_threads, const bool compute_logits) { + if (!compute_logits) { + graph.cgraph->n_nodes = graph.pre_logits_nodes; + graph.cgraph->n_leafs = graph.pre_logits_leafs; + } else { + graph.cgraph->n_nodes = graph.post_logits_nodes; + graph.cgraph->n_leafs = graph.post_logits_leafs; + } + + struct ggml_cplan * plan = ggml_graph_plan(graph.cgraph.get(), n_threads); + + std::unique_ptr work_data{ new(std::nothrow) uint8_t[plan->work_size] }; + plan->work_data = work_data.get(); + + ggml_graph_compute(graph.cgraph.get(), plan); + + free(plan); +} + +// API function. +bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) { + ctx->last_error = RWKV_ERROR_NONE; + + const struct rwkv_file_header & header = ctx->model->header; + const size_t n_vocab = header.n_vocab; + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < n_vocab, "Token (%" PRId32 ") is out of range (0 .. %zu)", token, n_vocab - 1); + + rwkv_set_inputs(ctx, ctx->serial_graph, state_in); + ggml_set_i32(ctx->serial_graph.tokens, token); + + rwkv_eval_graph(ctx->serial_graph, ctx->n_threads, logits_out != NULL); + + rwkv_get_outputs(ctx->serial_graph, state_out, logits_out); + + return true; +} + +// API function. +bool rwkv_eval_sequence( + struct rwkv_context * ctx, + const uint32_t * sequence, + const size_t sequence_len, + const float * state_in, + float * state_out, + float * logits_out +) { + ctx->last_error = RWKV_ERROR_NONE; + + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, sequence_len > 0, "Sequence length is 0"); + + if (sequence) { + const size_t n_vocab = ctx->model->header.n_vocab; + + for (size_t i = 0; i < sequence_len; i++) { + const uint32_t token = sequence[i]; + + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < n_vocab, "Token at index %zu (%" PRId32 ") is out of range (0 .. %zu)", i, token, n_vocab - 1); + } + } + + if (ctx->last_used_sequence_length != sequence_len) { + RWKV_ENSURE_OR_FALSE(rwkv_measure_and_build_sequential_context(*ctx->model, ctx->sequential_graph, sequence_len)); + + ctx->last_used_sequence_length = sequence_len; + } + + if (sequence) { + rwkv_set_inputs(ctx, ctx->sequential_graph, state_in); + memcpy(ctx->sequential_graph.tokens->data, sequence, sequence_len * sizeof(uint32_t)); + + rwkv_eval_graph(ctx->sequential_graph, ctx->n_threads, logits_out != NULL); + + rwkv_get_outputs(ctx->sequential_graph, state_out, logits_out); + } + + return true; +} + +// API function. +void rwkv_init_state(const struct rwkv_context * ctx, float * state) { + const struct rwkv_file_header & header = ctx->model->header; + const size_t layer_size = (size_t) header.n_embed * 5; + const size_t layer_zero = (size_t) header.n_embed * 4; + const size_t layers_size = (size_t) header.n_layer * layer_size; + + for (size_t start = 0; start < layers_size; start += layer_size) { + for (size_t i = 0; i < layer_zero; i++) { + state[start + i] = 0.0F; + } + + for (size_t i = layer_zero; i < layer_size; i++) { + state[start + i] = -1e30F; + } + } +} diff --git a/rwkv_file_format.inc b/rwkv_file_format.inc new file mode 100644 index 0000000..d9b9d4c --- /dev/null +++ b/rwkv_file_format.inc @@ -0,0 +1,223 @@ +// Data types + +#define TYPE_UNKNOWN TYPE_COUNT + +enum rwkv_type { + TYPE_FP32, + TYPE_FP16, + TYPE_Q4_0, + TYPE_Q4_1, + TYPE_Q4_1_O, // Unsupported + TYPE_Q4_2, // Unsupported + TYPE_Q4_3, // Unsupported + TYPE_Q5_0, + TYPE_Q5_1, + TYPE_Q8_0, + TYPE_COUNT +}; + +#define GGML_TYPE_UNKNOWN GGML_TYPE_COUNT + +static const enum ggml_type rwkv_type_to_ggml[TYPE_COUNT + 1] = { + GGML_TYPE_F32, /* FP32 */ + GGML_TYPE_F16, /* FP16 */ + GGML_TYPE_Q4_0, /* Q4_0 */ + GGML_TYPE_Q4_1, /* Q4_1 */ + GGML_TYPE_UNKNOWN, /* Q4_1_O */ + GGML_TYPE_UNKNOWN, /* Q4_2 */ + GGML_TYPE_UNKNOWN, /* Q4_3 */ + GGML_TYPE_Q5_0, /* Q5_0 */ + GGML_TYPE_Q5_1, /* Q5_1 */ + GGML_TYPE_Q8_0, /* Q8_0 */ + GGML_TYPE_COUNT /* COUNT */ +}; + +static const enum rwkv_type rwkv_type_from_ggml[GGML_TYPE_COUNT + 1] = { + TYPE_FP32, /* FP32 */ + TYPE_FP16, /* FP16 */ + TYPE_Q4_0, /* Q4_0 */ + TYPE_Q4_1, /* Q4_1 */ + TYPE_Q4_2, /* Q4_2 */ + TYPE_Q4_3, /* Q4_3 */ + TYPE_Q5_0, /* Q5_0 */ + TYPE_Q5_1, /* Q5_1 */ + TYPE_Q8_0, /* Q8_0 */ + TYPE_COUNT, /* Q8_1 */ + TYPE_COUNT, /* I8 */ + TYPE_COUNT, /* I16 */ + TYPE_COUNT, /* I32 */ + TYPE_COUNT, /* COUNT */ +}; + +static const char * rwkv_type_to_string[TYPE_COUNT + 1] = { + "FP32", + "FP16", + "Q4_0", + "Q4_1", + "Q4_1_O", + "Q4_2", + "Q4_3", + "Q5_0", + "Q5_1", + "Q8_0", + "unknown" +}; + +static enum rwkv_type rwkv_type_from_string(const char * str) { + for (int i = 0; i < TYPE_COUNT; i++) { + if (strcmp(str, rwkv_type_to_string[i]) == 0) { + return (enum rwkv_type) i; + } + } + + return TYPE_UNKNOWN; +} + +// rwkv_file_header + +struct rwkv_file_header { + uint32_t magic; + uint32_t version; + uint32_t n_vocab; + uint32_t n_embed; + uint32_t n_layer; + uint32_t data_type; +}; + +static bool rwkv_is_file_version_in_range(const uint32_t version) { + return version >= RWKV_FILE_VERSION_MIN && version <= RWKV_FILE_VERSION_MAX; +} + +static bool rwkv_fread_file_header(FILE * file, struct rwkv_file_header & header) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_file_header), &header)); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_MAGIC, header.magic == RWKV_FILE_MAGIC); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_VERSION, rwkv_is_file_version_in_range(header.version), "Unsupported file version %" PRId32, header.version); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, header.data_type < TYPE_COUNT, "Model data type out of range (%" PRId32 " > %" PRId32 ")", header.data_type, TYPE_COUNT - 1); + + enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type]; + + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_DATA_TYPE, + ggml_type != GGML_TYPE_UNKNOWN, + "Models in %s format cannot be loaded anymore because the format was removed.\n" + "You need to quantize the model into another format or use an older version of rwkv.cpp.\n" + "See https://github.com/saharNooby/rwkv.cpp#compatibility for more info", + rwkv_type_to_string[header.data_type] + ); + + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_DATA_TYPE, + (!ggml_is_quantized(ggml_type) || header.version == RWKV_FILE_VERSION_1), + "The quantized model file in %s format was created with an old version of rwkv.cpp and can not be loaded anymore.\n" + "You need to requantize the model or use an older version of rwkv.cpp.\n" + "See https://github.com/saharNooby/rwkv.cpp#compatibility for more info", + rwkv_type_to_string[header.data_type] + ); + + return true; +} + +static bool rwkv_fwrite_file_header(FILE * file, const struct rwkv_file_header & header) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_data(file, &header, sizeof(struct rwkv_file_header))); + + return true; +} + +// rwkv_tensor_header + +struct rwkv_tensor_header { + uint32_t dim_count; + uint32_t key_length; + uint32_t data_type; + uint32_t width; + uint32_t height; + + size_t size() const; +}; + +size_t rwkv_tensor_header::size() const { + return rwkv_tensor_nbytes(rwkv_type_to_ggml[this->data_type], this->width, this->height); +} + +static bool rwkv_fread_tensor_header(FILE * file, struct rwkv_tensor_header & header) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_tensor_header) - sizeof(uint32_t), &header)); + header.height = 1; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_SHAPE, header.dim_count == 1 || header.dim_count == 2, "Tensor has an invalid shape (%" PRId32 " dimensions)", header.dim_count); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, header.data_type < TYPE_COUNT, "Tensor data type out of range (%" PRId32 " > %" PRId32 ")", header.data_type, TYPE_COUNT - 1); + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_DATA_TYPE, + rwkv_type_to_ggml[header.data_type] != GGML_TYPE_UNKNOWN, + "Tensor data type (%s) is no longer supported", + rwkv_type_to_string[header.data_type] + ); + + if (header.dim_count == 2) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_uint32(file, header.height)); + } + + return true; +} + +static bool rwkv_fwrite_tensor_header(FILE * file, const struct rwkv_tensor_header & header) { + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_data(file, &header, sizeof(struct rwkv_tensor_header) - (header.dim_count == 1 ? sizeof(uint32_t) : 0))); + + return true; +} + +static bool rwkv_fread_tensor_header_skip_name_and_data(FILE * file, struct rwkv_tensor_header & header) { + RWKV_ENSURE_OR_FALSE(rwkv_fread_tensor_header(file, header)); + + RWKV_ASSERT_FALSE(RWKV_ERROR_DATA, fseek(file, header.key_length + header.size(), SEEK_CUR) == 0); + + return true; +} + +// rwkv_tensor + +struct rwkv_tensor { + struct rwkv_tensor_header header; + std::string name; + uint8_t * data; +}; + +static bool rwkv_fwrite_tensor(FILE * file, const struct rwkv_tensor & tensor) { + RWKV_ENSURE_OR_FALSE(rwkv_fwrite_tensor_header(file, tensor.header)); + RWKV_ENSURE_OR_FALSE(rwkv_fwrite_string(file, tensor.name)); + RWKV_ENSURE_OR_FALSE(rwkv_fwrite_data(file, tensor.data, tensor.header.size())); + return true; +} + +// Reading ggml tensors + +static bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { + struct rwkv_tensor_header header; + RWKV_ENSURE_OR_FALSE_MSG(rwkv_fread_tensor_header(file, header), "Invalid tensor header"); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name"); + + enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type]; + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_UNSUPPORTED, + ggml_type != GGML_TYPE_UNKNOWN, + "Unsupported data type %s in parameter %s", + rwkv_type_to_string[header.data_type], + name.c_str() + ); + + tensor = header.dim_count == 1 + ? ggml_new_tensor_1d(ctx, ggml_type, header.width) + : ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor != NULL, "Failed to allocate tensor"); + + ggml_set_name(tensor, name.c_str()); + + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_FILE_READ, + rwkv_fread_data(file, rwkv_tensor_nbytes(tensor), tensor->data), + "Failed to read data of parameter %s", + name.c_str() + ); + + return true; +} diff --git a/rwkv_gpu_offload.inc b/rwkv_gpu_offload.inc new file mode 100644 index 0000000..3564d78 --- /dev/null +++ b/rwkv_gpu_offload.inc @@ -0,0 +1,51 @@ +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) + +#if defined(GGML_USE_CUBLAS) +# include "ggml/src/ggml-cuda.h" +#elif defined(GGML_USE_CLBLAST) +# include "ggml/src/ggml-opencl.h" +#endif + +// API function. +bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) { + const auto offload = [&](struct ggml_tensor * tensor) { + // TODO Support multi-GPU + tensor->backend = GGML_BACKEND_GPU; +#if defined(GGML_USE_CUBLAS) + ggml_cuda_transform_tensor(tensor->data, tensor); +#elif defined(GGML_USE_CLBLAST) + ggml_cl_transform_tensor(tensor->data, tensor); +#endif + }; + + const size_t n_gpu = std::min(n_layers, ctx->model->header.n_layer); + + if (ctx->model->offloaded_layer_count >= n_gpu) { + return false; + } + + for (size_t & i = ctx->model->offloaded_layer_count; i < n_gpu; i++) { + const struct rwkv_layer & layer = ctx->model->layers[i]; + + // TODO Also offload other supported operations to GPU + offload(layer.att_key); + offload(layer.att_value); + offload(layer.att_receptance); + offload(layer.att_output); + + offload(layer.ffn_key); + offload(layer.ffn_value); + offload(layer.ffn_receptance); + } + + return true; +} + +#else + +// API function. +bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) { + return false; +} + +#endif diff --git a/rwkv_graph.inc b/rwkv_graph.inc new file mode 100644 index 0000000..e52c9d2 --- /dev/null +++ b/rwkv_graph.inc @@ -0,0 +1,463 @@ +// View tensors of a state of a single layer. +struct rwkv_layer_state { + struct ggml_tensor * ffn_xx; + struct ggml_tensor * att_xx; + struct ggml_tensor * att_aa; + struct ggml_tensor * att_bb; + struct ggml_tensor * att_pp; +}; + +// The computation graph holds ggml context and the ggml cgraph. +// It can be either a serial or a sequential graph. +struct rwkv_computation_graph { + struct ggml_context * ggml_ctx; + // ggml_cgraph is so large that it can cause stack overflows if not stored on the heap. + std::unique_ptr cgraph; + + // Input tensors. + struct ggml_tensor * tokens; + struct ggml_tensor * input_state; + std::unique_ptr input_layers; + + // Output tensors. + struct ggml_tensor * output_state; + std::unique_ptr output_layers; + struct ggml_tensor * logits; + + // ggml graph counters before the graph was extended with logits tensor. + int pre_logits_nodes; + int pre_logits_leafs; + // ggml graph counters after the graph was extended with logits tensor. + int post_logits_nodes; + int post_logits_leafs; +}; + +// The context holds the model and both serial and sequential computation graphs. +struct rwkv_context { + struct rwkv_model * model; + + // The serial graph implements the traditional RNN mode that processes only one token at a time (serial mode). + struct rwkv_computation_graph serial_graph; + // The sequence graph implements the "sequence mode" (or transformer/GPT mode) that processes multiple tokens at a time. + // This can be an order of magnitude or so faster than serial execution if used properly. + struct rwkv_computation_graph sequential_graph; + size_t last_used_sequence_length; + + uint32_t n_threads; + + enum rwkv_error_flags last_error; + bool print_errors; +}; + +static void rwkv_carry_x( + struct ggml_context * ctx, + struct ggml_tensor * weight, + struct ggml_tensor * bias, + struct ggml_tensor *& x, + struct ggml_tensor *& x_prev, + struct ggml_tensor *& carry +) { + const size_t n_embed = x->ne[0]; + const size_t sequence_len = x->ne[1]; + + if (sequence_len == 1) { + // self.layer_norm(x, self.w.blocks[i].ln2) + x = rwkv_layer_norm(ctx, x, weight, bias); + + // xx = state[5*i+0] + x_prev = carry; + + // state[5*i+0] = x + carry = x; + } else { + // self.layer_norm(x, self.w.blocks[i].ln2) + x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, weight, x), ggml_repeat(ctx, bias, x)); + + // xx = torch.cat((state[5*i+0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:])) + x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len); + x_prev = ggml_set_1d_inplace(ctx, x_prev, carry, 0); + x_prev = ggml_set_1d_inplace(ctx, x_prev, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float)); + + // state[5*i+0] = x[-1,:] + carry = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float)); + } +} + +static void rwkv_att_rkv( + struct ggml_context * ctx, + struct rwkv_layer layer, + struct ggml_tensor * x, + struct ggml_tensor * x_prev, + struct ggml_tensor *& r, + struct ggml_tensor *& k, + struct ggml_tensor *& v +) { + // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) + struct ggml_tensor * xk = ggml_add_inplace(ctx, + ggml_mul(ctx, x, layer.att_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) + ); + + // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) + struct ggml_tensor * xv = ggml_add_inplace(ctx, + ggml_mul(ctx, x, layer.att_time_mix_v), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) + ); + + // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) + struct ggml_tensor * xr = ggml_add_inplace(ctx, + ggml_mul(ctx, x, layer.att_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) + ); + + // r = torch.sigmoid(rw @ xr) + r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr)); + // k = kw @ xk + k = ggml_mul_mat(ctx, layer.att_key, xk); + // v = vw @ xv + v = ggml_mul_mat(ctx, layer.att_value, xv); +} + +static struct ggml_tensor * rwkv_att_wkv( + struct ggml_context * ctx, + struct ggml_tensor * att_time_first, + struct ggml_tensor * att_time_decay, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor *& aa, + struct ggml_tensor *& bb, + struct ggml_tensor *& pp +) { + // ww = time_first + k + struct ggml_tensor * ww = ggml_add(ctx, att_time_first, k); + // qq = torch.maximum(pp, ww) + struct ggml_tensor * qq = rwkv_max(ctx, pp, ww); + // e1 = torch.exp(pp - qq) + struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq)); + // e2 = torch.exp(ww - qq) + struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); + + // a = e1 * aa + e2 * v + struct ggml_tensor * a = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); + // b = e1 * bb + e2 + struct ggml_tensor * b = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); + + // ww = pp + time_decay + ww = ggml_add(ctx, pp, att_time_decay); + // qq = torch.maximum(ww, k) + qq = rwkv_max(ctx, ww, k); + // e1 = torch.exp(ww - qq) + e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); + // e2 = torch.exp(k[t] - qq) + e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq)); + + // state[5 * i + 2] = e1 * aa + e2 * v + // state[5 * i + 3] = e1 * bb + e2 + // state[5 * i + 4] = qq + aa = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); + bb = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); + pp = qq; + + // wkv = a / b + return ggml_div(ctx, a, b); +} + +static struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { + struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); + + struct ggml_tensor * r, * k, * v; + rwkv_att_rkv(ctx, layer, x, x_prev, r, k, v); + + struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, k, v, state.att_aa, state.att_bb, state.att_pp); + + // ow @ (r * xx) + return ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv)); +} + +static struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { + struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); + + // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) + // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) + struct ggml_tensor * xk = ggml_add_inplace( + ctx, + ggml_mul(ctx, x, layer.ffn_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) + ); + + // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) + struct ggml_tensor * xr = ggml_add_inplace( + ctx, + ggml_mul(ctx, x, layer.ffn_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) + ); + + // r = torch.sigmoid(rw @ xr) + struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); + + // k = torch.square(torch.relu(kw @ xk)) + struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); + + // r * (vw @ k) + return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); +} + +static void rwkv_create_input_and_output_views( + struct ggml_context * ctx, + struct rwkv_layer_state * inputs, + struct rwkv_layer_state * outputs, + struct ggml_tensor * input, + struct ggml_tensor * output, + const size_t n_layer, + const size_t n_embed +) { + for (size_t i = 0; i < n_layer; i++) { + struct rwkv_layer_state & input_state = inputs[i]; + input_state.ffn_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 0) * sizeof(float)); + input_state.att_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 1) * sizeof(float)); + input_state.att_aa = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 2) * sizeof(float)); + input_state.att_bb = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 3) * sizeof(float)); + input_state.att_pp = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 4) * sizeof(float)); + + struct rwkv_layer_state & output_state = outputs[i]; + output_state.ffn_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 0) * sizeof(float)); + output_state.att_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 1) * sizeof(float)); + output_state.att_aa = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 2) * sizeof(float)); + output_state.att_bb = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 3) * sizeof(float)); + output_state.att_pp = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 4) * sizeof(float)); + } +} + +// Serial graph (token-by-token eval) + +// Creates and sets the input and output ggml tensors, builds the computation graph. +static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_computation_graph & graph) { + graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); + + struct rwkv_file_header & header = model.header; + const size_t n_vocab = header.n_vocab; + const size_t n_embed = header.n_embed; + const size_t n_layer = header.n_layer; + + struct ggml_context * ctx = graph.ggml_ctx; + + // Creates a 1-element tensor. + graph.tokens = ggml_new_i32(ctx, 0); + + struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); + struct ggml_tensor * output = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); + + // We collect parts of input state here. Each part is (n_embed) vector. + std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, inputs.get(), "Failed to allocate input state parts"); + + // We collect parts of output state here. Each part is (n_embed) vector. + std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); + + rwkv_create_input_and_output_views(ctx, inputs.get(), outputs.get(), input, output, n_layer, n_embed); + + graph.logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab); + + // x = self.w.emb.weight[token] + struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, graph.tokens); + + // x = self.layer_norm(x, self.w.blocks[0].ln0) + x = rwkv_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); + + for (size_t i = 0; i < model.header.n_layer; i++) { + struct rwkv_layer & layer = model.layers[i]; + + struct rwkv_layer_state state = inputs[i]; + x = ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state)); + x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); + + struct rwkv_layer_state & output_state = outputs[i]; + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_xx, output_state.att_xx)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_aa, output_state.att_aa)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_bb, output_state.att_bb)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_pp, output_state.att_pp)); + } + + graph.pre_logits_nodes = graph.cgraph->n_nodes; + graph.pre_logits_leafs = graph.cgraph->n_leafs; + + // x = self.layer_norm(x[-1,:], self.w.ln_out) + x = rwkv_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias); + + // x = (self.w.head.weight @ x).float() + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), graph.logits)); + + graph.post_logits_nodes = graph.cgraph->n_nodes; + graph.post_logits_leafs = graph.cgraph->n_leafs; + + graph.input_state = input; + graph.input_layers = std::move(inputs); + + graph.output_state = output; + graph.output_layers = std::move(outputs); + + return true; +} + +// Copy-pasted from llama.cpp. +static const size_t tensor_alignment = 32; + +// Prepares the computation graph for inference, measuring and allocating all input and output tensors. +static bool rwkv_measure_and_build_serial_context(struct rwkv_model & model, struct rwkv_computation_graph & graph) { + if (graph.ggml_ctx) { + ggml_free(graph.ggml_ctx); + + graph.ggml_ctx = NULL; + } + + // 1. Measure the space required for the ggml context. + graph.ggml_ctx = rwkv_init_ggml_context(rwkv_ggml_overhead(), true); + + RWKV_ENSURE_OR_FALSE(rwkv_build_serial_graph(model, graph)); + + struct ggml_allocr * allocator = ggml_allocr_new_measure(tensor_alignment); + + size_t required_context_size = ggml_allocr_alloc_graph(allocator, graph.cgraph.get()) + + + rwkv_ggml_overhead() + + tensor_alignment + // For some reason, calculation above does not result in enough memory allocated. + // Instead of diving deep into ggml internals to debug this issue, I will just add some padding. + // 64 MB seems to be enough for Raven 14B model. + + size_t(64) * 1024 * 1024; + + ggml_allocr_free(allocator); + ggml_free(graph.ggml_ctx); + + // 2. Create the real ggml context. + graph.ggml_ctx = rwkv_init_ggml_context(required_context_size, false); + + RWKV_ENSURE_OR_FALSE(rwkv_build_serial_graph(model, graph)); + + return true; +} + +// Sequential graph + +// Creates and sets the input and output ggml tensors, builds the computation graph. +static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_computation_graph & graph, const size_t sequence_length) { + graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); + + struct rwkv_file_header & header = model.header; + const size_t n_vocab = header.n_vocab; + const size_t n_embed = header.n_embed; + const size_t n_layer = header.n_layer; + + struct ggml_context * ctx = graph.ggml_ctx; + + graph.tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sequence_length); + + struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); + struct ggml_tensor * output = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); + + // We collect parts of input state here. Each part is (n_embed) vector. + std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, inputs.get(), "Failed to allocate input state parts"); + + // We collect parts of output state here. Each part is (n_embed) vector. + std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); + + rwkv_create_input_and_output_views(ctx, inputs.get(), outputs.get(), input, output, n_layer, n_embed); + + graph.logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab); + + // x = self.w.emb.weight[token] + struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, graph.tokens); + + // x = self.layer_norm(x, self.w.blocks[0].ln0) + x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, model.ln0_weight, x), ggml_repeat(ctx, model.ln0_bias, x)); + + for (size_t i = 0; i < model.header.n_layer; i++) { + struct rwkv_layer & layer = model.layers[i]; + struct rwkv_layer_state state = inputs[i]; + + struct ggml_tensor * x0 = x, * x_prev; + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x0, x_prev, state.att_xx); + + struct ggml_tensor * r, * k, * v; + rwkv_att_rkv(ctx, layer, x0, x_prev, r, k, v); + + ggml_build_forward_expand(graph.cgraph.get(), r); + + for (uint32_t t = 0; t < sequence_length; t++) { + struct ggml_tensor * kt = ggml_view_1d(ctx, k, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * vt = ggml_view_1d(ctx, v, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * xt = ggml_view_1d(ctx, x_prev, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, kt, vt, state.att_aa, state.att_bb, state.att_pp); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, wkv, xt)); + } + + x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, x_prev))); + x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); + + struct rwkv_layer_state & output_state = outputs[i]; + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_xx, output_state.att_xx)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_aa, output_state.att_aa)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_bb, output_state.att_bb)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_pp, output_state.att_pp)); + } + + graph.pre_logits_nodes = graph.cgraph->n_nodes; + graph.pre_logits_leafs = graph.cgraph->n_leafs; + + // x = self.layer_norm(x[-1,:], self.w.ln_out) + x = rwkv_layer_norm(ctx, ggml_view_1d(ctx, x, n_embed, n_embed * sizeof(float) * (sequence_length - 1)), model.ln_out_weight, model.ln_out_bias); + + // x = (self.w.head.weight @ x).float() + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), graph.logits)); + + graph.post_logits_nodes = graph.cgraph->n_nodes; + graph.post_logits_leafs = graph.cgraph->n_leafs; + + graph.input_state = input; + graph.input_layers = std::move(inputs); + + graph.output_state = output; + graph.output_layers = std::move(outputs); + + return true; +} + +// Prepares the computation graph for inference, measuring and allocating all input and output tensors. +static bool rwkv_measure_and_build_sequential_context(struct rwkv_model & model, struct rwkv_computation_graph & graph, const size_t sequence_length) { + if (graph.ggml_ctx) { + ggml_free(graph.ggml_ctx); + + graph.ggml_ctx = NULL; + } + + // 1. Measure the space required for the ggml context. + graph.ggml_ctx = rwkv_init_ggml_context(rwkv_ggml_overhead(), true); + + RWKV_ENSURE_OR_FALSE(rwkv_build_sequential_graph(model, graph, sequence_length)); + + struct ggml_allocr * allocator = ggml_allocr_new_measure(tensor_alignment); + + size_t required_context_size = ggml_allocr_alloc_graph(allocator, graph.cgraph.get()) + + + rwkv_ggml_overhead() + + tensor_alignment + // For some reason, calculation above does not result in enough memory allocated. + // Instead of diving deep into ggml internals to debug this issue, I will just add some padding. + // 64 MB per token seems to be enough for Raven 14B model. It works for sequence_length up to 71 at least. + + sequence_length * 64 * 1024 * 1024; + + ggml_allocr_free(allocator); + ggml_free(graph.ggml_ctx); + + // 2. Create the real ggml context. + graph.ggml_ctx = rwkv_init_ggml_context(required_context_size, false); + + RWKV_ENSURE_OR_FALSE(rwkv_build_sequential_graph(model, graph, sequence_length)); + + return true; +} diff --git a/rwkv_model_loading.inc b/rwkv_model_loading.inc new file mode 100644 index 0000000..e19af58 --- /dev/null +++ b/rwkv_model_loading.inc @@ -0,0 +1,158 @@ +struct rwkv_layer { + struct ggml_tensor * ln1_weight; + struct ggml_tensor * ln1_bias; + + // RWKV, also called "attention" by the author. + struct ggml_tensor * att_time_mix_k; + struct ggml_tensor * att_time_mix_v; + struct ggml_tensor * att_time_mix_r; + struct ggml_tensor * att_time_first; + struct ggml_tensor * att_time_decay; + struct ggml_tensor * att_key; + struct ggml_tensor * att_value; + struct ggml_tensor * att_receptance; + struct ggml_tensor * att_output; + + struct ggml_tensor * ln2_weight; + struct ggml_tensor * ln2_bias; + + // FFN. + struct ggml_tensor * ffn_time_mix_k; + struct ggml_tensor * ffn_time_mix_r; + struct ggml_tensor * ffn_key; + struct ggml_tensor * ffn_value; + struct ggml_tensor * ffn_receptance; +}; + +// The model holds all parameter tensors and the ggml context containing them. +// Each tensor has data and can be used in computations happening in other contexts. +struct rwkv_model { + // This context holds all parameter tensors. + // It must not be used for computations. + struct ggml_context * ggml_ctx; + + struct rwkv_file_header header; + + struct ggml_tensor * emb; + + struct ggml_tensor * ln0_weight; + struct ggml_tensor * ln0_bias; + + std::unique_ptr layers; + + struct ggml_tensor * ln_out_weight; + struct ggml_tensor * ln_out_bias; + + struct ggml_tensor * head; + + // How many layers were offloaded to the GPU. + size_t offloaded_layer_count; + + // How many RWKV contexts reference this model. + int reference_count; +}; + +struct rwkv_file { + FILE * file; + + rwkv_file(FILE * file): file(file) {} + + ~rwkv_file() { + if (file) { + fclose(file); + } + } +}; + +// https://stackoverflow.com/a/6458689 +template +static bool rwkv_set_params(struct rwkv_model & model, F callback) { + RWKV_ENSURE_OR_FALSE(callback("emb.weight", model.emb)); + RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.weight", model.ln0_weight)); + RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.bias", model.ln0_bias)); + + uint32_t n_layer = model.header.n_layer; + std::unique_ptr layers(new(std::nothrow) struct rwkv_layer[n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, layers.get(), "Failed to allocate model layers"); + model.layers = std::move(layers); + + for (uint32_t i = 0; i < n_layer; i++) { + char buffer[128]; + size_t offset = sprintf(buffer, "blocks.%" PRId32 ".", i); + + rwkv_layer & layer = model.layers[i]; + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.weight"), buffer), layer.ln1_weight)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.bias"), buffer), layer.ln1_bias)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance)); + } + + RWKV_ENSURE_OR_FALSE(callback("ln_out.weight", model.ln_out_weight)); + RWKV_ENSURE_OR_FALSE(callback("ln_out.bias", model.ln_out_bias)); + RWKV_ENSURE_OR_FALSE(callback("head.weight", model.head)); + + return true; +} + +// Creates a ggml context and loads all parameter tensors from a model file. +static bool rwkv_load_model_from_file(const char * file_path, struct rwkv_model & model) { + struct stat file_stat; + + std::unordered_map parameters; + + rwkv_file file(fopen(file_path, "rb")); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file.file, "Failed to open file %s", file_path); + // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to get the file length. + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file.file), &file_stat) == 0, "Failed to stat file %s", file_path); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(file.file, model.header), "Invalid file header"); + + model.ggml_ctx = rwkv_init_ggml_context( + // ggml tensors must be aligned; assuming here that overhead of parameter headers, included in the file size, will account for that. + file_stat.st_size + rwkv_ggml_overhead(), + false + ); + + std::string name; + + struct ggml_tensor * tensor; + + while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, model.ggml_ctx, name, tensor), "Failed to read a model parameter"); + + parameters[std::move(name)] = tensor; + } + + std::unordered_map & parameters_ref = parameters; + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(model, [&](const char * key, struct ggml_tensor *& dest) { + struct ggml_tensor * tensor = parameters_ref[key]; + RWKV_ENSURE_OR_FALSE_MSG(tensor, "Model parameter %s not found", key); + dest = tensor; + return true; + })); + + // Verify order of dimensions. + struct ggml_tensor * emb = model.emb; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model.header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model.header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); + + return true; +} diff --git a/rwkv_operators.inc b/rwkv_operators.inc new file mode 100644 index 0000000..c24c91d --- /dev/null +++ b/rwkv_operators.inc @@ -0,0 +1,112 @@ +static void rwkv_validate_tensors_for_custom_unary_op(struct ggml_tensor * dest, const struct ggml_tensor * src) { + GGML_ASSERT(dest->type == GGML_TYPE_F32); + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(dest)); + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(ggml_are_same_shape(src, dest)); + // Verify that the shape is 2D. + GGML_ASSERT(dest->ne[2] == 1); + GGML_ASSERT(dest->ne[3] == 1); +} + +#define SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP() { (void) ith; (void) nth; (void) userdata; } + +static void rwkv_exp_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) { + rwkv_validate_tensors_for_custom_unary_op(dest, src); + + int64_t element_count = src->ne[0] * src->ne[1]; + float * src_data = (float *) src->data; + float * dest_data = (float *) dest->data; + + for (int64_t i = 0; i < element_count; i++) { + dest_data[i] = expf(src_data[i]); + } + + SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP(); +} + +static void rwkv_1_minus_x_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) { + rwkv_validate_tensors_for_custom_unary_op(dest, src); + + int64_t element_count = src->ne[0] * src->ne[1]; + float * src_data = (float *) src->data; + float * dest_data = (float *) dest->data; + + for (int64_t i = 0; i < element_count; i++) { + dest_data[i] = 1.0F - src_data[i]; + } + + SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP(); +} + +static void rwkv_sigmoid_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) { + rwkv_validate_tensors_for_custom_unary_op(dest, src); + + int64_t element_count = src->ne[0] * src->ne[1]; + float * src_data = (float *) src->data; + float * dest_data = (float *) dest->data; + + for (int64_t i = 0; i < element_count; i++) { + dest_data[i] = 1.0F / (1.0F + expf(-src_data[i])); + } + + SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP(); +} + +static void rwkv_max_impl( + struct ggml_tensor * dest, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + int ith, + int nth, + void * userdata +) { + GGML_ASSERT(dest->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(dest)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_are_same_shape(src0, dest)); + GGML_ASSERT(ggml_are_same_shape(src1, dest)); + // Verify that the shape is 2D. + GGML_ASSERT(dest->ne[2] == 1); + GGML_ASSERT(dest->ne[3] == 1); + + int64_t element_count = src0->ne[0] * src0->ne[1]; + float * src0_data = (float *) src0->data; + float * src1_data = (float *) src1->data; + float * dest_data = (float *) dest->data; + + for (int64_t i = 0; i < element_count; i++) { + dest_data[i] = fmaxf(src0_data[i], src1_data[i]); + } + + SUPPRESS_UNUSED_WARNINGS_IN_CUSTOM_OP(); +} + +// Element-wise exp(x) +struct ggml_tensor * rwkv_exp(struct ggml_context * ctx, struct ggml_tensor * x) { + return ggml_map_custom1(ctx, x, rwkv_exp_impl, 1, NULL); +} + +// Element-wise 1 - x +struct ggml_tensor * rwkv_1_minus_x(struct ggml_context * ctx, struct ggml_tensor * x) { + return ggml_map_custom1(ctx, x, rwkv_1_minus_x_impl, 1, NULL); +} + +// Element-wise sigmoid(x) +struct ggml_tensor * rwkv_sigmoid(struct ggml_context * ctx, struct ggml_tensor * x) { + return ggml_map_custom1(ctx, x, rwkv_sigmoid_impl, 1, NULL); +} + +// Element-wise max(x, y) +struct ggml_tensor * rwkv_max(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y) { + return ggml_map_custom2(ctx, x, y, rwkv_max_impl, 1, NULL); +} + +struct ggml_tensor * rwkv_layer_norm(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) { + // LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias` + // Looks like ggml_norm does the first part, we only need to apply weight & bias. + return ggml_add_inplace(ctx, ggml_mul_inplace(ctx, ggml_norm(ctx, x, 1e-5F), weight), bias); +} diff --git a/rwkv_quantize.inc b/rwkv_quantize.inc new file mode 100644 index 0000000..93ab098 --- /dev/null +++ b/rwkv_quantize.inc @@ -0,0 +1,171 @@ +// API function. +bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * type_name) { + global_last_error = RWKV_ERROR_NONE; + + enum ggml_type out_type = rwkv_type_to_ggml[rwkv_type_from_string(type_name)]; + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, + ggml_is_quantized(out_type), + "Unsupported output data type (%s)", + rwkv_type_to_string[rwkv_type_from_ggml[out_type]] + ); + + RWKV_MSG("Loading model from '%s'\n", in_path); + + struct stat in_stat; + + struct rwkv_file in_file(fopen(in_path, "rb")); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, in_file.file, "Failed to open %s for reading", in_path); + + // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length. + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(in_file.file), &in_stat) == 0, "failed to stat file %s", in_path); + + struct rwkv_file out_file(fopen(out_path, "wb")); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, out_file.file, "Failed to open %s for writing", out_path); + + struct rwkv_file_header in_header; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(in_file.file, in_header), "Invalid file header"); + + enum ggml_type in_type = rwkv_type_to_ggml[in_header.data_type]; + RWKV_ASSERT_FALSE_MSG( + RWKV_ERROR_FILE, + in_type == GGML_TYPE_F32 || in_type == GGML_TYPE_F16, + "Unsupported input data type (%s); needs to be FP32 or FP16", + rwkv_type_to_string[rwkv_type_from_ggml[in_type]] + ); + + struct rwkv_file_header out_header = in_header; + out_header.version = RWKV_FILE_VERSION; + out_header.data_type = rwkv_type_from_ggml[out_type]; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fwrite_file_header(out_file.file, out_header), "Failed to write file header"); + + // Process parameters. + size_t orig_total_size = 0; + size_t new_total_size = 0; + + // Required to init the F16 tables. + // Doesn't crash if ggml_init fails. + ggml_free(ggml_init({ 0, NULL, true })); + + size_t max_in_size = 0; + size_t max_out_size = 0; + size_t max_key_length = 0; + + while (ftell(in_file.file) < in_stat.st_size) { + struct rwkv_tensor_header header; + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, rwkv_fread_tensor_header_skip_name_and_data(in_file.file, header)); + + size_t in_size = header.size(); + + if (in_size > max_in_size) { + max_in_size = in_size; + } + + if (header.data_type == TYPE_FP16) { + if (in_size > max_out_size) { + max_out_size = in_size; + } + + size_t f32_size = rwkv_tensor_nbytes(GGML_TYPE_F32, header.width, header.height); + + if (f32_size > max_in_size) { + max_in_size = f32_size; + } + } + + size_t out_size = rwkv_tensor_nbytes(out_type, header.width, header.height); + + if (out_size > max_out_size) { + max_out_size = out_size; + } + + if (header.key_length > max_key_length) { + max_key_length = header.key_length; + } + } + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(in_file.file, sizeof(struct rwkv_file_header), SEEK_SET) == 0, "Failed to seek in file"); + + // This is a histogram of quantized values. If it shows single 1.0, then all 0.0, something went very wrong! + int64_t hist_all[16] {}; + + std::unique_ptr scratch(new(std::nothrow) uint8_t[max_in_size + max_out_size]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate buffer"); + + uint8_t * in_buf = scratch.get(); + uint8_t * out_buf = in_buf + max_in_size; + + struct rwkv_tensor tensor; + struct rwkv_tensor_header & header = tensor.header; + std::string & name = tensor.name; + uint8_t *& data = tensor.data; + + while (ftell(in_file.file) < in_stat.st_size) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(in_file.file, header), "Failed to read tensor header"); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(in_file.file, header.key_length, name), "Failed to read tensor name"); + + const char * name_str = name.c_str(); + RWKV_MSG("%*s - [%5" PRId32 ", %5" PRId32 "], type = %6s ", (int) max_key_length, name_str, header.width, header.height, rwkv_type_to_string[header.data_type]); + + data = header.data_type == TYPE_FP16 ? out_buf : in_buf; + size_t orig_size = header.size(), new_size = orig_size; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_data(in_file.file, orig_size, data), "\nFailed to read tensor data of %s", name_str); + + // Quantize only 2D tensors, except embedding and head matrices. + // Embedding and head take not too much space, especially in bigger models; + // but they significantly increase perplexity when quantized. + if ((header.data_type == TYPE_FP32 || header.data_type == TYPE_FP16) && + header.dim_count == 2 && + name != "emb.weight" && + name != "head.weight" + ) { + RWKV_MSG("quantizing... "); + + size_t nelements = (size_t) header.width * (size_t) header.height; + + if (header.data_type == TYPE_FP16) { + ggml_fp16_to_fp32_row((const ggml_fp16_t *) out_buf, (float *) in_buf, nelements); + } + + int64_t hist_cur[16] {}; + new_size = ggml_quantize_chunk(out_type, (const float *) in_buf, out_buf, 0, nelements, hist_cur); + header.data_type = rwkv_type_from_ggml[out_type]; + data = out_buf; + + RWKV_MSG("size = %8.2f MB -> %8.2f MB | hist: ", orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + + for (int i = 0; i < 16; i++) { + RWKV_MSG("%5.3f ", hist_cur[i] / (float) nelements); + hist_all[i] += hist_cur[i]; + } + + RWKV_MSG("\n"); + } else { + RWKV_MSG("size = %8.3f MB\n", orig_size / 1024.0 / 1024.0); + } + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_tensor(out_file.file, tensor), "Failed to write tensor %s", name_str); + orig_total_size += orig_size; + new_total_size += new_size; + } + + RWKV_MSG("original size = %8.2f MB\n", orig_total_size / 1024.0 / 1024.0); + RWKV_MSG("quantized size = %8.2f MB\n", new_total_size / 1024.0 / 1024.0); + RWKV_MSG("compression ratio = %8.2f\n", orig_total_size / float(new_total_size)); + + int64_t sum_all = 0; + + for (int i = 0; i < 16; i++) { + sum_all += hist_all[i]; + } + + RWKV_MSG("hist: "); + + for (int i = 0; i < 16; ++i) { + printf("%5.3f ", hist_all[i] / float(sum_all)); + } + + RWKV_MSG("\n"); + + return true; +} diff --git a/rwkv_utilities.inc b/rwkv_utilities.inc new file mode 100644 index 0000000..9b10b22 --- /dev/null +++ b/rwkv_utilities.inc @@ -0,0 +1,52 @@ +static size_t rwkv_tensor_nbytes(const enum ggml_type type, const int64_t width, const int64_t height) { + return (ggml_type_size(type) * width * height) / ggml_blck_size(type); +} + +// For some reason, ggml_nbytes calculates the size in a way +// incompatible with rwkv.cpp; we need our own function for that. +static size_t rwkv_tensor_nbytes(const struct ggml_tensor * tensor) { + return rwkv_tensor_nbytes(tensor->type, tensor->ne[0], tensor->ne[1]); +} + +// Minimum amount of memory required for a ggml context, not counting the tensor data. +static size_t rwkv_ggml_overhead() { + return ggml_tensor_overhead() * GGML_MAX_NODES + ggml_graph_overhead(); +} + +static struct ggml_context * rwkv_init_ggml_context(const size_t memory_size, const bool no_alloc) { + struct ggml_init_params init_params = { + memory_size, + NULL, + no_alloc + }; + + return ggml_init(init_params); +} + +// IO utilities + +// Reads a single uint32 value from a file. +static bool rwkv_fread_uint32(FILE * file, uint32_t & dest) { + return fread((void *) &dest, sizeof(uint32_t), 1, file) == 1; +} + +// Reads a single string value from a file. +static bool rwkv_fread_string(FILE * file, const size_t length, std::string & dest) { + dest.resize(length); + return fread((void *) dest.data(), length, 1, file) == 1; +} + +// Reads a single data buffer from a file. +static bool rwkv_fread_data(FILE * file, const size_t length, void * dest) { + return fread(dest, length, 1, file) == 1; +} + +// Writes a single string value to a file. +static bool rwkv_fwrite_string(FILE * file, const std::string & value) { + return fwrite((const void *) value.data(), value.length(), 1, file) == 1; +} + +// Writes a single data buffer to a file. +static bool rwkv_fwrite_data(FILE * file, const void * data, const size_t length) { + return fwrite(data, length, 1, file) == 1; +} diff --git a/tests/assertions.inc b/tests/assertions.inc new file mode 100644 index 0000000..df5ba64 --- /dev/null +++ b/tests/assertions.inc @@ -0,0 +1,27 @@ +#ifndef ASSERTIONS_INC +#define ASSERTIONS_INC + +#include + +#define ASSERT(x, ...) {\ + if (!(x)) {\ + fprintf(stderr, "*** Assertion failed ***\n");\ + fprintf(stderr, __VA_ARGS__);\ + fprintf(stderr, "\n%s:%d\n", __FILE__, __LINE__);\ + abort();\ + }\ + } + +#define ASSERT_ELEMENT_F32(tensor, i, expected_value) {\ + float actual = ((float *) tensor->data)[i];\ + ASSERT(\ + fabsf(actual - expected_value) <= 0.0000001F,\ + "At %s[%d]: expected %f, actual %f",\ + #tensor,\ + i,\ + (double) expected_value,\ + (double) actual\ + );\ + } + +#endif diff --git a/tests/logit_difference_validator.inc b/tests/logit_difference_validator.inc index d449004..9ff0591 100644 --- a/tests/logit_difference_validator.inc +++ b/tests/logit_difference_validator.inc @@ -1,16 +1,16 @@ -// TODO Move to inc -#define ASSERT(x, ...) {\ - if (!(x)) {\ - fprintf(stderr, "*** Assertion failed ***\n");\ - fprintf(stderr, __VA_ARGS__);\ - fprintf(stderr, "\n%s:%d\n", __FILE__, __LINE__);\ - abort();\ - }\ - } +#ifndef LOGIT_DIFFERENCE_VALIDATOR_INC +#define LOGIT_DIFFERENCE_VALIDATOR_INC + +#include +#include + +#include -// RWKV Tiny is a byte-level model +#include "assertions.inc" + +// RWKV Tiny is a byte-level model. #define N_VOCAB 256 -// Also test multithreading +// Also test multithreading. #define N_THREADS 2 void load_expected_logits(float * expected_logits) { @@ -28,7 +28,7 @@ void test_model(const char * model_path, const float * expected_logits, const fl enum rwkv_error_flags error = rwkv_get_last_error(NULL); ASSERT(error == 0, "Unexpected error %d", error); -#ifdef GGML_USE_CUBLAS +#if defined(GGML_USE_CUBLAS) ASSERT(rwkv_gpu_offload_layers(model, rwkv_get_n_layer(model)), "Failed to offload layers to GPU"); #endif @@ -36,17 +36,21 @@ void test_model(const char * model_path, const float * expected_logits, const fl ASSERT(n_vocab == N_VOCAB, "Unexpected n_vocab in the model"); - float * state = malloc(sizeof(float) * rwkv_get_state_len(model)); - float * logits = malloc(sizeof(float) * n_vocab); + float * state = calloc(rwkv_get_state_len(model), sizeof(float)); + float * logits = calloc(n_vocab, sizeof(float)); - char * prompt = "\"in"; - uint32_t prompt_seq[] = { '"', 'i', 'n' }; + ASSERT(state != NULL, "Failed to allocate state"); + ASSERT(logits != NULL, "Failed to allocate logits"); + const char * prompt = "\"in"; + const uint32_t prompt_seq[] = { '"', 'i', 'n' }; const size_t prompt_length = strlen(prompt); + // --- + rwkv_init_state(model, state); - for (size_t i = 0; i < prompt_length; i++) { + for (size_t i = 0; prompt[i] != 0; i++) { rwkv_eval(model, prompt[i], state, state, logits); } @@ -56,10 +60,12 @@ void test_model(const char * model_path, const float * expected_logits, const fl diff_sum += logits[i] - expected_logits[i]; } - fprintf(stderr, "Difference sum: %f\n", diff_sum); + fprintf(stderr, "Serial difference sum: %f\n", diff_sum); // When something breaks, difference would be way more than 10 - ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); + ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big serial difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); + + // --- rwkv_init_state(model, state); rwkv_eval_sequence(model, prompt_seq, prompt_length, state, state, logits); @@ -75,8 +81,12 @@ void test_model(const char * model_path, const float * expected_logits, const fl // When something breaks, difference would be way more than 10 ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big sequence difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); + // --- + rwkv_free(model); free(state); free(logits); } + +#endif diff --git a/tests/test_context_cloning.c b/tests/test_context_cloning.c index e911f98..9087fca 100644 --- a/tests/test_context_cloning.c +++ b/tests/test_context_cloning.c @@ -1,72 +1,57 @@ // Tests that evaluation works after the context was cloned. -#include - #include #include #include -int main() { - struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); +#include + +#include "assertions.inc" + +int main(void) { + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); - if (!ctx) { - enum rwkv_error_flags error = rwkv_get_last_error(NULL); - fprintf(stderr, "Unexpected error 0x%.8X\n", error); - return EXIT_FAILURE; - } + ASSERT(ctx != NULL, "Unexpected error 0x%.8X", rwkv_get_last_error(NULL)); - float * state = calloc(rwkv_get_state_len(ctx), sizeof(float)); - float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); + float * state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); - if (!state || !logits) { - fprintf(stderr, "Failed to allocate state or logits\n"); - return EXIT_FAILURE; - } + ASSERT(state != NULL, "Failed to allocate state"); + ASSERT(logits != NULL, "Failed to allocate logits"); - const unsigned char prompt[12] = "hello world"; + const uint8_t prompt[12] = "hello world"; - rwkv_eval(ctx, prompt[0], NULL, state, logits); + rwkv_eval(ctx, prompt[0], NULL, state, logits); - for (int i = 1; prompt[i] != 0; i++) { - rwkv_eval(ctx, prompt[i], state, state, logits); - } + for (size_t i = 1; prompt[i] != 0; i++) { + rwkv_eval(ctx, prompt[i], state, state, logits); + } - float * expected_logits = logits; + float * expected_logits = logits; - logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); + logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); - if (!logits) { - fprintf(stderr, "Failed to allocate logits\n"); - return EXIT_FAILURE; - } + ASSERT(logits != NULL, "Failed to allocate logits"); - struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2); + struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2); - if (ctx == ctx2) { - fprintf(stderr, "Same context was returned\n"); - return EXIT_FAILURE; - } + ASSERT(ctx != ctx2, "Same context was returned"); // The cloned context should work fine after the original context was freed. - rwkv_free(ctx); + rwkv_free(ctx); - rwkv_eval(ctx2, prompt[0], NULL, state, logits); + rwkv_eval(ctx2, prompt[0], NULL, state, logits); - for (int i = 1; prompt[i] != 0; i++) { - rwkv_eval(ctx2, prompt[i], state, state, logits); - } + for (int i = 1; prompt[i] != 0; i++) { + rwkv_eval(ctx2, prompt[i], state, state, logits); + } - if (memcmp(expected_logits, logits, rwkv_get_logits_len(ctx2) * sizeof(float))) { - fprintf(stderr, "Results are not identical :(\n"); - return EXIT_FAILURE; - } else { - fprintf(stdout, "Results are identical, success!\n"); - } + ASSERT(memcmp(expected_logits, logits, rwkv_get_logits_len(ctx2) * sizeof(float)) == 0, "Results are not identical"); - rwkv_free(ctx2); + rwkv_free(ctx2); - free(expected_logits); - free(logits); - free(state); + free(expected_logits); + free(logits); + free(state); - return EXIT_SUCCESS; + return 0; } diff --git a/tests/test_ggml_basics.c b/tests/test_ggml_basics.c index d99cab2..e767387 100644 --- a/tests/test_ggml_basics.c +++ b/tests/test_ggml_basics.c @@ -1,29 +1,15 @@ // Tests that ggml basics work. -#include - -#include #include +#include #include +#include + +#include "assertions.inc" + #define SET_ELEMENT_F32(tensor, i, value) ((float *) tensor->data)[i] = value -// TODO Move to inc -#define ASSERT(x, ...) {\ - if (!(x)) {\ - fprintf(stderr, "*** Assertion failed ***\n");\ - fprintf(stderr, __VA_ARGS__);\ - fprintf(stderr, "\n%s:%d\n", __FILE__, __LINE__);\ - abort();\ - }\ - } - -#define ASSERT_ELEMENT_F32(tensor, i, expected_value) {\ - float actual = ((float *) tensor->data)[i];\ - ASSERT(fabsf(actual - expected_value) <= 0.0000001F, "At %s[%d]: expected %f, actual %f", #tensor, i, (double) expected_value, (double) actual);\ - } - -// Tests simple computation in a single context. -static void test_computation(void) { +void test_simple_computation(void) { struct ggml_init_params params = { .mem_size = 16 * 1024, .mem_buffer = NULL, @@ -64,9 +50,8 @@ static void test_computation(void) { ggml_free(ctx); } -// Tests that operations on tensors from different contexts work. // RWKV model loading code depends on this behavior. -static void test_tensors_from_different_contexts(void) { +void test_computation_on_tensors_from_different_contexts(void) { struct ggml_init_params params = { .mem_size = 16 * 1024, .mem_buffer = NULL, @@ -104,9 +89,9 @@ static void test_tensors_from_different_contexts(void) { } int main(void) { - test_computation(); + test_simple_computation(); - test_tensors_from_different_contexts(); + test_computation_on_tensors_from_different_contexts(); return 0; } diff --git a/tests/test_logit_calculation_skipping.c b/tests/test_logit_calculation_skipping.c index cb7239d..2765a19 100644 --- a/tests/test_logit_calculation_skipping.c +++ b/tests/test_logit_calculation_skipping.c @@ -1,132 +1,98 @@ // Tests that evaluation works when the logits parameter was set to NULL. -#include - #include #include #include +#include + +#include "assertions.inc" + #define TOKEN_COUNT 11 -static const unsigned char prompt[TOKEN_COUNT + 1] = "hello world"; +const char prompt[TOKEN_COUNT + 1] = "hello world"; -static int test_serial_mode() { - struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); +void test_serial_mode(void) { + fprintf(stderr, "Testing serial mode\n"); - if (!ctx) { - enum rwkv_error_flags error = rwkv_get_last_error(NULL); - fprintf(stderr, "Unexpected error 0x%.8X\n", error); - return EXIT_FAILURE; - } + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); - float * state = calloc(rwkv_get_state_len(ctx), sizeof(float)); - float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); + ASSERT(ctx != NULL, "Unexpected error 0x%.8X", rwkv_get_last_error(NULL)); - if (!state || !logits) { - fprintf(stderr, "Failed to allocate state or logits\n"); - return EXIT_FAILURE; - } + float * state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); - rwkv_eval(ctx, prompt[0], NULL, state, logits); + ASSERT(state != NULL, "Failed to allocate state"); + ASSERT(logits != NULL, "Failed to allocate logits"); - for (int i = 1; prompt[i] != 0; i++) { - rwkv_eval(ctx, prompt[i], state, state, logits); - } + rwkv_eval(ctx, prompt[0], NULL, state, logits); - float * expected_state = state; + for (size_t i = 1; prompt[i] != 0; i++) { + rwkv_eval(ctx, prompt[i], state, state, logits); + } - state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + float * expected_state = state; - if (!state) { - fprintf(stderr, "Failed to allocate state\n"); - return EXIT_FAILURE; - } + state = calloc(rwkv_get_state_len(ctx), sizeof(float)); - rwkv_eval(ctx, prompt[0], NULL, state, NULL); + ASSERT(state != NULL, "Failed to allocate state"); - for (int i = 1; prompt[i] != 0; i++) { - rwkv_eval(ctx, prompt[i], state, state, NULL); - } + rwkv_eval(ctx, prompt[0], NULL, state, NULL); - if (memcmp(expected_state, state, rwkv_get_state_len(ctx) * sizeof(float))) { - fprintf(stderr, "Serial mode: results are not identical :(\n"); - return EXIT_FAILURE; - } else { - fprintf(stdout, "Serial mode: results are identical, success!\n"); - } + for (int i = 1; prompt[i] != 0; i++) { + rwkv_eval(ctx, prompt[i], state, state, NULL); + } - rwkv_free(ctx); + ASSERT(memcmp(expected_state, state, rwkv_get_state_len(ctx) * sizeof(float)) == 0, "Results are not identical"); - free(logits); - free(state); - free(expected_state); + rwkv_free(ctx); - return EXIT_SUCCESS; + free(logits); + free(state); + free(expected_state); } -static int test_sequential_mode() { - struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); +void test_sequential_mode(void) { + fprintf(stderr, "Testing sequential mode\n"); - if (!ctx) { - enum rwkv_error_flags error = rwkv_get_last_error(NULL); - fprintf(stderr, "Unexpected error 0x%.8X\n", error); - return EXIT_FAILURE; - } + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); - float * state = calloc(rwkv_get_state_len(ctx), sizeof(float)); - float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); + ASSERT(ctx != NULL, "Unexpected error 0x%.8X", rwkv_get_last_error(NULL)); - if (!state || !logits) { - fprintf(stderr, "Failed to allocate state or logits\n"); - return EXIT_FAILURE; - } + float * state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); - uint32_t prompt_tokens[TOKEN_COUNT]; + ASSERT(state != NULL, "Failed to allocate state"); + ASSERT(logits != NULL, "Failed to allocate logits"); - for (int i = 0; i < TOKEN_COUNT; i++) { - prompt_tokens[i] = prompt[i]; - } + uint32_t prompt_tokens[TOKEN_COUNT]; - rwkv_eval_sequence(ctx, prompt_tokens, TOKEN_COUNT, NULL, state, logits); + for (int i = 0; i < TOKEN_COUNT; i++) { + prompt_tokens[i] = prompt[i]; + } - float * expected_state = state; + rwkv_eval_sequence(ctx, prompt_tokens, TOKEN_COUNT, NULL, state, logits); - state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + float * expected_state = state; - if (!state) { - fprintf(stderr, "Failed to allocate state\n"); - return EXIT_FAILURE; - } + state = calloc(rwkv_get_state_len(ctx), sizeof(float)); - rwkv_eval_sequence(ctx, prompt_tokens, TOKEN_COUNT, NULL, state, NULL); + ASSERT(state != NULL, "Failed to allocate state"); - if (memcmp(expected_state, state, rwkv_get_state_len(ctx) * sizeof(float))) { - fprintf(stderr, "Sequential mode: results are not identical :(\n"); - return EXIT_FAILURE; - } else { - fprintf(stdout, "Sequential mode: results are identical, success!\n"); - } + rwkv_eval_sequence(ctx, prompt_tokens, TOKEN_COUNT, NULL, state, NULL); - rwkv_free(ctx); + ASSERT(memcmp(expected_state, state, rwkv_get_state_len(ctx) * sizeof(float)) == 0, "Results are not identical"); - free(logits); - free(state); - free(expected_state); + rwkv_free(ctx); - return EXIT_SUCCESS; + free(logits); + free(state); + free(expected_state); } -int main() { - int result = test_serial_mode(); - - if (result != EXIT_SUCCESS) { - return result; - } - - result = test_sequential_mode(); +int main(void) { + test_serial_mode(); - if (result != EXIT_SUCCESS) { - return result; - } + test_sequential_mode(); - return EXIT_SUCCESS; + return 0; } diff --git a/tests/test_quantization_format_compatibility.c b/tests/test_quantization_format_compatibility.c index 0c3c4e3..e652e6c 100644 --- a/tests/test_quantization_format_compatibility.c +++ b/tests/test_quantization_format_compatibility.c @@ -1,21 +1,19 @@ // Tests that existing Q5_0 & Q5_1 model files are still working. -#include - -#include #include -#include -#include +#include + +#include #include "logit_difference_validator.inc" int main(void) { fprintf(stderr, "System info: %s\n", rwkv_get_system_info_string()); - float * expected_logits = malloc(sizeof(float) * N_VOCAB); + float * expected_logits = calloc(N_VOCAB, sizeof(float)); load_expected_logits(expected_logits); test_model("tiny-rwkv-660K-Q5_0.bin", expected_logits, -0.170404F); - test_model("tiny-rwkv-660K-Q5_1.bin", expected_logits, 0.278034F); + test_model("tiny-rwkv-660K-Q5_1.bin", expected_logits, +0.278034F); free(expected_logits); diff --git a/tests/test_quantized_matmul_on_gpu.c b/tests/test_quantized_matmul_on_gpu.c index cae8748..854980c 100644 --- a/tests/test_quantized_matmul_on_gpu.c +++ b/tests/test_quantized_matmul_on_gpu.c @@ -1,27 +1,21 @@ // Tests that quantized matmul on GPU works. -#include - -#include #include +#include + +#if defined(GGML_USE_CUBLAS) + #include -// TODO Move to inc -#define ASSERT(x, ...) {\ - if (!(x)) {\ - fprintf(stderr, "*** Assertion failed ***\n");\ - fprintf(stderr, __VA_ARGS__);\ - fprintf(stderr, "\n%s:%d\n", __FILE__, __LINE__);\ - abort();\ - }\ - } +#include +#include "ggml/src/ggml-cuda.h" + +#include "assertions.inc" #define SET_ELEMENT_F32(tensor, i, value) ((float *) tensor->data)[i] = value #define ELEMENT_COUNT 32 int main(void) { - #ifdef GGML_USE_CUBLAS - struct ggml_init_params params = { .mem_size = 16 * 1024, .mem_buffer = NULL, @@ -87,7 +81,15 @@ int main(void) { ggml_free(ctx); - #endif + return 0; +} + +#else + +int main(void) { + fprintf(stderr, "Skipping test_quantized_matmul_on_gpu.c: GGML_USE_CUBLAS is not defined\n"); return 0; } + +#endif diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index c8cbbd0..b3a45cc 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -1,43 +1,44 @@ // Tests that tiny RWKV outputs expected results in all data types. -#include - -#include #include -#include -#include +#include + +#include #include "logit_difference_validator.inc" int main(void) { fprintf(stderr, "System info: %s\n", rwkv_get_system_info_string()); - float * expected_logits = malloc(sizeof(float) * N_VOCAB); + // Silences the overly verbose output during quantization. + rwkv_set_print_errors(NULL, false); + + float * expected_logits = calloc(N_VOCAB, sizeof(float)); load_expected_logits(expected_logits); - // Somehow when using cuBLAS the calculation of Q4_1 may different from cpu only + // Somehow when using cuBLAS the result of Q4_1 is different from CPU only. float expected_difference_sum[14] = { - 0.000000F, - -0.005320F, + +0.000000F, // FP32 + -0.005320F, // FP16 - -0.160030F, -#ifdef GGML_USE_CUBLAS - -0.547409F, + -0.160030F, // Q4_0 +#if defined(GGML_USE_CUBLAS) + -0.547409F, // Q4_1 #else - -0.370606F, + -0.370606F, // Q4_1 #endif - -0.170404F, - 0.278034F, - 0.071216F, + -0.170404F, // Q5_0 + +0.278034F, // Q5_1 + +0.071216F, // Q8_0 - 0.154614F, -#ifdef GGML_USE_CUBLAS - -0.539827F, + +0.154614F, // Q4_0 +#if defined(GGML_USE_CUBLAS) + -0.539827F, // Q4_1 #else - -0.372169F, + -0.372169F, // Q4_1 #endif - -0.170043F, - 0.294953F, - 0.065571F, + -0.170043F, // Q5_0 + +0.294953F, // Q5_1 + +0.065571F, // Q8_0 }; test_model("tiny-rwkv-660K-FP32.bin", expected_logits, expected_difference_sum[0]);