diff --git a/.gitmodules b/.gitmodules index 8527898..67eaf4a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "ggml"] path = ggml - url = https://github.com/ggerganov/ggml - branch = master + url = https://github.com/saharNooby/ggml + branch = increased-node-limit-2023-09-19 diff --git a/CMakeLists.txt b/CMakeLists.txt index 24cd1e1..2b8beaa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,6 +106,7 @@ if (RWKV_CUBLAS) cmake_minimum_required(VERSION 3.17) find_package(CUDAToolkit) + if (CUDAToolkit_FOUND) message(STATUS "cuBLAS found") @@ -121,6 +122,17 @@ if (RWKV_CUBLAS) set(RWKV_EXTRA_LIBS ${RWKV_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) endif() + # Architecture set-up copy-pasted from https://github.com/ggerganov/llama.cpp/blob/111163e2463171891680feed94371eb9becd9817/CMakeLists.txt#L317 + if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + # 52: lowest CUDA 12 standard + # 60: f16 CUDA intrinsics + # 61: integer CUDA intrinsics + # 70: compute capability at which unrolling a loop in mul_mat_q kernels is faster + + # Lowest CUDA 12 standard + lowest for integer intrinsics. + set(CMAKE_CUDA_ARCHITECTURES "52;61;70") + endif() + message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") else() message(WARNING "cuBLAS not found") endif() @@ -138,7 +150,7 @@ if (RWKV_CLBLAST) $ENV{OPENCL_HOME} $ENV{OPENCL_HOME}/include ${OPENCL_INCLUDE_SEARCH_PATHS} - ) + ) set(CLBLAST_INCLUDE_SEARCH_PATHS /usr/include @@ -146,7 +158,7 @@ if (RWKV_CLBLAST) $ENV{CLBLAST_HOME} $ENV{CLBLAST_HOME}/include ${CLBLAST_INCLUDE_SEARCH_PATHS} - ) + ) find_path(OPENCL_INC NAMES opencl.h PATHS ${OPENCL_INCLUDE_SEARCH_PATHS} PATH_SUFFIXES include/CL) find_library(OPENCL_LIB NAMES OpenCL PATHS ${OPENCL_INCLUDE_SEARCH_PATHS} PATH_SUFFIXES lib) @@ -286,6 +298,56 @@ else() message(STATUS "Unknown architecture") endif() +# +# POSIX conformance +# Section copy-pasted from https://github.com/ggerganov/llama.cpp/blob/8781013ef654270cbead3e0011e33a6d690fb168/CMakeLists.txt#L580C20-L580C20 +# + +# clock_gettime came in POSIX.1b (1993) +# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional +# posix_memalign came in POSIX.1-2001 / SUSv3 +# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985) +add_compile_definitions(_XOPEN_SOURCE=600) + +# Somehow in OpenBSD whenever POSIX conformance is specified +# some string functions rely on locale_t availability, +# which was introduced in POSIX.1-2008, forcing us to go higher. +if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") + remove_definitions(-D_XOPEN_SOURCE=600) + add_compile_definitions(_XOPEN_SOURCE=700) +endif() + +# Data types, macros and functions related to controlling CPU affinity and +# some memory allocation are available on Linux through GNU extensions in libc. +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + add_compile_definitions(_GNU_SOURCE) +endif() + +# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1, +# and on macOS its availability depends on enabling Darwin extensions. +# Similarly on DragonFly, enabling BSD extensions is necessary. +if ( + CMAKE_SYSTEM_NAME MATCHES "Darwin" OR + CMAKE_SYSTEM_NAME MATCHES "iOS" OR + CMAKE_SYSTEM_NAME MATCHES "tvOS" OR + CMAKE_SYSTEM_NAME MATCHES "DragonFly" +) + add_compile_definitions(_DARWIN_C_SOURCE) +endif() + +# alloca is a non-standard interface that is not visible on BSDs when +# POSIX conformance is specified, but not all of them provide a clean way +# to enable it in such cases. +if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD") + add_compile_definitions(__BSD_VISIBLE) +endif() +if (CMAKE_SYSTEM_NAME MATCHES "NetBSD") + add_compile_definitions(_NETBSD_SOURCE) +endif() +if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") + add_compile_definitions(_BSD_SOURCE) +endif() + # # Build libraries # @@ -296,7 +358,9 @@ endif() add_library(ggml OBJECT ${CMAKE_SOURCE_DIR}/ggml/src/ggml.c + ${CMAKE_SOURCE_DIR}/ggml/src/ggml-alloc.c ${CMAKE_SOURCE_DIR}/ggml/include/ggml/ggml.h + ${CMAKE_SOURCE_DIR}/ggml/include/ggml/ggml-alloc.h ${GGML_CUDA_SOURCES} ${GGML_OPENCL_SOURCES}) @@ -328,13 +392,6 @@ if (RWKV_BUILD_SHARED_LIBRARY) target_compile_definitions(rwkv PRIVATE RWKV_SHARED RWKV_BUILD) endif() -if (GGML_CUDA_SOURCES) - message(STATUS "GGML CUDA sources found, configuring CUDA architecture") - set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF) - set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") - set_property(TARGET rwkv PROPERTY CUDA_ARCHITECTURES OFF) -endif() - if (NOT RWKV_STANDALONE) set_property(TARGET ggml PROPERTY GGML_STANDALONE OFF) enable_testing() diff --git a/README.md b/README.md index a9a8f18..281b7f0 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it. -RWKV is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts. +[RWKV](https://arxiv.org/abs/2305.13048) is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts. Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py). diff --git a/ggml b/ggml index a1d0ea7..d925ed7 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit a1d0ea7c2abd44f56822ffdfcfe0a0fcf7170885 +Subproject commit d925ed7a96767192d422a97645f08ad86d5cc6f0 diff --git a/rwkv.cpp b/rwkv.cpp index 8ace275..19c394d 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -1,5 +1,6 @@ #include "rwkv.h" #include "ggml.h" +#include "ggml-alloc.h" #ifdef GGML_USE_CUBLAS #include "ggml/src/ggml-cuda.h" @@ -40,8 +41,8 @@ #endif #endif -static_assert(sizeof(stat::st_size) >= 8, "File offsets should be 64-bit or else rwkv.cpp will not be able to load model files over 2GB"); -static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit or else rwkv.cpp will not be able to load model files over 2GB"); +static_assert(sizeof(stat::st_size) >= 8, "File offsets should be 64-bit or else rwkv.cpp will not be able to load model files over 2 GB"); +static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit or else rwkv.cpp will not be able to load model files over 2 GB"); // --- Error handling --- @@ -140,6 +141,31 @@ inline enum rwkv_error_flags operator|=(enum rwkv_error_flags & a, enum rwkv_err // --- Utilities --- +size_t rwkv_tensor_nbytes(const enum ggml_type type, const int64_t width, const int64_t height) { + return (ggml_type_size(type) * width * height) / ggml_blck_size(type); +} + +// For some reason, ggml_nbytes calculates the size in a way incompatible with rwkv.cpp +size_t rwkv_tensor_nbytes(const struct ggml_tensor * tensor) { + return rwkv_tensor_nbytes(tensor->type, tensor->ne[0], tensor->ne[1]); +} + +size_t rwkv_ggml_overhead() { + return ggml_tensor_overhead() * GGML_MAX_NODES + ggml_graph_overhead(); +} + +struct ggml_context * rwkv_init_ggml_context(const size_t memory_size, const bool no_alloc) { + struct ggml_init_params init_params = { + memory_size, + NULL, + no_alloc + }; + + return ggml_init(init_params); +} + +// --- IO utilities --- + // Reads a single uint32 value from a file. bool rwkv_fread_uint32(FILE * file, uint32_t & dest) { return fread((void *) &dest, sizeof(uint32_t), 1, file) == 1; @@ -290,9 +316,13 @@ struct rwkv_tensor_header { uint32_t width; uint32_t height; - const size_t size() const; + size_t size() const; }; +size_t rwkv_tensor_header::size() const { + return rwkv_tensor_nbytes(rwkv_type_to_ggml[this->data_type], this->width, this->height); +} + struct rwkv_tensor { struct rwkv_tensor_header header; std::string name; @@ -323,13 +353,17 @@ bool rwkv_fwrite_tensor_header(FILE * file, const struct rwkv_tensor_header & he return true; } -bool rwkv_fskip_tensor_data(FILE * file, const struct rwkv_tensor_header & header) { +bool rwkv_fskip_tensor_name_and_data(FILE * file, const struct rwkv_tensor_header & header) { return fseek(file, header.key_length + header.size(), SEEK_CUR) == 0; } +bool rwkv_fskip_tensor_data(FILE * file, const struct rwkv_tensor_header & header) { + return fseek(file, header.size(), SEEK_CUR) == 0; +} + bool rwkv_fread_tensor_header_and_skip(FILE * file, struct rwkv_tensor_header & header) { RWKV_ENSURE_OR_FALSE(rwkv_fread_tensor_header(file, header)); - RWKV_ASSERT_FALSE(RWKV_ERROR_DATA, rwkv_fskip_tensor_data(file, header)); + RWKV_ASSERT_FALSE(RWKV_ERROR_DATA, rwkv_fskip_tensor_name_and_data(file, header)); return true; } @@ -341,7 +375,7 @@ bool rwkv_fread_tensor_data(FILE * file, struct rwkv_tensor & output, void * buf RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, data_size, buffer)); } else { output.data = NULL; - RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fskip_tensor_data(file, output.header)); + RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fskip_tensor_name_and_data(file, output.header)); } return true; @@ -366,7 +400,13 @@ bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header & RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); ggml_set_name(tensor, name.c_str()); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); + // Tensor data may be NULL if no_alloc is true + if (tensor->data != NULL) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, rwkv_tensor_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); + } else { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fskip_tensor_data(file, header), "Failed to skip tensor data from %s", name.c_str()); + } + return true; } @@ -383,7 +423,7 @@ bool rwkv_fwrite_tensor(FILE * file, const struct rwkv_tensor & tensor) { return true; } -// --- Model definition --- +// --- Model loading --- struct rwkv_layer { struct ggml_tensor * ln1_weight; @@ -411,7 +451,13 @@ struct rwkv_layer { struct ggml_tensor * ffn_receptance; }; +// The model holds all parameter tensors and the ggml context containing them. +// Each tensor has data and can be used in computations happening in other contexts. struct rwkv_model { + // This context holds all parameter tensors. + // It must not be used for computations. + struct ggml_context * ggml_ctx; + struct rwkv_file_header header; struct ggml_tensor * emb; @@ -425,248 +471,245 @@ struct rwkv_model { struct ggml_tensor * ln_out_bias; struct ggml_tensor * head; + + // How many layers were offloaded to the GPU. + size_t offloaded_layer_count; + + // How many RWKV contexts reference this model. + int reference_count; }; -// --- Operators --- +struct rwkv_file { + FILE * file; -void rwkv_exp_impl(const int n_cols, float * dest, const float * src) { - for (int i = 0; i < n_cols; i++) { - dest[i] = expf(src[i]); - } -} + rwkv_file(FILE * file): file(file) {} -void rwkv_1_minus_x_impl(const int n_cols, float * dest, const float * src) { - for (int i = 0; i < n_cols; i++) { - dest[i] = 1.0F - src[i]; + ~rwkv_file() { + if (file) { + fclose(file); + } } -} +}; -void rwkv_sigmoid_impl(const int n_cols, float * dest, const float * src) { - for (int i = 0; i < n_cols; i++) { - dest[i] = 1.0F / (1.0F + expf(-src[i])); - } -} +// https://stackoverflow.com/a/6458689 +template +bool rwkv_set_params(struct rwkv_model & model, F callback) { + RWKV_ENSURE_OR_FALSE(callback("emb.weight", model.emb)); + RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.weight", model.ln0_weight)); + RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.bias", model.ln0_bias)); -void rwkv_max_impl(const int n_cols, float * dest, const float * src0, const float * src1) { - for (int i = 0; i < n_cols; i++) { - dest[i] = fmaxf(src0[i], src1[i]); - } -} + uint32_t n_layer = model.header.n_layer; + std::unique_ptr layers(new(std::nothrow) struct rwkv_layer[n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, layers.get(), "Failed to allocate model layers"); + model.layers = std::move(layers); -struct ggml_tensor * rwkv_exp(ggml_context * ctx, struct ggml_tensor * x) { - return ggml_map_unary_f32(ctx, x, rwkv_exp_impl); -} + for (uint32_t i = 0; i < n_layer; i++) { + char buffer[128]; + size_t offset = sprintf(buffer, "blocks.%" PRId32 ".", i); -struct ggml_tensor * rwkv_1_minus_x(ggml_context * ctx, struct ggml_tensor * x) { - return ggml_map_unary_f32(ctx, x, rwkv_1_minus_x_impl); -} + rwkv_layer & layer = model.layers[i]; + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.weight"), buffer), layer.ln1_weight)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.bias"), buffer), layer.ln1_bias)); -struct ggml_tensor * rwkv_sigmoid(ggml_context * ctx, struct ggml_tensor * x) { - return ggml_map_unary_f32(ctx, x, rwkv_sigmoid_impl); -} + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output)); -struct ggml_tensor * rwkv_max(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y) { - return ggml_map_binary_f32(ctx, x, y, rwkv_max_impl); -} + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias)); -struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) { - // LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias` - // Looks like ggml_norm does the first part, we only need to apply weight & bias. - return ggml_add_inplace(ctx, ggml_mul_inplace(ctx, ggml_norm(ctx, x), weight), bias); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance)); + } + + RWKV_ENSURE_OR_FALSE(callback("ln_out.weight", model.ln_out_weight)); + RWKV_ENSURE_OR_FALSE(callback("ln_out.bias", model.ln_out_bias)); + RWKV_ENSURE_OR_FALSE(callback("head.weight", model.head)); + + return true; } -// --- Implementation --- +// Creates a ggml context and loads all parameter tensors from a model file. +bool rwkv_load_model_from_file(const char * file_path, struct rwkv_model & model) { + struct stat file_stat; + + std::unordered_map parameters; -// Used as a helper during rwkv_ctx_size calculation. -struct rwkv_future_tensor; + rwkv_file file(fopen(file_path, "rb")); -// Used to calculate the memory usage of ggml contexts before allocating them. -// Since ggml uses an internal bump allocator that can't be grown at runtime, we need to ensure we have enough space, -// while at the same time not using more memory than necessary. -struct rwkv_future_ctx { - size_t objects_count = 0; - size_t memory_size = 0; - size_t scratch_size = 0; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file.file, "Failed to open file %s", file_path); + // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to get the file length. + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file.file), &file_stat) == 0, "Failed to stat file %s", file_path); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(file.file, model.header), "Invalid file header"); - // Align to GGML_MEM_ALIGN, which can currently be up to 16 - static const size_t align(const size_t size) { - return ((size + 15) & ~15); - } + model.ggml_ctx = rwkv_init_ggml_context( + // ggml tensors must be aligned; assuming here that overhead of parameter headers, included in the file size, will account for that. + file_stat.st_size + rwkv_ggml_overhead(), + false + ); - void add_objects(const size_t size, const size_t count = 1) { - this->objects_count += count; + std::string name; - if (size && count) { - this->add_memory(size, count); - } - } + struct ggml_tensor * tensor; - void add_memory(const size_t size, const size_t count = 1) { - this->memory_size += this->align(size) * count; - } + while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, model.ggml_ctx, name, tensor), "Failed to read a model parameter"); - void add_scratch(const size_t size, const size_t count = 1) { - this->scratch_size += this->align(size) * count; + parameters[std::move(name)] = tensor; } - void add_data(const bool use_scratch, const size_t size, const size_t count = 1) { - if (use_scratch) { - this->add_scratch(size, count); - } else { - this->add_memory(size, count); - } - } + std::unordered_map & parameters_ref = parameters; + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(model, [&](const char * key, struct ggml_tensor *& dest) { + struct ggml_tensor * tensor = parameters_ref[key]; + RWKV_ENSURE_OR_FALSE_MSG(tensor, "Model parameter %s not found", key); + dest = tensor; + return true; + })); - struct rwkv_future_tensor declare(const enum ggml_type type, const uint64_t width, const uint64_t height = 1); + // Verify order of dimensions + struct ggml_tensor * emb = model.emb; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model.header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model.header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); - struct rwkv_future_tensor alloc(const enum ggml_type type, const uint64_t width, const uint64_t height = 1, const bool use_scratch = true); -}; + return true; +} -struct rwkv_future_tensor { - enum ggml_type type = GGML_TYPE_COUNT; - uint64_t width = 0; - uint64_t height = 0; - - static const size_t size(const enum ggml_type type, const uint64_t width, const uint64_t height) { - struct ggml_tensor decoy {}; - decoy.type = type; - decoy.ne[0] = width; - decoy.ne[1] = height; - decoy.ne[2] = 1; - decoy.ne[3] = 1; - return ggml_nbytes(&decoy); - } +// --- Operators --- - rwkv_future_tensor() {} - rwkv_future_tensor(const enum ggml_type type, const uint64_t width, const uint64_t height = 1): type(type), width(width), height(height) {} - rwkv_future_tensor(const struct ggml_tensor * ref): type(ref->type), width(ref->ne[0]), height(ref->ne[1]) {} +void rwkv_exp_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) { + GGML_ASSERT(dest->type == GGML_TYPE_F32); + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(dest)); + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(ggml_are_same_shape(src, dest)); - struct rwkv_future_tensor alloc(struct rwkv_future_ctx & ctx, const bool use_scratch = true) const { - ctx.add_objects(sizeof(struct ggml_tensor)); - ctx.add_data(use_scratch, rwkv_future_tensor::size(type, width, height)); - return *this; - } + // Assuming 2D tensors. + int64_t element_count = src->ne[0] * src->ne[1]; + float * src_data = (float *) src->data; + float * dest_data = (float *) dest->data; - struct rwkv_future_tensor view(struct rwkv_future_ctx & ctx) const { - ctx.add_objects(sizeof(struct ggml_tensor)); - return *this; + for (int64_t i = 0; i < element_count; i++) { + dest_data[i] = expf(src_data[i]); } - struct rwkv_future_tensor subview(struct rwkv_future_ctx & ctx, const uint32_t width, const uint32_t height = 1) const { - ctx.add_objects(sizeof(struct ggml_tensor), 2); - ctx.add_memory(sizeof(uint32_t) * 2); - return rwkv_future_tensor(type, width, height); - } + // Suppress warnings for unused parameters. + (void) ith; + (void) nth; + (void) userdata; +} - struct rwkv_future_tensor dup(struct rwkv_future_ctx & ctx) const { - return this->alloc(ctx); - } +void rwkv_1_minus_x_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) { + GGML_ASSERT(dest->type == GGML_TYPE_F32); + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(dest)); + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(ggml_are_same_shape(src, dest)); - struct rwkv_future_tensor layer_norm(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & weight, const struct rwkv_future_tensor & bias) const { - return this->dup(ctx).view(ctx).view(ctx); - } + // Assuming 2D tensors. + int64_t element_count = src->ne[0] * src->ne[1]; + float * src_data = (float *) src->data; + float * dest_data = (float *) dest->data; - struct rwkv_future_tensor repeat(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor reference) const { - return reference.dup(ctx); + for (int64_t i = 0; i < element_count; i++) { + dest_data[i] = 1.0F - src_data[i]; } - struct rwkv_future_tensor set_inplace(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor src) { - ctx.add_objects(sizeof(struct ggml_tensor)); - ctx.add_memory(sizeof(uint32_t) * 5); - return this->view(ctx); - } + // Suppress warnings for unused parameters. + (void) ith; + (void) nth; + (void) userdata; +} - struct rwkv_future_tensor consume(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & other) { - return this->view(ctx); - } +void rwkv_sigmoid_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) { + GGML_ASSERT(dest->type == GGML_TYPE_F32); + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(dest)); + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(ggml_are_same_shape(src, dest)); - struct rwkv_future_tensor combine(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & other) const { - return this->dup(ctx); - } + // Assuming 2D tensors. + int64_t element_count = src->ne[0] * src->ne[1]; + float * src_data = (float *) src->data; + float * dest_data = (float *) dest->data; - struct rwkv_future_tensor fn(struct rwkv_future_ctx & ctx) const { - ctx.add_objects(sizeof(struct ggml_tensor)); - ctx.add_memory(sizeof(void *) / sizeof(uint32_t)); - return this->dup(ctx); + for (int64_t i = 0; i < element_count; i++) { + dest_data[i] = 1.0F / (1.0F + expf(-src_data[i])); } - struct rwkv_future_tensor mul_mat(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & other) const { - return ctx.alloc(GGML_TYPE_F32, this->height, other.height); - } + // Suppress warnings for unused parameters. + (void) ith; + (void) nth; + (void) userdata; +} - struct rwkv_future_tensor get_rows(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & other) const { - return ctx.alloc(GGML_TYPE_F32, this->width, other.width); +void rwkv_max_impl( + struct ggml_tensor * dest, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + int ith, + int nth, + void * userdata +) { + GGML_ASSERT(dest->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(dest)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_are_same_shape(src0, dest)); + GGML_ASSERT(ggml_are_same_shape(src1, dest)); + + // Assuming 2D tensors. + int64_t element_count = src0->ne[0] * src0->ne[1]; + float * src0_data = (float *) src0->data; + float * src1_data = (float *) src1->data; + float * dest_data = (float *) dest->data; + + for (int64_t i = 0; i < element_count; i++) { + dest_data[i] = fmaxf(src0_data[i], src1_data[i]); } -}; -const size_t rwkv_tensor_header::size() const { - return rwkv_future_tensor::size(rwkv_type_to_ggml[this->data_type], this->width, this->height); + // Suppress warnings for unused parameters. + (void) ith; + (void) nth; + (void) userdata; } -struct rwkv_future_tensor rwkv_future_ctx::declare(const enum ggml_type type, const uint64_t width, const uint64_t height) { - return rwkv_future_tensor(type, width, height); +struct ggml_tensor * rwkv_exp(ggml_context * ctx, struct ggml_tensor * x) { + return ggml_map_custom1(ctx, x, rwkv_exp_impl, 1, NULL); } -struct rwkv_future_tensor rwkv_future_ctx::alloc(const enum ggml_type type, const uint64_t width, const uint64_t height, const bool use_scratch) { - return this->declare(type, width, height).alloc(*this, use_scratch); +struct ggml_tensor * rwkv_1_minus_x(ggml_context * ctx, struct ggml_tensor * x) { + return ggml_map_custom1(ctx, x, rwkv_1_minus_x_impl, 1, NULL); } -struct rwkv_ggml_context { - std::unique_ptr scratch; - struct ggml_context * ctx; - - rwkv_ggml_context(): ctx(NULL) {} - - rwkv_ggml_context(const struct rwkv_future_ctx future_ctx): ctx(NULL) { - scratch.reset(new(std::nothrow) uint8_t[future_ctx.scratch_size]); - - if (!scratch) { - return; - } - - ctx = ggml_init({ future_ctx.objects_count * GGML_OBJECT_SIZE + future_ctx.memory_size, NULL, false}); - - if (!ctx) { - return; - } - - ggml_set_scratch(ctx, { 0, future_ctx.scratch_size, scratch.get() }); - } +struct ggml_tensor * rwkv_sigmoid(ggml_context * ctx, struct ggml_tensor * x) { + return ggml_map_custom1(ctx, x, rwkv_sigmoid_impl, 1, NULL); +} - struct rwkv_ggml_context & operator=(struct rwkv_ggml_context && source) { - scratch.reset(source.scratch.release()); - std::swap(ctx, source.ctx); - return *this; - } +struct ggml_tensor * rwkv_max(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y) { + return ggml_map_custom2(ctx, x, y, rwkv_max_impl, 1, NULL); +} - ~rwkv_ggml_context() { - if (ctx) { - ggml_free(ctx); - } - } -}; +struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) { + // LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias` + // Looks like ggml_norm does the first part, we only need to apply weight & bias. + return ggml_add_inplace(ctx, ggml_mul_inplace(ctx, ggml_norm(ctx, x, 1e-5F), weight), bias); +} -// An instance of an RWKV model loaded into memory. -// Contains all the model weights. -// Shared by one or more contexts. -struct rwkv_instance { - struct rwkv_ggml_context ctx; - struct rwkv_model model; - - // TODO Come up with a better solution to estimate "work tensor" size - // The ggml_cgraph allocates a "work tensor" the first time it is used. - // Currently, the height of blocks.0.ffn.key.weight is the bottleneck in our implementation of RWKV. - // Since it is the largest dimension used in any matrix multiply, it is the size used for the "work tensor". - // However, if ggml changes its implementation, or rwkv.cpp changes its own implementation, at any point, - // this may become outdated. We need to find a way not to hardcode a specific tensor, but to calculate accurately. - // This may come out of a ggml issue: https://github.com/ggerganov/ggml/issues/214 - size_t ffn_key_size; -}; +// --- Implementation --- -// The hidden state of a single RWKV layer. -// These are mostly used for dividing up the input state, and writing portions of the output state. -// But they're also used in building the computation graphs to represent the operations -// used from input->output (operating "in place" on a rwkv_layer_state). +// View tensors of a state of a single layer. struct rwkv_layer_state { struct ggml_tensor * ffn_xx; struct ggml_tensor * att_xx; @@ -675,118 +718,62 @@ struct rwkv_layer_state { struct ggml_tensor * att_pp; }; -// Holds a single computation graph and its ggml context. -// Graphs each have their own context so that they can be individually freed and rebuilt. -// Graphs read hidden state from the rwkv_context and then write it back to the rwkv_context. -// (see rwkv_context.input_layers and rwkv_context.output_layers) -struct rwkv_graph { - struct rwkv_ggml_context ctx; - struct ggml_tensor * tokens; - - // ggml_cgraph is so large that it can cause stack overflows if not stored on the heap +// The computation graph holds ggml context and the ggml cgraph. +// It can be either a serial or a sequential graph. +struct rwkv_computation_graph { + struct ggml_context * ggml_ctx; + // ggml_cgraph is so large that it can cause stack overflows if not stored on the heap. std::unique_ptr cgraph; - size_t pre_logits_nodes; - size_t pre_logits_leafs; - size_t post_logits_nodes; - size_t post_logits_leafs; -}; - -// RWKV context for a specific instance. -// Contains computation graphs and is used for inference. -struct rwkv_context { - std::shared_ptr instance; - - // Reused by all graphs. - struct rwkv_ggml_context ctx; + // Input tensors. + struct ggml_tensor * tokens; struct ggml_tensor * input_state; std::unique_ptr input_layers; + + // Output tensors. struct ggml_tensor * output_state; std::unique_ptr output_layers; struct ggml_tensor * logits; - uint32_t n_threads; + // ggml graph counters before the graph was extended with logits tensor. + int pre_logits_nodes; + int pre_logits_leafs; + // ggml graph counters after the graph was extended with logits tensor. + int post_logits_nodes; + int post_logits_leafs; +}; - // The serial graph implements the traditional RNN mode that processes only one token at a time (serial mode). - struct rwkv_graph serial_graph; +// The context holds the model and both serial and sequential computation graphs. +struct rwkv_context { + struct rwkv_model * model; + // The serial graph implements the traditional RNN mode that processes only one token at a time (serial mode). + struct rwkv_computation_graph serial_graph; // The sequence graph implements the "sequence mode" (or transformer/GPT mode) that processes multiple tokens at a time. // This can be an order of magnitude or so faster than serial execution if used properly. - size_t sequence_len; - struct rwkv_graph sequence_graph; + struct rwkv_computation_graph sequential_graph; + size_t last_used_sequence_length; + + uint32_t n_threads; enum rwkv_error_flags last_error; bool print_errors; - - size_t gpu_layers; }; -// https://stackoverflow.com/a/6458689 -template -bool rwkv_set_params(struct rwkv_model & model, F callback) { - RWKV_ENSURE_OR_FALSE(callback("emb.weight", model.emb)); - RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.weight", model.ln0_weight)); - RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.bias", model.ln0_bias)); - - uint32_t n_layer = model.header.n_layer; - std::unique_ptr layers(new(std::nothrow) struct rwkv_layer[n_layer]); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, layers.get(), "Failed to allocate model layers"); - model.layers = std::move(layers); - - for (uint32_t i = 0; i < n_layer; i++) { - char buffer[128]; - size_t offset = sprintf(buffer, "blocks.%" PRId32 ".", i); - - rwkv_layer & layer = model.layers[i]; - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.weight"), buffer), layer.ln1_weight)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.bias"), buffer), layer.ln1_bias)); - - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output)); - - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias)); - - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance)); - } - - RWKV_ENSURE_OR_FALSE(callback("ln_out.weight", model.ln_out_weight)); - RWKV_ENSURE_OR_FALSE(callback("ln_out.bias", model.ln_out_bias)); - RWKV_ENSURE_OR_FALSE(callback("head.weight", model.head)); - return true; +void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) { + bool * ptr = ctx ? &ctx->print_errors : &global_print_errors; + *ptr = print_errors; } -void rwkv_future_carry_x(struct rwkv_future_ctx & ctx, - const struct rwkv_future_tensor weight, - const struct rwkv_future_tensor bias, - struct rwkv_future_tensor & x, - struct rwkv_future_tensor & x_prev, - struct rwkv_future_tensor & carry -) { - if (x.height == 1) { - x = x.layer_norm(ctx, weight, bias); - x_prev = carry; - carry = x; - } else { - x = x.layer_norm(ctx, weight.repeat(ctx, x), bias.repeat(ctx, x)); - - x_prev = x.dup(ctx) - .set_inplace(ctx, carry) - .set_inplace(ctx, x.subview(ctx, x.width, x.height - 1)); +bool rwkv_get_print_errors(struct rwkv_context * ctx) { + return ctx ? ctx->print_errors : global_print_errors; +} - carry = x.subview(ctx, x.width); - } +enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) { + enum rwkv_error_flags * ptr = ctx ? &ctx->last_error : &global_last_error; + enum rwkv_error_flags value = *ptr; + *ptr = RWKV_ERROR_NONE; + return value; } void rwkv_carry_x(struct ggml_context * ctx, @@ -822,28 +809,6 @@ void rwkv_carry_x(struct ggml_context * ctx, } } -void rwkv_future_att_rkv(struct rwkv_future_ctx & ctx, - const struct rwkv_future_tensor time_mix_k, - const struct rwkv_future_tensor time_mix_v, - const struct rwkv_future_tensor time_mix_r, - const struct rwkv_future_tensor x, - const struct rwkv_future_tensor x_prev, - const struct rwkv_future_tensor att_r, - const struct rwkv_future_tensor att_k, - const struct rwkv_future_tensor att_v, - struct rwkv_future_tensor & r, - struct rwkv_future_tensor & k, - struct rwkv_future_tensor & v -) { - const struct rwkv_future_tensor xk = x.combine(ctx, time_mix_k).consume(ctx, x_prev.combine(ctx, time_mix_k.fn(ctx))); - const struct rwkv_future_tensor xv = x.combine(ctx, time_mix_v).consume(ctx, x_prev.combine(ctx, time_mix_v.fn(ctx))); - const struct rwkv_future_tensor xr = x.combine(ctx, time_mix_r).consume(ctx, x_prev.combine(ctx, time_mix_r.fn(ctx))); - - r = att_r.mul_mat(ctx, xr).fn(ctx); - k = att_k.mul_mat(ctx, xk); - v = att_v.mul_mat(ctx, xv); -} - void rwkv_att_rkv( struct ggml_context * ctx, struct rwkv_layer layer, @@ -879,37 +844,6 @@ void rwkv_att_rkv( v = ggml_mul_mat(ctx, layer.att_value, xv); } -struct rwkv_future_tensor rwkv_future_att_wkv(struct rwkv_future_ctx & ctx, - const struct rwkv_future_tensor time_first, - const struct rwkv_future_tensor time_decay, - struct rwkv_future_tensor & aa, - struct rwkv_future_tensor & bb, - struct rwkv_future_tensor & pp, - const struct rwkv_future_tensor k, - const struct rwkv_future_tensor v -) { - struct rwkv_future_tensor ww = time_first.combine(ctx, k); - struct rwkv_future_tensor qq = pp.fn(ctx); - struct rwkv_future_tensor e1 = pp.combine(ctx, qq).fn(ctx); - struct rwkv_future_tensor e2 = ww.combine(ctx, qq).fn(ctx); - - struct rwkv_future_tensor a = e1.combine(ctx, aa).consume(ctx, e2.combine(ctx, v)); - struct rwkv_future_tensor b = e1.combine(ctx, bb).consume(ctx, e2); - - ww = pp.combine(ctx, time_decay); - qq = ww.fn(ctx); - e1 = ww.combine(ctx, qq).fn(ctx); - e2 = k.combine(ctx, qq).fn(ctx); - - // aa, bb - aa = e1.combine(ctx, aa).consume(ctx, e2.combine(ctx, v)); - bb = e1.combine(ctx, bb).consume(ctx, e2); - pp = qq; - - // wkv - return a.combine(ctx, b); -} - struct ggml_tensor * rwkv_att_wkv( struct ggml_context * ctx, struct ggml_tensor * att_time_first, @@ -954,36 +888,6 @@ struct ggml_tensor * rwkv_att_wkv( return ggml_div(ctx, a, b); } - -struct rwkv_future_tensor rwkv_future_att(struct rwkv_future_ctx & ctx, - const struct rwkv_future_tensor ln1_weight, - const struct rwkv_future_tensor ln1_bias, - const struct rwkv_future_tensor time_mix_k, - const struct rwkv_future_tensor time_mix_v, - const struct rwkv_future_tensor time_mix_r, - const struct rwkv_future_tensor time_first, - const struct rwkv_future_tensor time_decay, - const struct rwkv_future_tensor att_r, - const struct rwkv_future_tensor att_k, - const struct rwkv_future_tensor att_v, - const struct rwkv_future_tensor att_output, - struct rwkv_future_tensor x, - struct rwkv_future_tensor & att_xx, - struct rwkv_future_tensor & att_aa, - struct rwkv_future_tensor & att_bb, - struct rwkv_future_tensor & att_pp -) { - struct rwkv_future_tensor x_prev; - rwkv_future_carry_x(ctx, ln1_weight, ln1_bias, x, x_prev, att_xx); - - struct rwkv_future_tensor r, k, v; - rwkv_future_att_rkv(ctx, time_mix_k, time_mix_v, time_mix_r, x, x_prev, att_r, att_k, att_v, r, k, v); - - struct rwkv_future_tensor wkv = rwkv_future_att_wkv(ctx, time_first, time_decay, att_aa, att_bb, att_pp, k, v); - - return att_output.mul_mat(ctx, r.combine(ctx, wkv)); -} - struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { struct ggml_tensor * x_prev; rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); @@ -997,29 +901,6 @@ struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, return ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv)); } -struct rwkv_future_tensor rwkv_future_ffn(struct rwkv_future_ctx & ctx, - const struct rwkv_future_tensor ln2_weight, - const struct rwkv_future_tensor ln2_bias, - const struct rwkv_future_tensor time_mix_k, - const struct rwkv_future_tensor time_mix_r, - const struct rwkv_future_tensor ffn_k, - const struct rwkv_future_tensor ffn_v, - const struct rwkv_future_tensor ffn_r, - struct rwkv_future_tensor x, - struct rwkv_future_tensor & ffn_xx -) { - struct rwkv_future_tensor x_prev; - rwkv_future_carry_x(ctx, ln2_weight, ln2_bias, x, x_prev, ffn_xx); - - struct rwkv_future_tensor xk = x.combine(ctx, time_mix_k).consume(ctx, x_prev.combine(ctx, time_mix_k.fn(ctx))); - struct rwkv_future_tensor xr = x.combine(ctx, time_mix_r).consume(ctx, x_prev.combine(ctx, time_mix_r.fn(ctx))); - - struct rwkv_future_tensor r = ffn_r.mul_mat(ctx, xr).fn(ctx); - struct rwkv_future_tensor k = ffn_k.mul_mat(ctx, xk).view(ctx).view(ctx); - - return r.consume(ctx, ffn_v.mul_mat(ctx, k)); -} - struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { struct ggml_tensor * x_prev; rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); @@ -1049,99 +930,63 @@ struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); } -struct rwkv_future_tensor rwkv_future_graph_work(struct rwkv_future_ctx & ctx, - const enum ggml_type type, - const size_t ffn_key_height, - const size_t n_threads, - const size_t sequence_len = 1 -) { -#ifdef GGML_USE_CUBLAS - enum ggml_type mul_mat_type = type == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; -#else - enum ggml_type mul_mat_type = ggml_is_quantized(type) ? GGML_TYPE_Q8_1 : type; -#endif - return ctx.alloc(GGML_TYPE_I8, rwkv_future_tensor::size(mul_mat_type, ffn_key_height, sequence_len) * n_threads + 64 * (n_threads - 1)); -} - -struct rwkv_future_tensor rwkv_future_serial_graph(struct rwkv_future_ctx & ctx, - const struct rwkv_future_tensor tokens, - const size_t n_threads, - - const struct rwkv_future_tensor emb, - const struct rwkv_future_tensor ln0_weight, - const struct rwkv_future_tensor ln0_bias, - - const size_t n_layer, - - const struct rwkv_future_tensor ln1_weight, - const struct rwkv_future_tensor ln1_bias, - const struct rwkv_future_tensor att_time_mix_k, - const struct rwkv_future_tensor att_time_mix_v, - const struct rwkv_future_tensor att_time_mix_r, - const struct rwkv_future_tensor att_time_first, - const struct rwkv_future_tensor att_time_decay, - const struct rwkv_future_tensor att_r, - const struct rwkv_future_tensor att_k, - const struct rwkv_future_tensor att_v, - const struct rwkv_future_tensor att_output, - struct rwkv_future_tensor & att_xx, - struct rwkv_future_tensor & att_aa, - struct rwkv_future_tensor & att_bb, - struct rwkv_future_tensor & att_pp, - - const struct rwkv_future_tensor ln2_weight, - const struct rwkv_future_tensor ln2_bias, - const struct rwkv_future_tensor ffn_time_mix_k, - const struct rwkv_future_tensor ffn_time_mix_r, - const struct rwkv_future_tensor ffn_k, - const struct rwkv_future_tensor ffn_v, - const struct rwkv_future_tensor ffn_r, - struct rwkv_future_tensor & ffn_xx, - - const struct rwkv_future_tensor ln_out_weight, - const struct rwkv_future_tensor ln_out_bias, - const struct rwkv_future_tensor head +void rwkv_create_input_and_output_views( + struct rwkv_layer_state * inputs, + struct rwkv_layer_state * outputs, + struct ggml_tensor * input, + struct ggml_tensor * output, + struct ggml_context * ctx, + size_t n_layer, + size_t n_embed ) { - struct rwkv_future_tensor x = emb.get_rows(ctx, tokens).layer_norm(ctx, ln0_weight, ln0_bias); - for (size_t i = 0; i < n_layer; i++) { - x = x.consume(ctx, rwkv_future_att(ctx, - ln1_weight, ln1_bias, att_time_mix_k, att_time_mix_v, att_time_mix_r, att_time_first, att_time_decay, - att_r, att_k, att_v, att_output, x, att_xx, att_aa, att_bb, att_pp)); - - x = x.consume(ctx, rwkv_future_ffn(ctx, - ln2_weight, ln2_bias, ffn_time_mix_k, ffn_time_mix_r, ffn_k, ffn_v, ffn_r, x, ffn_xx)); - - ffn_xx.view(ctx); - att_xx.view(ctx); - att_aa.view(ctx); - att_bb.view(ctx); - att_pp.view(ctx); + struct rwkv_layer_state & input_state = inputs[i]; + input_state.ffn_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 0) * sizeof(float)); + input_state.att_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 1) * sizeof(float)); + input_state.att_aa = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 2) * sizeof(float)); + input_state.att_bb = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 3) * sizeof(float)); + input_state.att_pp = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 4) * sizeof(float)); + + struct rwkv_layer_state & output_state = outputs[i]; + output_state.ffn_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 0) * sizeof(float)); + output_state.att_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 1) * sizeof(float)); + output_state.att_aa = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 2) * sizeof(float)); + output_state.att_bb = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 3) * sizeof(float)); + output_state.att_pp = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 4) * sizeof(float)); } +} - x = x.layer_norm(ctx, ln_out_weight, ln_out_bias); +// Creates and sets the input and output ggml tensors, builds the computation graph. +bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_computation_graph & graph) { + graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); - rwkv_future_graph_work(ctx, ffn_k.type, ffn_k.height, n_threads, tokens.width); + struct rwkv_file_header & header = model.header; + const size_t n_vocab = header.n_vocab; + const size_t n_embed = header.n_embed; + const size_t n_layer = header.n_layer; - return head.mul_mat(ctx, x).view(ctx); -} + struct ggml_context * ctx = graph.ggml_ctx; -bool rwkv_build_serial_graph( - struct ggml_context * ctx, - struct rwkv_model & model, - struct ggml_tensor * tokens, - struct rwkv_layer_state * inputs, - struct rwkv_layer_state * outputs, - struct ggml_tensor * logits, - struct ggml_cgraph * cgraph, + // Creates a 1-element tensor. + graph.tokens = ggml_new_i32(ctx, 0); + + struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); + struct ggml_tensor * output = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); + + // We collect parts of input state here. Each part is (n_embed) vector. + std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, inputs.get(), "Failed to allocate input state parts"); + + // We collect parts of output state here. Each part is (n_embed) vector. + std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); + + rwkv_create_input_and_output_views(inputs.get(), outputs.get(), input, output, ctx, n_layer, n_embed); + + graph.logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab); - size_t * const pre_logits_nodes, - size_t * const pre_logits_leafs, - size_t * const post_logits_nodes, - size_t * const post_logits_leafs -) { // x = self.w.emb.weight[token] - struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, tokens); + struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, graph.tokens); // x = self.layer_norm(x, self.w.blocks[0].ln0) x = rwkv_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias); @@ -1153,123 +998,108 @@ bool rwkv_build_serial_graph( x = ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state)); x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); - struct rwkv_layer_state & output = outputs[i]; - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.ffn_xx, output.ffn_xx)); - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_xx, output.att_xx)); - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_aa, output.att_aa)); - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_bb, output.att_bb)); - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp)); + struct rwkv_layer_state & output_state = outputs[i]; + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_xx, output_state.att_xx)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_aa, output_state.att_aa)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_bb, output_state.att_bb)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_pp, output_state.att_pp)); } - *pre_logits_nodes = cgraph->n_nodes; - *pre_logits_leafs = cgraph->n_leafs; + graph.pre_logits_nodes = graph.cgraph->n_nodes; + graph.pre_logits_leafs = graph.cgraph->n_leafs; // x = self.layer_norm(x[-1,:], self.w.ln_out) x = rwkv_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias); // x = (self.w.head.weight @ x).float() - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), graph.logits)); - *post_logits_nodes = cgraph->n_nodes; - *post_logits_leafs = cgraph->n_leafs; + graph.post_logits_nodes = graph.cgraph->n_nodes; + graph.post_logits_leafs = graph.cgraph->n_leafs; + + graph.input_state = input; + graph.input_layers = std::move(inputs); + + graph.output_state = output; + graph.output_layers = std::move(outputs); return true; } -struct rwkv_future_tensor rwkv_future_sequence_graph(struct rwkv_future_ctx & ctx, - const struct rwkv_future_tensor tokens, - const size_t n_threads, - - const struct rwkv_future_tensor emb, - const struct rwkv_future_tensor ln0_weight, - const struct rwkv_future_tensor ln0_bias, - - const size_t n_layer, - - const struct rwkv_future_tensor ln1_weight, - const struct rwkv_future_tensor ln1_bias, - const struct rwkv_future_tensor att_time_mix_k, - const struct rwkv_future_tensor att_time_mix_v, - const struct rwkv_future_tensor att_time_mix_r, - const struct rwkv_future_tensor att_time_first, - const struct rwkv_future_tensor att_time_decay, - const struct rwkv_future_tensor att_r, - const struct rwkv_future_tensor att_k, - const struct rwkv_future_tensor att_v, - const struct rwkv_future_tensor att_output, - struct rwkv_future_tensor & att_xx, - struct rwkv_future_tensor & att_aa, - struct rwkv_future_tensor & att_bb, - struct rwkv_future_tensor & att_pp, - - const struct rwkv_future_tensor ln2_weight, - const struct rwkv_future_tensor ln2_bias, - const struct rwkv_future_tensor ffn_time_mix_k, - const struct rwkv_future_tensor ffn_time_mix_r, - const struct rwkv_future_tensor ffn_k, - const struct rwkv_future_tensor ffn_v, - const struct rwkv_future_tensor ffn_r, - struct rwkv_future_tensor & ffn_xx, - - const struct rwkv_future_tensor ln_out_weight, - const struct rwkv_future_tensor ln_out_bias, - const struct rwkv_future_tensor head -) { - struct rwkv_future_tensor x = emb.get_rows(ctx, tokens); - x = x.layer_norm(ctx, ln0_weight.repeat(ctx, x), ln0_bias.repeat(ctx, x)); - - for (size_t i = 0; i < n_layer; i++) { - struct rwkv_future_tensor x0 = x, x_prev; - rwkv_future_carry_x(ctx, ln1_weight, ln1_bias, x0, x_prev, att_xx); - - struct rwkv_future_tensor r, k, v; - rwkv_future_att_rkv(ctx, att_time_mix_k, att_time_mix_v, att_time_mix_r, x0, x_prev, att_r, att_k, att_v, r, k, v); - - for (size_t i = 0; i < tokens.width; i++) { - struct rwkv_future_tensor kt = k.subview(ctx, emb.width); - struct rwkv_future_tensor vt = v.subview(ctx, emb.width); - struct rwkv_future_tensor xt = x_prev.subview(ctx, emb.width); - struct rwkv_future_tensor wkv = rwkv_future_att_wkv(ctx, att_time_first, att_time_decay, att_aa, att_bb, att_pp, k, v); - wkv.view(ctx); - } +// Stolen from llama.cpp. +static const size_t tensor_alignment = 32; - x = x.consume(ctx, att_output.mul_mat(ctx, r.combine(ctx, x_prev))); - x = x.consume(ctx, rwkv_future_ffn(ctx, ln2_weight, ln2_bias, ffn_time_mix_k, ffn_time_mix_r, ffn_k, ffn_v, ffn_r, x, ffn_xx)); +// Prepares the computation graph for inference, measuring and allocating all input and output tensors. +bool rwkv_measure_and_build_serial_context(struct rwkv_model & model, struct rwkv_computation_graph & graph) { + if (graph.ggml_ctx) { + ggml_free(graph.ggml_ctx); - ffn_xx.view(ctx); - att_xx.view(ctx); - att_aa.view(ctx); - att_bb.view(ctx); - att_pp.view(ctx); + graph.ggml_ctx = NULL; } - x = x.subview(ctx, emb.width).layer_norm(ctx, ln_out_weight, ln_out_bias); + // 1. Measure the space required for the ggml context. + graph.ggml_ctx = rwkv_init_ggml_context(rwkv_ggml_overhead(), true); - rwkv_future_graph_work(ctx, ffn_k.type, ffn_k.height, n_threads, tokens.width); + RWKV_ENSURE_OR_FALSE(rwkv_build_serial_graph(model, graph)); - return head.mul_mat(ctx, x).view(ctx); + struct ggml_allocr * allocator = ggml_allocr_new_measure(tensor_alignment); + + size_t required_context_size = ggml_allocr_alloc_graph(allocator, graph.cgraph.get()) + + + rwkv_ggml_overhead() + + tensor_alignment + // For some reason, calculation above does not result in enough memory allocated. + // Instead of diving deep into ggml internals to debug this issue, I will just add some padding. + // 64 MB seems to be enough for Raven 14B model. + + size_t(64) * 1024 * 1024; + + ggml_allocr_free(allocator); + ggml_free(graph.ggml_ctx); + + // 2. Create the real ggml context. + graph.ggml_ctx = rwkv_init_ggml_context(required_context_size, false); + + RWKV_ENSURE_OR_FALSE(rwkv_build_serial_graph(model, graph)); + + return true; } -bool rwkv_build_sequence_graph( - struct ggml_context * ctx, - struct rwkv_model & model, - struct ggml_tensor * tokens, - struct rwkv_layer_state * inputs, - struct rwkv_layer_state * outputs, - struct ggml_tensor * logits, - struct ggml_cgraph * cgraph, +// --- - size_t * const pre_logits_nodes, - size_t * const pre_logits_leafs, - size_t * const post_logits_nodes, - size_t * const post_logits_leafs -) { - const uint32_t n_embed = model.header.n_embed; - const size_t sequence_len = tokens->ne[0]; +// Creates and sets the input and output ggml tensors, builds the computation graph. +bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_computation_graph & graph, const size_t sequence_length) { + graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); + + struct rwkv_file_header & header = model.header; + const size_t n_vocab = header.n_vocab; + const size_t n_embed = header.n_embed; + const size_t n_layer = header.n_layer; + + struct ggml_context * ctx = graph.ggml_ctx; + + graph.tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sequence_length); + + struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); + struct ggml_tensor * output = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); + + // We collect parts of input state here. Each part is (n_embed) vector. + std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, inputs.get(), "Failed to allocate input state parts"); + + // We collect parts of output state here. Each part is (n_embed) vector. + std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); + + rwkv_create_input_and_output_views(inputs.get(), outputs.get(), input, output, ctx, n_layer, n_embed); + + graph.logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab); + + // x = self.w.emb.weight[token] + struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, graph.tokens); - struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, tokens); + // x = self.layer_norm(x, self.w.blocks[0].ln0) x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, model.ln0_weight, x), ggml_repeat(ctx, model.ln0_bias, x)); - + for (size_t i = 0; i < model.header.n_layer; i++) { struct rwkv_layer & layer = model.layers[i]; struct rwkv_layer_state state = inputs[i]; @@ -1280,275 +1110,123 @@ bool rwkv_build_sequence_graph( struct ggml_tensor * r, * k, * v; rwkv_att_rkv(ctx, layer, x0, x_prev, r, k, v); - ggml_build_forward_expand(cgraph, r); + ggml_build_forward_expand(graph.cgraph.get(), r); - for (uint32_t t = 0; t < sequence_len; t++) { + for (uint32_t t = 0; t < sequence_length; t++) { struct ggml_tensor * kt = ggml_view_1d(ctx, k, n_embed, n_embed * sizeof(float) * t); struct ggml_tensor * vt = ggml_view_1d(ctx, v, n_embed, n_embed * sizeof(float) * t); struct ggml_tensor * xt = ggml_view_1d(ctx, x_prev, n_embed, n_embed * sizeof(float) * t); struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, kt, vt, state.att_aa, state.att_bb, state.att_pp); - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, wkv, xt)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, wkv, xt)); } x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, x_prev))); x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); - struct rwkv_layer_state & output = outputs[i]; - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.ffn_xx, output.ffn_xx)); - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_xx, output.att_xx)); - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_aa, output.att_aa)); - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_bb, output.att_bb)); - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp)); + struct rwkv_layer_state & output_state = outputs[i]; + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.ffn_xx, output_state.ffn_xx)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_xx, output_state.att_xx)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_aa, output_state.att_aa)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_bb, output_state.att_bb)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, state.att_pp, output_state.att_pp)); } - *pre_logits_nodes = cgraph->n_nodes; - *pre_logits_leafs = cgraph->n_leafs; + graph.pre_logits_nodes = graph.cgraph->n_nodes; + graph.pre_logits_leafs = graph.cgraph->n_leafs; // x = self.layer_norm(x[-1,:], self.w.ln_out) - x = rwkv_layer_norm(ctx, ggml_view_1d(ctx, x, n_embed, n_embed * sizeof(float) * (sequence_len - 1)), model.ln_out_weight, model.ln_out_bias); + x = rwkv_layer_norm(ctx, ggml_view_1d(ctx, x, n_embed, n_embed * sizeof(float) * (sequence_length - 1)), model.ln_out_weight, model.ln_out_bias); // x = (self.w.head.weight @ x).float() - ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits)); + ggml_build_forward_expand(graph.cgraph.get(), ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), graph.logits)); - *post_logits_nodes = cgraph->n_nodes; - *post_logits_leafs = cgraph->n_leafs; - - return true; -} + graph.post_logits_nodes = graph.cgraph->n_nodes; + graph.post_logits_leafs = graph.cgraph->n_leafs; -void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) { - bool * ptr = ctx ? &ctx->print_errors : &global_print_errors; - *ptr = print_errors; -} + graph.input_state = input; + graph.input_layers = std::move(inputs); -bool rwkv_get_print_errors(struct rwkv_context * ctx) { - return ctx ? ctx->print_errors : global_print_errors; -} + graph.output_state = output; + graph.output_layers = std::move(outputs); -enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) { - enum rwkv_error_flags * ptr = ctx ? &ctx->last_error : &global_last_error; - enum rwkv_error_flags value = *ptr; - *ptr = RWKV_ERROR_NONE; - return value; + return true; } -struct rwkv_file { - FILE * file; - - rwkv_file(FILE * file): file(file) {} +// Prepares the computation graph for inference, measuring and allocating all input and output tensors. +bool rwkv_measure_and_build_sequential_context(struct rwkv_model & model, struct rwkv_computation_graph & graph, const size_t sequence_length) { + if (graph.ggml_ctx) { + ggml_free(graph.ggml_ctx); - ~rwkv_file() { - if (file) { - fclose(file); - } + graph.ggml_ctx = NULL; } -}; - -bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance) { - struct stat file_stat; - struct rwkv_model model; - struct rwkv_ggml_context ctx; - size_t ffn_key_size = 0; - - std::unordered_map parameters; - - { - rwkv_file file(fopen(file_path, "rb")); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file.file, "Failed to open file %s", file_path); - // Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to get the file length. - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file.file), &file_stat) == 0, "Failed to stat file %s", file_path); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(file.file, model.header), "Invalid file header"); - - struct rwkv_tensor_header tensor_header; - std::string name; - struct rwkv_future_ctx future_ctx; - - while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(file.file, tensor_header), "Invalid tensor header"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(file.file, tensor_header.key_length, name), "Failed to read tensor name"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, tensor_header.size(), SEEK_CUR) == 0, "Failed to read tensor data"); - - future_ctx.alloc(rwkv_type_to_ggml[tensor_header.data_type], tensor_header.width, tensor_header.height); - - if (ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") { - ffn_key_size = tensor_header.height; - } - } + // 1. Measure the space required for the ggml context. + graph.ggml_ctx = rwkv_init_ggml_context(rwkv_ggml_overhead(), true); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, ffn_key_size, "Model is missing parameter blocks.0.ffn.key.weight"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, sizeof(struct rwkv_file_header), SEEK_SET) == 0, "Failed to seek in file"); + RWKV_ENSURE_OR_FALSE(rwkv_build_sequential_graph(model, graph, sequence_length)); - ctx = future_ctx; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx.ctx, "Failed to allocate model context"); + struct ggml_allocr * allocator = ggml_allocr_new_measure(tensor_alignment); - struct ggml_tensor * tensor; + size_t required_context_size = ggml_allocr_alloc_graph(allocator, graph.cgraph.get()) + + + rwkv_ggml_overhead() + + tensor_alignment + // For some reason, calculation above does not result in enough memory allocated. + // Instead of diving deep into ggml internals to debug this issue, I will just add some padding. + // 64 MB per token seems to be enough for Raven 14B model. It works for sequence_length = 5; not tested on larger lengths. + + sequence_length * 64 * 1024 * 1024; - while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, ctx.ctx, name, tensor), "Failed to read model params"); - parameters[std::move(name)] = tensor; - } - } + ggml_allocr_free(allocator); + ggml_free(graph.ggml_ctx); - std::unordered_map & parameters_ref = parameters; - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(model, [&](const char * key, struct ggml_tensor *& dest) { - struct ggml_tensor * tensor = parameters_ref[key]; - RWKV_ENSURE_OR_FALSE_MSG(tensor, "Model parameter %s not found", key); - dest = tensor; - return true; - })); + // 2. Create the real ggml context. + graph.ggml_ctx = rwkv_init_ggml_context(required_context_size, false); - // Verify order of dimensions - struct ggml_tensor * emb = model.emb; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model.header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model.header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); + RWKV_ENSURE_OR_FALSE(rwkv_build_sequential_graph(model, graph, sequence_length)); - instance.ctx = std::move(ctx); - instance.model = std::move(model); - instance.ffn_key_size = ffn_key_size; return true; } -struct rwkv_context * rwkv_new_context_impl(std::shared_ptr instance, const uint32_t n_threads) { - global_last_error = RWKV_ERROR_NONE; +// --- - struct rwkv_file_header & header = instance->model.header; - const size_t n_vocab = header.n_vocab; - const size_t n_embed = header.n_embed; - const size_t n_layer = header.n_layer; - - struct rwkv_future_ctx future_ctx; - const struct rwkv_future_tensor future_input = future_ctx.alloc(GGML_TYPE_F32, n_embed * 5 * n_layer); - const struct rwkv_future_tensor future_output = future_ctx.alloc(GGML_TYPE_F32, n_embed * 5 * n_layer); - const struct rwkv_future_tensor future_logits = future_ctx.alloc(GGML_TYPE_F32, n_vocab); - - for (size_t i = 0; i < n_layer; i++) { - /* ffn_xx */ future_input.subview(future_ctx, n_embed); future_output.subview(future_ctx, n_embed); - /* att_xx */ future_input.subview(future_ctx, n_embed); future_output.subview(future_ctx, n_embed); - /* att_aa */ future_input.subview(future_ctx, n_embed); future_output.subview(future_ctx, n_embed); - /* att_bb */ future_input.subview(future_ctx, n_embed); future_output.subview(future_ctx, n_embed); - /* att_pp */ future_input.subview(future_ctx, n_embed); future_output.subview(future_ctx, n_embed); - } - - struct rwkv_ggml_context ctx(future_ctx); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx.ctx, "Failed to allocate model context"); +struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { + global_last_error = RWKV_ERROR_NONE; - struct ggml_tensor * input = ggml_new_tensor_1d(ctx.ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); - struct ggml_tensor * output = ggml_new_tensor_1d(ctx.ctx, GGML_TYPE_F32, n_embed * 5 * n_layer); + std::unique_ptr ctx(new(std::nothrow) struct rwkv_context()); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx, "Failed to allocate rwkv_context"); - // We collect parts of input state here. Each part is (n_embed) vector. - std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, inputs.get(), "Failed to allocate input state parts"); + ctx->model = new(std::nothrow) struct rwkv_model(); + ctx->model->reference_count++; + RWKV_ENSURE_OR_NULL(rwkv_load_model_from_file(file_path, *ctx->model)); - // We collect parts of output state here. Each part is (n_embed) vector. - std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); + ctx->n_threads = n_threads; - for (size_t i = 0; i < n_layer; i++) { - struct rwkv_layer_state & input_state = inputs[i]; - input_state.ffn_xx = ggml_view_1d(ctx.ctx, input, n_embed, n_embed * (i * 5 + 0) * sizeof(float)); - input_state.att_xx = ggml_view_1d(ctx.ctx, input, n_embed, n_embed * (i * 5 + 1) * sizeof(float)); - input_state.att_aa = ggml_view_1d(ctx.ctx, input, n_embed, n_embed * (i * 5 + 2) * sizeof(float)); - input_state.att_bb = ggml_view_1d(ctx.ctx, input, n_embed, n_embed * (i * 5 + 3) * sizeof(float)); - input_state.att_pp = ggml_view_1d(ctx.ctx, input, n_embed, n_embed * (i * 5 + 4) * sizeof(float)); + RWKV_ENSURE_OR_NULL(rwkv_measure_and_build_serial_context(*ctx->model, ctx->serial_graph)); - struct rwkv_layer_state & output_state = outputs[i]; - output_state.ffn_xx = ggml_view_1d(ctx.ctx, output, n_embed, n_embed * (i * 5 + 0) * sizeof(float)); - output_state.att_xx = ggml_view_1d(ctx.ctx, output, n_embed, n_embed * (i * 5 + 1) * sizeof(float)); - output_state.att_aa = ggml_view_1d(ctx.ctx, output, n_embed, n_embed * (i * 5 + 2) * sizeof(float)); - output_state.att_bb = ggml_view_1d(ctx.ctx, output, n_embed, n_embed * (i * 5 + 3) * sizeof(float)); - output_state.att_pp = ggml_view_1d(ctx.ctx, output, n_embed, n_embed * (i * 5 + 4) * sizeof(float)); - } + return ctx.release(); +} - struct ggml_tensor * logits = ggml_new_tensor_1d(ctx.ctx, GGML_TYPE_F32, n_vocab); - - struct rwkv_future_ctx graph_future_ctx; - const struct rwkv_future_tensor future_token = graph_future_ctx.alloc(GGML_TYPE_I32, 1, 1, false); - - const struct rwkv_model & model = instance->model; - const struct rwkv_layer & layer = model.layers[0]; - const struct rwkv_layer_state & state = inputs[0]; - struct rwkv_future_tensor ffn_xx = state.ffn_xx; - struct rwkv_future_tensor att_xx = state.att_xx; - struct rwkv_future_tensor att_aa = state.att_aa; - struct rwkv_future_tensor att_bb = state.att_bb; - struct rwkv_future_tensor att_pp = state.att_pp; - - const struct rwkv_future_tensor future_graph = rwkv_future_serial_graph(graph_future_ctx, future_token, n_threads, - model.emb, - model.ln0_weight, model.ln0_bias, - - n_layer, - layer.ln1_weight, layer.ln1_bias, - layer.att_time_mix_k, layer.att_time_mix_v, layer.att_time_mix_r, - layer.att_time_first, layer.att_time_decay, - layer.att_receptance, layer.att_key, layer.att_value, layer.att_output, - att_xx, att_aa, att_bb, att_pp, - - layer.ln2_weight, layer.ln2_bias, - layer.ffn_time_mix_k, layer.ffn_time_mix_r, - layer.ffn_key, layer.ffn_value, layer.ffn_receptance, - ffn_xx, - - model.ln_out_weight, model.ln_out_weight, - model.head - ); +struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads) { + std::unique_ptr clone(new(std::nothrow) struct rwkv_context()); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, clone, "Failed to allocate rwkv_context"); - struct rwkv_graph serial_graph; - serial_graph.ctx = graph_future_ctx; - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, serial_graph.ctx.ctx, "Failed to allocate serial graph context"); - serial_graph.tokens = ggml_new_i32(serial_graph.ctx.ctx, 0); - serial_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, serial_graph.cgraph, "Failed to allocate serial graph"); - serial_graph.cgraph->n_threads = n_threads; - - RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph( - serial_graph.ctx.ctx, instance->model, - serial_graph.tokens, inputs.get(), outputs.get(), logits, - serial_graph.cgraph.get(), - &serial_graph.pre_logits_nodes, &serial_graph.pre_logits_leafs, &serial_graph.post_logits_nodes, &serial_graph.post_logits_leafs - )); - - std::unique_ptr rwkv_ctx(new(std::nothrow) struct rwkv_context()); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx, "Failed to allocate rwkv_context"); - rwkv_ctx->instance = std::move(instance); - rwkv_ctx->ctx = std::move(ctx); - rwkv_ctx->input_state = input; - rwkv_ctx->input_layers = std::move(inputs); - rwkv_ctx->output_state = output; - rwkv_ctx->output_layers = std::move(outputs); - rwkv_ctx->logits = logits; - rwkv_ctx->n_threads = n_threads; - rwkv_ctx->serial_graph = std::move(serial_graph); - rwkv_ctx->last_error = RWKV_ERROR_NONE; - rwkv_ctx->print_errors = global_print_errors; - return rwkv_ctx.release(); -} + clone->model = ctx->model; + clone->model->reference_count++; -struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { - global_last_error = RWKV_ERROR_NONE; + clone->n_threads = n_threads; - std::shared_ptr instance(new(std::nothrow) struct rwkv_instance()); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, instance, "Failed to allocate instance"); - RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get())); - return rwkv_new_context_impl(instance, n_threads); -} + RWKV_ENSURE_OR_NULL(rwkv_measure_and_build_serial_context(*clone->model, clone->serial_graph)); -struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads) { - struct rwkv_context * clone = rwkv_new_context_impl(ctx->instance, n_threads); + clone->last_used_sequence_length = 0; - if (clone) { - clone->print_errors = ctx->print_errors; - } + clone->print_errors = ctx->print_errors; - return clone; + return clone.release(); } bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) { #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) const auto offload = [&](struct ggml_tensor * tensor) { - // TODO support multi-GPU + // TODO Support multi-GPU tensor->backend = GGML_BACKEND_GPU; #ifdef GGML_USE_CUBLAS ggml_cuda_transform_tensor(tensor->data, tensor); @@ -1557,13 +1235,13 @@ bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) #endif }; - const size_t n_gpu = std::min(n_layers, ctx->instance->model.header.n_layer); + const size_t n_gpu = std::min(n_layers, ctx->model->header.n_layer); - if (ctx->gpu_layers < n_gpu) { - for (size_t & i = ctx->gpu_layers; i < n_gpu; i++) { - const struct rwkv_layer & layer = ctx->instance->model.layers[i]; + if (ctx->model->offloaded_layer_count < n_gpu) { + for (size_t & i = ctx->model->offloaded_layer_count; i < n_gpu; i++) { + const struct rwkv_layer & layer = ctx->model->layers[i]; - // TODO also offload other operations to GPU with ggml_cuda_assign_buffers + // TODO Also offload other operations to GPU with ggml_cuda_assign_buffers offload(layer.att_key); offload(layer.att_value); offload(layer.att_receptance); @@ -1580,134 +1258,97 @@ bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) return false; } -void rwkv_set_inputs(const struct rwkv_context * ctx, const float * state_in) { +void rwkv_set_inputs(const struct rwkv_context * ctx, const struct rwkv_computation_graph & graph, const float * state_in) { if (state_in) { - memcpy(ctx->input_state->data, state_in, ggml_nbytes(ctx->input_state)); + memcpy(graph.input_state->data, state_in, rwkv_tensor_nbytes(graph.input_state)); } else { - rwkv_init_state(ctx, (float *) ctx->input_state->data); + rwkv_init_state(ctx, (float *) graph.input_state->data); } } -void rwkv_get_outputs(const struct rwkv_context * ctx, float * state_out, float * logits_out) { +void rwkv_get_outputs(const struct rwkv_computation_graph & graph, float * state_out, float * logits_out) { if (state_out) { - memcpy(state_out, ctx->output_state->data, ggml_nbytes(ctx->output_state)); + memcpy(state_out, graph.output_state->data, rwkv_tensor_nbytes(graph.output_state)); } if (logits_out) { - memcpy(logits_out, ctx->logits->data, ggml_nbytes(ctx->logits)); + memcpy(logits_out, graph.logits->data, rwkv_tensor_nbytes(graph.logits)); } } +void rwkv_eval_graph(struct rwkv_computation_graph & graph, const uint32_t n_threads, const bool compute_logits) { + // Short circuit computation of logits if they are not needed. + if (!compute_logits) { + graph.cgraph->n_nodes = graph.pre_logits_nodes; + graph.cgraph->n_leafs = graph.pre_logits_leafs; + } else { + graph.cgraph->n_nodes = graph.post_logits_nodes; + graph.cgraph->n_leafs = graph.post_logits_leafs; + } + + struct ggml_cplan * plan = ggml_graph_plan(graph.cgraph.get(), n_threads); + + std::unique_ptr work_data{ new(std::nothrow) uint8_t[plan->work_size] }; + plan->work_data = work_data.get(); + + ggml_graph_compute(graph.cgraph.get(), plan); + + free(plan); +} + bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) { ctx->last_error = RWKV_ERROR_NONE; - const struct rwkv_file_header & header = ctx->instance->model.header; + const struct rwkv_file_header & header = ctx->model->header; const size_t n_vocab = header.n_vocab; RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < n_vocab, "Token (%" PRId32 ") is out of range (0 .. %zu)", token, n_vocab - 1); - rwkv_set_inputs(ctx, state_in); + rwkv_set_inputs(ctx, ctx->serial_graph, state_in); ggml_set_i32(ctx->serial_graph.tokens, token); - // Short circuit computation of logits if nobody actually cares - if (!logits_out) { - ctx->serial_graph.cgraph->n_nodes = ctx->serial_graph.pre_logits_nodes; - ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.pre_logits_leafs; - } else { - ctx->serial_graph.cgraph->n_nodes = ctx->serial_graph.post_logits_nodes; - ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs; - } + rwkv_eval_graph(ctx->serial_graph, ctx->n_threads, logits_out != NULL); - ggml_graph_compute(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get()); - rwkv_get_outputs(ctx, state_out, logits_out); + rwkv_get_outputs(ctx->serial_graph, state_out, logits_out); return true; } -bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, const size_t sequence_len, const float * state_in, float * state_out, float * logits_out) { +bool rwkv_eval_sequence( + struct rwkv_context * ctx, + const uint32_t * sequence, + const size_t sequence_len, + const float * state_in, + float * state_out, + float * logits_out +) { ctx->last_error = RWKV_ERROR_NONE; - const struct rwkv_file_header & header = ctx->instance->model.header; - const size_t n_vocab = header.n_vocab; - const size_t n_embed = header.n_embed; - const size_t n_layer = header.n_layer; + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, sequence_len > 0, "Sequence length is 0"); + + const size_t n_vocab = ctx->model->header.n_vocab; if (sequence) { for (size_t i = 0; i < sequence_len; i++) { const uint32_t token = sequence[i]; + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < n_vocab, "Token at index %zu (%" PRId32 ") is out of range (0 .. %zu)", i, token, n_vocab - 1); } } - if (ctx->sequence_len != sequence_len) { - // Build new sequence graph - - struct rwkv_future_ctx graph_future_ctx; - const struct rwkv_future_tensor future_tokens = graph_future_ctx.alloc(GGML_TYPE_I32, sequence_len); - - const struct rwkv_model & model = ctx->instance->model; - const struct rwkv_layer & layer = model.layers[0]; - const struct rwkv_layer_state & state = ctx->input_layers[0]; - struct rwkv_future_tensor ffn_xx = state.ffn_xx; - struct rwkv_future_tensor att_xx = state.att_xx; - struct rwkv_future_tensor att_aa = state.att_aa; - struct rwkv_future_tensor att_bb = state.att_bb; - struct rwkv_future_tensor att_pp = state.att_pp; - - const struct rwkv_future_tensor future_graph = rwkv_future_sequence_graph(graph_future_ctx, future_tokens, 1, - model.emb, - model.ln0_weight, model.ln0_bias, - - n_layer, - layer.ln1_weight, layer.ln1_bias, - layer.att_time_mix_k, layer.att_time_mix_v, layer.att_time_mix_r, - layer.att_time_first, layer.att_time_decay, - layer.att_receptance, layer.att_key, layer.att_value, layer.att_output, - att_xx, att_aa, att_bb, att_pp, - - layer.ln2_weight, layer.ln2_bias, - layer.ffn_time_mix_k, layer.ffn_time_mix_r, - layer.ffn_key, layer.ffn_value, layer.ffn_receptance, - ffn_xx, - - model.ln_out_weight, model.ln_out_weight, - model.head - ); + if (ctx->last_used_sequence_length != sequence_len) { + RWKV_ENSURE_OR_FALSE(rwkv_measure_and_build_sequential_context(*ctx->model, ctx->sequential_graph, sequence_len)); - struct rwkv_graph sequence_graph; - sequence_graph.ctx = graph_future_ctx; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, sequence_graph.ctx.ctx, "Failed to allocate sequence graph context"); - sequence_graph.tokens = ggml_new_tensor_1d(sequence_graph.ctx.ctx, GGML_TYPE_I32, sequence_len); - sequence_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, sequence_graph.cgraph, "Failed to allocate sequence graph"); - sequence_graph.cgraph->n_threads = 1; - - RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph( - sequence_graph.ctx.ctx, ctx->instance->model, - sequence_graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits, - sequence_graph.cgraph.get(), - &sequence_graph.pre_logits_nodes, &sequence_graph.pre_logits_leafs, &sequence_graph.post_logits_nodes, &sequence_graph.post_logits_leafs - )); - - ctx->sequence_len = sequence_len; - ctx->sequence_graph = std::move(sequence_graph); + ctx->last_used_sequence_length = sequence_len; } // Allow building the sequence graph without actually evaluating, by specifying sequence = NULL. if (sequence) { - rwkv_set_inputs(ctx, state_in); - memcpy(ctx->sequence_graph.tokens->data, sequence, sequence_len * sizeof(uint32_t)); + rwkv_set_inputs(ctx, ctx->sequential_graph, state_in); + memcpy(ctx->sequential_graph.tokens->data, sequence, sequence_len * sizeof(uint32_t)); - // Short circuit computation of logits if nobody actually cares - if (!logits_out) { - ctx->sequence_graph.cgraph->n_nodes = ctx->sequence_graph.pre_logits_nodes; - ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.pre_logits_leafs; - } else { - ctx->sequence_graph.cgraph->n_nodes = ctx->sequence_graph.post_logits_nodes; - ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs; - } + rwkv_eval_graph(ctx->sequential_graph, ctx->n_threads, logits_out != NULL); - ggml_graph_compute(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get()); - rwkv_get_outputs(ctx, state_out, logits_out); + rwkv_get_outputs(ctx->sequential_graph, state_out, logits_out); } return true; @@ -1724,28 +1365,29 @@ extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct r } extern "C" RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) { - return (size_t) ctx->instance->model.header.n_vocab; + return (size_t) ctx->model->header.n_vocab; } extern "C" RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx) { - return (size_t) ctx->instance->model.header.n_embed; + return (size_t) ctx->model->header.n_embed; } extern "C" RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx) { - return (size_t) ctx->instance->model.header.n_layer; + return (size_t) ctx->model->header.n_layer; } size_t rwkv_get_state_len(const struct rwkv_context * ctx) { - const struct rwkv_file_header & header = ctx->instance->model.header; + const struct rwkv_file_header & header = ctx->model->header; + return (size_t) header.n_embed * 5 * (size_t) header.n_layer; } size_t rwkv_get_logits_len(const struct rwkv_context * ctx) { - return (size_t) ctx->instance->model.header.n_vocab; + return (size_t) ctx->model->header.n_vocab; } void rwkv_init_state(const struct rwkv_context * ctx, float * state) { - const struct rwkv_file_header & header = ctx->instance->model.header; + const struct rwkv_file_header & header = ctx->model->header; const size_t layer_size = (size_t) header.n_embed * 5; const size_t layer_zero = (size_t) header.n_embed * 4; const size_t layers_size = (size_t) header.n_layer * layer_size; @@ -1762,6 +1404,18 @@ void rwkv_init_state(const struct rwkv_context * ctx, float * state) { } void rwkv_free(struct rwkv_context * ctx) { + if (--ctx->model->reference_count == 0) { + ggml_free(ctx->model->ggml_ctx); + + delete ctx->model; + } + + ggml_free(ctx->serial_graph.ggml_ctx); + + if (ctx->last_used_sequence_length > 0) { + ggml_free(ctx->sequential_graph.ggml_ctx); + } + std::unique_ptr rwkv_ctx(ctx); } @@ -1828,14 +1482,14 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const max_out_size = in_size; } - size_t f32_size = rwkv_future_tensor::size(GGML_TYPE_F32, header.width, header.height); + size_t f32_size = rwkv_tensor_nbytes(GGML_TYPE_F32, header.width, header.height); if (f32_size > max_in_size) { max_in_size = f32_size; } } - size_t out_size = rwkv_future_tensor::size(out_type, header.width, header.height); + size_t out_size = rwkv_tensor_nbytes(out_type, header.width, header.height); if (out_size > max_out_size) { max_out_size = out_size; diff --git a/rwkv.h b/rwkv.h index b1ada36..87bca55 100644 --- a/rwkv.h +++ b/rwkv.h @@ -103,10 +103,9 @@ extern "C" { RWKV_API bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers); // Evaluates the model for a single token. + // You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10 ms per iteration, because logits are not calculated. // Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. // Returns false on any error. - // You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration - // that you do not calculate logits. // - token: next token index, in range 0 <= token < n_vocab. // - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass. // - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. @@ -114,12 +113,23 @@ extern "C" { RWKV_API bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out); // Evaluates the model for a sequence of tokens. - // Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so. + // Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so. // Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length. + // + // NOTE ON GGML NODE LIMIT + // + // ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes + // this limit when using large models and/or large sequence lengths. + // Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models. + // + // If you get `GGML_ASSERT: ...\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. + // To get rid of the assertion failure, reduce the model size and/or sequence length. + // + // TODO When Metal (MPS) support is implemented, check that large sequence lengths work + // + // You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10 ms per iteration, because logits are not calculated. // Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. // Returns false on any error. - // You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration - // that you do not calculate logits. // - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization. // - sequence_len: number of tokens to read from the array. // - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass. diff --git a/rwkv/quantize.py b/rwkv/quantize.py index 305573b..fe45da6 100644 --- a/rwkv/quantize.py +++ b/rwkv/quantize.py @@ -1,6 +1,6 @@ # Quantizes rwkv.cpp model file from FP32 or FP16. # Available format names are in rwkv_cpp_shared_library.QUANTIZED_FORMAT_NAMES -# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-FP32.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1 +# Usage: python quantize.py C:\rwkv.cpp-169M-FP32.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1 import argparse import rwkv_cpp_shared_library diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index 0a8a842..e612dd1 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -131,6 +131,16 @@ def eval_sequence( ) -> Tuple[torch.Tensor, torch.Tensor]: """ Evaluates the model for a sequence of tokens. + + NOTE ON GGML NODE LIMIT + + ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes + this limit when using large models and/or large sequence lengths. + Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models. + + If you get `GGML_ASSERT: ...\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. + To get rid of the assertion failure, reduce the model size and/or sequence length. + In case of any error, this method will throw an exception. Parameters diff --git a/rwkv/rwkv_cpp_shared_library.py b/rwkv/rwkv_cpp_shared_library.py index 718b697..edc4736 100644 --- a/rwkv/rwkv_cpp_shared_library.py +++ b/rwkv/rwkv_cpp_shared_library.py @@ -139,6 +139,16 @@ def rwkv_eval( ) -> None: """ Evaluates the model for a single token. + + NOTE ON GGML NODE LIMIT + + ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes + this limit when using large models and/or large sequence lengths. + Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models. + + If you get `GGML_ASSERT: ...\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. + To get rid of the assertion failure, reduce the model size and/or sequence length. + Throws an exception in case of any error. Error messages would be printed to stderr. Parameters diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a48dda2..8705d94 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -15,8 +15,13 @@ endfunction() file(COPY tiny-rwkv-660K-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY tiny-rwkv-660K-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-660K-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-660K-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) rwkv_add_test(test_ggml_basics.c) +rwkv_add_test(test_quantized_matmul_on_gpu.c) rwkv_add_test(test_tiny_rwkv.c) +rwkv_add_test(test_quantization_format_compatibility.c) +rwkv_add_test(test_logit_calculation_skipping.c) rwkv_add_test(test_context_cloning.c) diff --git a/tests/logit_difference_validator.inc b/tests/logit_difference_validator.inc new file mode 100644 index 0000000..d449004 --- /dev/null +++ b/tests/logit_difference_validator.inc @@ -0,0 +1,82 @@ +// TODO Move to inc +#define ASSERT(x, ...) {\ + if (!(x)) {\ + fprintf(stderr, "*** Assertion failed ***\n");\ + fprintf(stderr, __VA_ARGS__);\ + fprintf(stderr, "\n%s:%d\n", __FILE__, __LINE__);\ + abort();\ + }\ + } + +// RWKV Tiny is a byte-level model +#define N_VOCAB 256 +// Also test multithreading +#define N_THREADS 2 + +void load_expected_logits(float * expected_logits) { + FILE * file = fopen("expected_logits.bin", "rb"); + ASSERT(file != NULL, "Failed to open expected_logits.bin"); + size_t elements_read = fread(expected_logits, sizeof(float), N_VOCAB, file); + ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read); + fclose(file); +} + +void test_model(const char * model_path, const float * expected_logits, const float max_diff) { + fprintf(stderr, "Testing %s\n", model_path); + + struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS); + enum rwkv_error_flags error = rwkv_get_last_error(NULL); + ASSERT(error == 0, "Unexpected error %d", error); + +#ifdef GGML_USE_CUBLAS + ASSERT(rwkv_gpu_offload_layers(model, rwkv_get_n_layer(model)), "Failed to offload layers to GPU"); +#endif + + const size_t n_vocab = rwkv_get_logits_len(model); + + ASSERT(n_vocab == N_VOCAB, "Unexpected n_vocab in the model"); + + float * state = malloc(sizeof(float) * rwkv_get_state_len(model)); + float * logits = malloc(sizeof(float) * n_vocab); + + char * prompt = "\"in"; + uint32_t prompt_seq[] = { '"', 'i', 'n' }; + + const size_t prompt_length = strlen(prompt); + + rwkv_init_state(model, state); + + for (size_t i = 0; i < prompt_length; i++) { + rwkv_eval(model, prompt[i], state, state, logits); + } + + float diff_sum = 0.0F; + + for (uint32_t i = 0; i < n_vocab; i++) { + diff_sum += logits[i] - expected_logits[i]; + } + + fprintf(stderr, "Difference sum: %f\n", diff_sum); + + // When something breaks, difference would be way more than 10 + ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); + + rwkv_init_state(model, state); + rwkv_eval_sequence(model, prompt_seq, prompt_length, state, state, logits); + + diff_sum = 0.0F; + + for (uint32_t i = 0; i < n_vocab; i++) { + diff_sum += logits[i] - expected_logits[i]; + } + + fprintf(stderr, "Sequence difference sum: %f\n", diff_sum); + + // When something breaks, difference would be way more than 10 + ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big sequence difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); + + rwkv_free(model); + + free(state); + free(logits); +} diff --git a/tests/test_context_cloning.c b/tests/test_context_cloning.c index eb0f7c4..e911f98 100644 --- a/tests/test_context_cloning.c +++ b/tests/test_context_cloning.c @@ -1,3 +1,4 @@ +// Tests that evaluation works after the context was cloned. #include #include @@ -17,7 +18,7 @@ int main() { float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); if (!state || !logits) { - fprintf(stderr, "Failed to allocate state/logits\n"); + fprintf(stderr, "Failed to allocate state or logits\n"); return EXIT_FAILURE; } @@ -34,26 +35,33 @@ int main() { logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); if (!logits) { - fprintf(stderr, "Failed to allocate state/logits\n"); + fprintf(stderr, "Failed to allocate logits\n"); return EXIT_FAILURE; } struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2); - rwkv_eval(ctx, prompt[0], NULL, state, logits); + if (ctx == ctx2) { + fprintf(stderr, "Same context was returned\n"); + return EXIT_FAILURE; + } + + // The cloned context should work fine after the original context was freed. + rwkv_free(ctx); + + rwkv_eval(ctx2, prompt[0], NULL, state, logits); for (int i = 1; prompt[i] != 0; i++) { - rwkv_eval(ctx, prompt[i], state, state, logits); + rwkv_eval(ctx2, prompt[i], state, state, logits); } - if (memcmp(expected_logits, logits, rwkv_get_logits_len(ctx) * sizeof(float))) { - fprintf(stderr, "Results not identical :(\n"); + if (memcmp(expected_logits, logits, rwkv_get_logits_len(ctx2) * sizeof(float))) { + fprintf(stderr, "Results are not identical :(\n"); return EXIT_FAILURE; } else { - fprintf(stdout, "Results identical, success!\n"); + fprintf(stdout, "Results are identical, success!\n"); } - rwkv_free(ctx); rwkv_free(ctx2); free(expected_logits); @@ -61,4 +69,4 @@ int main() { free(state); return EXIT_SUCCESS; -} \ No newline at end of file +} diff --git a/tests/test_ggml_basics.c b/tests/test_ggml_basics.c index a31687a..d99cab2 100644 --- a/tests/test_ggml_basics.c +++ b/tests/test_ggml_basics.c @@ -1,6 +1,5 @@ // Tests that ggml basics work. - -#include "ggml.h" +#include #include #include @@ -8,6 +7,7 @@ #define SET_ELEMENT_F32(tensor, i, value) ((float *) tensor->data)[i] = value +// TODO Move to inc #define ASSERT(x, ...) {\ if (!(x)) {\ fprintf(stderr, "*** Assertion failed ***\n");\ @@ -22,7 +22,8 @@ ASSERT(fabsf(actual - expected_value) <= 0.0000001F, "At %s[%d]: expected %f, actual %f", #tensor, i, (double) expected_value, (double) actual);\ } -int main(void) { +// Tests simple computation in a single context. +static void test_computation(void) { struct ggml_init_params params = { .mem_size = 16 * 1024, .mem_buffer = NULL, @@ -45,9 +46,13 @@ int main(void) { struct ggml_tensor * sum = ggml_add(ctx, x, y); - struct ggml_cgraph graph = ggml_build_forward(sum); - graph.n_threads = 2; - ggml_graph_compute(ctx, &graph); + // Allocation on heap instead of stack avoids SegFault when GGML_MAX_NODES is set to a large value. + struct ggml_cgraph * graph = (struct ggml_cgraph *) calloc(1, sizeof(struct ggml_cgraph)); + ggml_build_forward_expand(graph, sum); + struct ggml_cplan * plan = ggml_graph_plan(graph, 2); + ggml_graph_compute(graph, plan); + free(plan); + free(graph); ASSERT_ELEMENT_F32(sum, 0, -9.0F); ASSERT_ELEMENT_F32(sum, 1, 2.0F); @@ -57,6 +62,51 @@ int main(void) { ggml_print_objects(ctx); ggml_free(ctx); +} + +// Tests that operations on tensors from different contexts work. +// RWKV model loading code depends on this behavior. +static void test_tensors_from_different_contexts(void) { + struct ggml_init_params params = { + .mem_size = 16 * 1024, + .mem_buffer = NULL, + .no_alloc = false, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_context * ctx1 = ggml_init(params); + struct ggml_context * ctx2 = ggml_init(params); + + struct ggml_tensor * x = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 4); + SET_ELEMENT_F32(x, 0, -10.0F); + SET_ELEMENT_F32(x, 1, 0.0F); + + struct ggml_tensor * y = ggml_new_tensor_1d(ctx1, GGML_TYPE_F32, 4); + SET_ELEMENT_F32(y, 0, 1.0F); + SET_ELEMENT_F32(y, 1, 2.0F); + + struct ggml_tensor * sum = ggml_add(ctx2, x, y); + + // Allocation on heap instead of stack avoids SegFault when GGML_MAX_NODES is set to a large value. + struct ggml_cgraph * graph = (struct ggml_cgraph *) calloc(1, sizeof(struct ggml_cgraph)); + ggml_build_forward_expand(graph, sum); + struct ggml_cplan * plan = ggml_graph_plan(graph, 2); + ggml_graph_compute(graph, plan); + free(plan); + free(graph); + + ASSERT_ELEMENT_F32(sum, 0, -9.0F); + ASSERT_ELEMENT_F32(sum, 1, 2.0F); + + ggml_free(ctx0); + ggml_free(ctx1); + ggml_free(ctx2); +} + +int main(void) { + test_computation(); + + test_tensors_from_different_contexts(); return 0; } diff --git a/tests/test_logit_calculation_skipping.c b/tests/test_logit_calculation_skipping.c new file mode 100644 index 0000000..cb7239d --- /dev/null +++ b/tests/test_logit_calculation_skipping.c @@ -0,0 +1,132 @@ +// Tests that evaluation works when the logits parameter was set to NULL. +#include + +#include +#include +#include + +#define TOKEN_COUNT 11 + +static const unsigned char prompt[TOKEN_COUNT + 1] = "hello world"; + +static int test_serial_mode() { + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); + + if (!ctx) { + enum rwkv_error_flags error = rwkv_get_last_error(NULL); + fprintf(stderr, "Unexpected error 0x%.8X\n", error); + return EXIT_FAILURE; + } + + float * state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); + + if (!state || !logits) { + fprintf(stderr, "Failed to allocate state or logits\n"); + return EXIT_FAILURE; + } + + rwkv_eval(ctx, prompt[0], NULL, state, logits); + + for (int i = 1; prompt[i] != 0; i++) { + rwkv_eval(ctx, prompt[i], state, state, logits); + } + + float * expected_state = state; + + state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + + if (!state) { + fprintf(stderr, "Failed to allocate state\n"); + return EXIT_FAILURE; + } + + rwkv_eval(ctx, prompt[0], NULL, state, NULL); + + for (int i = 1; prompt[i] != 0; i++) { + rwkv_eval(ctx, prompt[i], state, state, NULL); + } + + if (memcmp(expected_state, state, rwkv_get_state_len(ctx) * sizeof(float))) { + fprintf(stderr, "Serial mode: results are not identical :(\n"); + return EXIT_FAILURE; + } else { + fprintf(stdout, "Serial mode: results are identical, success!\n"); + } + + rwkv_free(ctx); + + free(logits); + free(state); + free(expected_state); + + return EXIT_SUCCESS; +} + +static int test_sequential_mode() { + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); + + if (!ctx) { + enum rwkv_error_flags error = rwkv_get_last_error(NULL); + fprintf(stderr, "Unexpected error 0x%.8X\n", error); + return EXIT_FAILURE; + } + + float * state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); + + if (!state || !logits) { + fprintf(stderr, "Failed to allocate state or logits\n"); + return EXIT_FAILURE; + } + + uint32_t prompt_tokens[TOKEN_COUNT]; + + for (int i = 0; i < TOKEN_COUNT; i++) { + prompt_tokens[i] = prompt[i]; + } + + rwkv_eval_sequence(ctx, prompt_tokens, TOKEN_COUNT, NULL, state, logits); + + float * expected_state = state; + + state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + + if (!state) { + fprintf(stderr, "Failed to allocate state\n"); + return EXIT_FAILURE; + } + + rwkv_eval_sequence(ctx, prompt_tokens, TOKEN_COUNT, NULL, state, NULL); + + if (memcmp(expected_state, state, rwkv_get_state_len(ctx) * sizeof(float))) { + fprintf(stderr, "Sequential mode: results are not identical :(\n"); + return EXIT_FAILURE; + } else { + fprintf(stdout, "Sequential mode: results are identical, success!\n"); + } + + rwkv_free(ctx); + + free(logits); + free(state); + free(expected_state); + + return EXIT_SUCCESS; +} + +int main() { + int result = test_serial_mode(); + + if (result != EXIT_SUCCESS) { + return result; + } + + result = test_sequential_mode(); + + if (result != EXIT_SUCCESS) { + return result; + } + + return EXIT_SUCCESS; +} diff --git a/tests/test_quantization_format_compatibility.c b/tests/test_quantization_format_compatibility.c new file mode 100644 index 0000000..0c3c4e3 --- /dev/null +++ b/tests/test_quantization_format_compatibility.c @@ -0,0 +1,23 @@ +// Tests that existing Q5_0 & Q5_1 model files are still working. +#include + +#include +#include +#include +#include + +#include "logit_difference_validator.inc" + +int main(void) { + fprintf(stderr, "System info: %s\n", rwkv_get_system_info_string()); + + float * expected_logits = malloc(sizeof(float) * N_VOCAB); + load_expected_logits(expected_logits); + + test_model("tiny-rwkv-660K-Q5_0.bin", expected_logits, -0.170404F); + test_model("tiny-rwkv-660K-Q5_1.bin", expected_logits, 0.278034F); + + free(expected_logits); + + return 0; +} diff --git a/tests/test_quantized_matmul_on_gpu.c b/tests/test_quantized_matmul_on_gpu.c new file mode 100644 index 0000000..cae8748 --- /dev/null +++ b/tests/test_quantized_matmul_on_gpu.c @@ -0,0 +1,93 @@ +// Tests that quantized matmul on GPU works. +#include + +#include +#include +#include + +// TODO Move to inc +#define ASSERT(x, ...) {\ + if (!(x)) {\ + fprintf(stderr, "*** Assertion failed ***\n");\ + fprintf(stderr, __VA_ARGS__);\ + fprintf(stderr, "\n%s:%d\n", __FILE__, __LINE__);\ + abort();\ + }\ + } + +#define SET_ELEMENT_F32(tensor, i, value) ((float *) tensor->data)[i] = value + +#define ELEMENT_COUNT 32 + +int main(void) { + #ifdef GGML_USE_CUBLAS + + struct ggml_init_params params = { + .mem_size = 16 * 1024, + .mem_buffer = NULL, + .no_alloc = false, + }; + + struct ggml_context * ctx = ggml_init(params); + + // --- + + struct ggml_tensor * x = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ELEMENT_COUNT, 1); + + for (int i = 0; i < ELEMENT_COUNT; i++) { + SET_ELEMENT_F32(x, i, 1.0F * i); + } + + // --- + + struct ggml_tensor * x_quantized = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, ELEMENT_COUNT, 1); + + int64_t hist[16]; + ggml_quantize_chunk(x_quantized->type, (const float *) x->data, x_quantized->data, 0, ELEMENT_COUNT, hist); + + x_quantized->backend = GGML_BACKEND_GPU; + ggml_cuda_transform_tensor(x_quantized->data, x_quantized); + + // --- + + struct ggml_tensor * y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ELEMENT_COUNT); + + for (int i = 0; i < ELEMENT_COUNT; i++) { + SET_ELEMENT_F32(y, i, 1.0F * i); + } + + // --- + + struct ggml_tensor * mul0 = ggml_mul_mat(ctx, x, y); + struct ggml_tensor * mul1 = ggml_mul_mat(ctx, x_quantized, y); + + // Allocation on heap instead of stack avoids SegFault when GGML_MAX_NODES is set to a large value. + struct ggml_cgraph * graph = (struct ggml_cgraph *) calloc(1, sizeof(struct ggml_cgraph)); + ggml_build_forward_expand(graph, mul0); + ggml_build_forward_expand(graph, mul1); + + struct ggml_cplan * plan = ggml_graph_plan(graph, 2); + + uint8_t * work_data = (uint8_t *) malloc(plan->work_size); + plan->work_data = work_data; + + ggml_graph_compute(graph, plan); + + free(plan); + free(graph); + free(work_data); + + float result0 = ((float *) mul0->data)[0]; + float result1 = ((float *) mul1->data)[0]; + + fprintf(stderr, "FP32 CPU result = %f\n", result0); + fprintf(stderr, "Q4_0 GPU result = %f\n", result1); + + ASSERT(fabsf(result0 - result1) <= 100.0F, "Results differ too much"); + + ggml_free(ctx); + + #endif + + return 0; +} diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index 2a12329..c8cbbd0 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -1,96 +1,18 @@ // Tests that tiny RWKV outputs expected results in all data types. - -#include "ggml.h" -#include "rwkv.h" +#include #include #include #include #include -#define ASSERT(x, ...) {\ - if (!(x)) {\ - fprintf(stderr, "*** Assertion failed ***\n");\ - fprintf(stderr, __VA_ARGS__);\ - fprintf(stderr, "\n%s:%d\n", __FILE__, __LINE__);\ - abort();\ - }\ - } - -// --- - -#define N_VOCAB 256 -#define N_THREADS 2 - -void test_model(const char * model_path, const float * expected_logits, const float max_diff) { - fprintf(stderr, "Testing %s\n", model_path); - - struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS); - enum rwkv_error_flags error = rwkv_get_last_error(NULL); - ASSERT(error == 0, "Unexpected error %d", error); - -#ifdef GGML_USE_CUBLAS - ASSERT(rwkv_gpu_offload_layers(model, rwkv_get_n_layer(model)), "Failed to offload layers to GPU"); -#endif - - const size_t n_vocab = rwkv_get_logits_len(model); - - ASSERT(n_vocab == N_VOCAB, "Unexpected n_vocab in the model"); - - float * state = malloc(sizeof(float) * rwkv_get_state_len(model)); - float * logits = malloc(sizeof(float) * n_vocab); - - char * prompt = "\"in"; - uint32_t prompt_seq[] = { '"', 'i', 'n' }; - - const size_t prompt_length = strlen(prompt); - - rwkv_init_state(model, state); - - for (size_t i = 0; i < prompt_length; i++) { - rwkv_eval(model, prompt[i], state, state, logits); - } - - float diff_sum = 0.0F; - - for (uint32_t i = 0; i < n_vocab; i++) { - diff_sum += logits[i] - expected_logits[i]; - } - - fprintf(stderr, "Difference sum: %f\n", diff_sum); - - // When something breaks, difference would be way more than 10 - ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); - - rwkv_init_state(model, state); - rwkv_eval_sequence(model, prompt_seq, prompt_length, state, state, logits); - - diff_sum = 0.0F; - - for (uint32_t i = 0; i < n_vocab; i++) { - diff_sum += logits[i] - expected_logits[i]; - } - - fprintf(stderr, "Sequence difference sum: %f\n", diff_sum); - - // When something breaks, difference would be way more than 10 - ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big sequence difference %f, expected no more than %f", (double) diff_sum, (double) max_diff); - - rwkv_free(model); - - free(state); - free(logits); -} +#include "logit_difference_validator.inc" int main(void) { fprintf(stderr, "System info: %s\n", rwkv_get_system_info_string()); float * expected_logits = malloc(sizeof(float) * N_VOCAB); - FILE * file = fopen("expected_logits.bin", "rb"); - ASSERT(file != NULL, "Failed to open expected_logits.bin"); - size_t elements_read = fread(expected_logits, sizeof(float), N_VOCAB, file); - ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read); - fclose(file); + load_expected_logits(expected_logits); // Somehow when using cuBLAS the calculation of Q4_1 may different from cpu only float expected_difference_sum[14] = { @@ -99,7 +21,7 @@ int main(void) { -0.160030F, #ifdef GGML_USE_CUBLAS - -0.412408F, + -0.547409F, #else -0.370606F, #endif @@ -109,7 +31,7 @@ int main(void) { 0.154614F, #ifdef GGML_USE_CUBLAS - -0.405527F, + -0.539827F, #else -0.372169F, #endif diff --git a/tests/tiny-rwkv-660K-Q5_0.bin b/tests/tiny-rwkv-660K-Q5_0.bin new file mode 100644 index 0000000..682ebaa Binary files /dev/null and b/tests/tiny-rwkv-660K-Q5_0.bin differ diff --git a/tests/tiny-rwkv-660K-Q5_1.bin b/tests/tiny-rwkv-660K-Q5_1.bin new file mode 100644 index 0000000..6bcbfc4 Binary files /dev/null and b/tests/tiny-rwkv-660K-Q5_1.bin differ