From a3178b20ea0a1600f4d5d9dc06f67800d8bbb62a Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 8 May 2023 14:28:54 +0500 Subject: [PATCH] Various improvements (#52) * Update ggml * Add link to pre-quantized models in README * Enable W4 for MSVC * Fix warnings, clean up code * Fix LoRA merge script --- CMakeLists.txt | 7 ++- README.md | 20 +++++++-- ggml | 2 +- rwkv.cpp | 85 +++++++++++++++++------------------- rwkv.h | 8 ++-- rwkv/merge_lora_into_ggml.py | 33 +++++++------- tests/test_ggml_basics.c | 4 +- tests/test_tiny_rwkv.c | 4 +- 8 files changed, 90 insertions(+), 73 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8ba5818..7ff772f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -119,7 +119,12 @@ if (RWKV_ALL_WARNINGS) -Wno-multichar ) else() - # TODO [llama.cpp]: msvc + set(c_flags + -W4 + ) + set(cxx_flags + -W4 + ) endif() add_compile_options( diff --git a/README.md b/README.md index feee4f7..b5c09f4 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,8 @@ On Windows: to check whether your CPU supports AVX2 or AVX-512, [use CPU-Z](http #### Option 2.2. Build the library yourself +This option is recommended for maximum performance, because the library would be built specifically for your CPU and OS. + ##### Windows **Requirements**: [CMake](https://cmake.org/download/) or [CMake from anaconda](https://anaconda.org/conda-forge/cmake), MSVC compiler. @@ -75,10 +77,22 @@ cmake --build . --config Release If everything went OK, `librwkv.so` (Linux) or `librwkv.dylib` (MacOS) file should appear in the base repo folder. -### 3. Download an RWKV model from [Hugging Face](https://huggingface.co/BlinkDL) like [this one](https://huggingface.co/BlinkDL/rwkv-4-pile-169m/blob/main/RWKV-4-Pile-169M-20220807-8023.pth) and convert it into `ggml` format +### 3. Get an RWKV model + +#### Option 3.1. Download pre-quantized Raven model + +There are pre-quantized Raven models available on [Hugging Face](https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main). Check that you are downloading `.bin` file, NOT `.pth`. + +#### Option 3.2. Convert and quantize PyTorch model **Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/). +This option would require a little more manual work, but you can use it with any RWKV model and any target format. + +**First**, download a model from [Hugging Face](https://huggingface.co/BlinkDL) like [this one](https://huggingface.co/BlinkDL/rwkv-4-pile-169m/blob/main/RWKV-4-Pile-169M-20220807-8023.pth). + +**Second**, convert it into `rwkv.cpp` format using following commands: + ```commandline # Windows python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float16 @@ -87,9 +101,7 @@ python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\ python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin float16 ``` -#### 3.1. Optionally, quantize the model - -To convert the model into one of quantized formats from the table above, run: +**Optionally**, quantize the model into one of quantized formats from the table above: ```commandline # Windows diff --git a/ggml b/ggml index 9d7974c..ff6e03c 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 9d7974c3cf1284b4ddb926d94552e9fe4c4ad483 +Subproject commit ff6e03cbcd9bf6e9fa41d49f2495c042efae4dc6 diff --git a/rwkv.cpp b/rwkv.cpp index 9ba7786..a7c2ee4 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -15,29 +15,31 @@ // --- Utilities --- -// Checks that x is not false. If x is false, prints fancy message to stderr and returns 0. -#define RWKV_ASSERT_FALSE(x, ...) \ - do { \ +// Checks that x is not false. If x is false, prints fancy message to stderr and returns RET_VAL. +#define RWKV_ASSERT(RET_VAL, x, ...) \ + { \ if (!(x)) { \ fprintf(stderr, __VA_ARGS__); \ fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ - return false; \ + return RET_VAL; \ } \ - } while (0) + } + +// Checks that x is not false. If x is false, prints fancy message to stderr and returns false. +#define RWKV_ASSERT_FALSE(x, ...) RWKV_ASSERT(false, x, __VA_ARGS__) // Checks that x is not false. If x is false, prints fancy message to stderr and returns NULL. -#define RWKV_ASSERT_NULL(x, ...) \ - do { \ - if (!(x)) { \ - fprintf(stderr, __VA_ARGS__); \ - fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \ - return NULL; \ - } \ - } while (0) +#define RWKV_ASSERT_NULL(x, ...) RWKV_ASSERT(NULL, x, __VA_ARGS__) // Reads single int32 value from a file. bool read_int32(FILE * file, int32_t * dest) { - RWKV_ASSERT_FALSE(fread(dest, 4, 1, file) == 1, "Failed to read an int32 value from a file"); + RWKV_ASSERT_FALSE(fread(dest, sizeof(int32_t), 1, file) == 1, "Failed to read an int32 value from a file"); + return true; +} + +// Reads single uint32 value from a file. +bool read_uint32(FILE * file, uint32_t * dest) { + RWKV_ASSERT_FALSE(fread(dest, sizeof(uint32_t), 1, file) == 1, "Failed to read a uint32 value from a file"); return true; } @@ -98,9 +100,9 @@ struct rwkv_layer { }; struct rwkv_model { - int32_t n_vocab; - int32_t n_layer; - int32_t n_embed; + uint32_t n_vocab; + uint32_t n_layer; + uint32_t n_embed; // 0 for float32, 1 for float16. int32_t data_type; @@ -119,18 +121,18 @@ struct rwkv_model { // Finds model parameter by key and sets it into dest. // If the parameter was not found, returns false. -bool set_parameter(std::unordered_map * parameters, char * key, struct ggml_tensor ** dest) { +bool set_parameter(std::unordered_map * parameters, std::string key, struct ggml_tensor ** dest) { struct ggml_tensor * parameter = (*parameters)[key]; - RWKV_ASSERT_FALSE(parameter != NULL, "Parameter %s not found in model file", key); + RWKV_ASSERT_FALSE(parameter != NULL, "Parameter %s not found in model file", key.c_str()); *dest = parameter; return true; } // Finds block parameter by block index and key and sets it into dest. // If the parameter was not found, returns false. -bool set_block_parameter(std::unordered_map * parameters, int32_t block_index, char * key, struct ggml_tensor ** dest) { +bool set_block_parameter(std::unordered_map * parameters, int32_t block_index, std::string key, struct ggml_tensor ** dest) { char full_key[128]; - sprintf(full_key, "blocks.%d.%s", block_index, key); + sprintf(full_key, "blocks.%d.%s", block_index, key.c_str()); return set_parameter(parameters, full_key, dest); } @@ -198,7 +200,7 @@ struct rwkv_context { bool freed; }; -struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_threads) { +struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { FILE * file = fopen(file_path, "rb"); RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path); @@ -212,14 +214,9 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr struct rwkv_model * model = (struct rwkv_model *) calloc(1, sizeof(struct rwkv_model)); - read_int32(file, &(model->n_vocab)); - RWKV_ASSERT_NULL(model->n_vocab > 0, "Non-positive n_vocab %d", model->n_vocab); - - read_int32(file, &(model->n_embed)); - RWKV_ASSERT_NULL(model->n_embed > 0, "Non-positive n_embed %d", model->n_embed); - - read_int32(file, &(model->n_layer)); - RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer); + read_uint32(file, &(model->n_vocab)); + read_uint32(file, &(model->n_embed)); + read_uint32(file, &(model->n_layer)); read_int32(file, &(model->data_type)); RWKV_ASSERT_NULL(model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type); @@ -321,7 +318,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr set_parameter(¶meters, "blocks.0.ln0.weight", &(model->ln0_weight)); set_parameter(¶meters, "blocks.0.ln0.bias", &(model->ln0_bias)); - for (int i = 0; i < model->n_layer; i++) { + for (uint32_t i = 0; i < model->n_layer; i++) { rwkv_layer layer = model->layers[i]; set_block_parameter(¶meters, i, "ln1.weight", &(layer.ln1_weight)); @@ -360,8 +357,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr RWKV_ASSERT_NULL(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %lld", emb->ne[0]); RWKV_ASSERT_NULL(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %lld", emb->ne[1]); - int32_t n_embed = model->n_embed; - int32_t n_layer = model->n_layer; + uint32_t n_embed = model->n_embed; + uint32_t n_layer = model->n_layer; // Build graph struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed); @@ -376,7 +373,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr // We collect parts of new state here. Each part is (n_embed) vector. struct ggml_tensor ** state_parts = new ggml_tensor * [n_layer * 5]; - for (int i = 0; i < n_layer; i++) { + for (uint32_t i = 0; i < n_layer; i++) { auto layer = model->layers[i]; // RWKV/time mixing @@ -533,7 +530,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr *graph = ggml_build_forward(logits); - for (int i = 0; i < n_layer * 5; i++) { + for (uint32_t i = 0; i < n_layer * 5; i++) { ggml_build_forward_expand(graph, state_parts[i]); } @@ -550,30 +547,30 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr return rwkv_ctx; } -uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) { +uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) { return ctx->model->n_layer * 5 * ctx->model->n_embed; } -uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) { +uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) { return ctx->model->n_vocab; } -bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out) { +bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) { RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL"); RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL"); - int32_t n_layer = ctx->model->n_layer; - int32_t n_embed = ctx->model->n_embed; - int32_t n_vocab = ctx->model->n_vocab; + uint32_t n_layer = ctx->model->n_layer; + uint32_t n_embed = ctx->model->n_embed; + uint32_t n_vocab = ctx->model->n_vocab; - RWKV_ASSERT_FALSE(token >= 0 && token < n_vocab, "Token is out of range 0..%d", n_vocab - 1); + RWKV_ASSERT_FALSE(token < (uint32_t) n_vocab, "Token is out of range 0..%d", n_vocab - 1); ggml_set_i32_1d(ctx->token_index, 0, token); if (state_in == NULL) { ggml_set_f32(ctx->state, 0.0F); - for (int i = 0; i < n_layer; i++) { + for (uint32_t i = 0; i < n_layer; i++) { // state[5 * i + 4] = -1e30 ggml_set_f32( ggml_view_1d(ctx->ctx, ctx->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)), @@ -586,7 +583,7 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float ggml_graph_compute(ctx->ctx, ctx->graph); - for (size_t i = 0; i < size_t(n_layer * 5); i++) { + for (uint32_t i = 0; i < n_layer * 5; i++) { struct ggml_tensor * part = ctx->state_parts[i]; memcpy(state_out + i * n_embed, part->data, part->ne[0] * sizeof(float)); diff --git a/rwkv.h b/rwkv.h index 3a90c73..46abb61 100644 --- a/rwkv.h +++ b/rwkv.h @@ -33,7 +33,7 @@ extern "C" { // Returns NULL on any error. Error messages would be printed to stderr. // - model_file_path: path to model file in ggml format. // - n_threads: count of threads to use, must be positive. - RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, uint32_t n_threads); + RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads); // Evaluates the model for a single token. // Returns false on any error. Error messages would be printed to stderr. @@ -41,13 +41,13 @@ extern "C" { // - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass. // - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. // - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. - RWKV_API bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out); + RWKV_API bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out); // Returns count of FP32 elements in state buffer. - RWKV_API uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx); + RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx); // Returns count of FP32 elements in logits buffer. - RWKV_API uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx); + RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx); // Frees all allocated memory and the context. RWKV_API void rwkv_free(struct rwkv_context * ctx); diff --git a/rwkv/merge_lora_into_ggml.py b/rwkv/merge_lora_into_ggml.py index e5c7d3a..e7d9523 100644 --- a/rwkv/merge_lora_into_ggml.py +++ b/rwkv/merge_lora_into_ggml.py @@ -113,29 +113,32 @@ def main() -> None: del lora_state_dict[key] - lora_A_key: str = key.replace('.weight', '') + '.lora_A.weight' - lora_B_key: str = key.replace('.weight', '') + '.lora_B.weight' + for suffix in ['.weight', '']: + lora_A_key: str = key.replace('.weight', '') + '.lora_A' + suffix + lora_B_key: str = key.replace('.weight', '') + '.lora_B' + suffix - if lora_A_key in lora_state_dict: - lora_A: torch.Tensor = lora_state_dict[lora_A_key] - lora_B: torch.Tensor = lora_state_dict[lora_B_key] + if lora_A_key in lora_state_dict: + lora_A: torch.Tensor = lora_state_dict[lora_A_key] + lora_B: torch.Tensor = lora_state_dict[lora_B_key] - assert lora_B.shape[1] == lora_A.shape[0], f'Invalid shape of LoRA matrices for {key}: ' \ - f'{lora_A.shape}, {lora_B.shape}' + assert lora_B.shape[1] == lora_A.shape[0], f'Invalid shape of LoRA matrices for {key}: ' \ + f'{lora_A.shape}, {lora_B.shape}' - lora_R: int = lora_B.shape[1] + lora_R: int = lora_B.shape[1] - replacement: torch.Tensor = parameter + lora_B @ lora_A * (args.lora_alpha / lora_R) + replacement: torch.Tensor = parameter + lora_B @ lora_A * (args.lora_alpha / lora_R) - if parameter.dtype == torch.float16: - replacement = replacement.half() + if parameter.dtype == torch.float16: + replacement = replacement.half() - parameter = replacement + parameter = replacement + + print(f'Merged LoRA into parameter {key}, lora_r = {lora_R}') - print(f'Merged LoRA into parameter {key}, lora_r = {lora_R}') + del lora_state_dict[lora_A_key] + del lora_state_dict[lora_B_key] - del lora_state_dict[lora_A_key] - del lora_state_dict[lora_B_key] + break write_parameter(out_file, key, parameter) diff --git a/tests/test_ggml_basics.c b/tests/test_ggml_basics.c index d14f85a..a31687a 100644 --- a/tests/test_ggml_basics.c +++ b/tests/test_ggml_basics.c @@ -19,10 +19,10 @@ #define ASSERT_ELEMENT_F32(tensor, i, expected_value) {\ float actual = ((float *) tensor->data)[i];\ - ASSERT(fabsf(actual - expected_value) <= 0.0000001F, "At %s[%d]: expected %f, actual %f", #tensor, i, expected_value, actual);\ + ASSERT(fabsf(actual - expected_value) <= 0.0000001F, "At %s[%d]: expected %f, actual %f", #tensor, i, (double) expected_value, (double) actual);\ } -int main(int argc, const char ** argv) { +int main(void) { struct ggml_init_params params = { .mem_size = 16 * 1024, .mem_buffer = NULL, diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index bfdc356..bfb726e 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -51,7 +51,7 @@ void test_model(const char * model_path, const float * expected_logits, const fl fprintf(stderr, "Difference sum: %f\n", diff_sum); // When something breaks, difference would be way more than 10 - ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big difference %f, expected no more than %f", diff_sum, max_diff); + ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); rwkv_free(model); @@ -59,7 +59,7 @@ void test_model(const char * model_path, const float * expected_logits, const fl free(logits); } -int main(int argc, const char ** argv) { +int main(void) { fprintf(stderr, "System info: %s\n", rwkv_get_system_info_string()); float * expected_logits = malloc(sizeof(float) * N_VOCAB);