Skip to content

Commit

Permalink
Allow creating multiple contexts per model (#83)
Browse files Browse the repository at this point in the history
* Allow creating multiple contexts per model

This allows for parallel inference and I am preparing to support
sequence mode using a method similar to this

* Fix cuBLAS

* Update rwkv.h

Co-authored-by: Alex <saharNooby@users.noreply.github.com>

* Update rwkv.cpp

Co-authored-by: Alex <saharNooby@users.noreply.github.com>

* Inherit print_errors from parent ctx when cloning

* Add context cloning test

* Free

* Free ggml context when last rwkv_context is freed

* Free before exit

* int main

* add explanation of ffn_key_size

* Update rwkv_instance and rwkv_context comments

* Thread safety notes

---------

Co-authored-by: Alex <saharNooby@users.noreply.github.com>
  • Loading branch information
LoganDark and saharNooby committed Jun 3, 2023
1 parent 363dfb1 commit 3f8bb2c
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 29 deletions.
119 changes: 90 additions & 29 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,33 @@ struct rwkv_graph {
std::unique_ptr<struct ggml_cgraph> cgraph;
};

struct rwkv_context {
struct rwkv_ggml_guard {
struct ggml_context * ctx;
~rwkv_ggml_guard() { if (ctx) { ggml_free(ctx); } }
};

// An instance of an RWKV model loaded into memory:
// Contains all the model weights.
// Shared by one or more contexts.
struct rwkv_instance {
struct rwkv_model model;
struct rwkv_ggml_guard ctx;
std::unique_ptr<uint8_t []> scratch;

// TODO come up with a better solution to estimate "work tensor" size.
// The ggml_cgraph allocates a "work tensor" the first time it is used.
// Currently, the height of blocks.0.ffn.key.weight is the bottleneck in our implementation of RWKV.
// Since it is the largest dimension used in any matrix multiply, it is the size used for the "work tensor".
// However, if ggml changes its implementation, or rwkv.cpp changes its own implementation, at any point,
// this may become outdated. We need to find a way not to hardcode a specific tensor, but to calculate accurately.
// This may come out of a ggml issue: https://github.com/ggerganov/ggml/issues/214
size_t ffn_key_size;
};

// RWKV context for a specific instance.
// Contains the computation graph and is used for inference.
struct rwkv_context {
std::shared_ptr<struct rwkv_instance> instance;
struct ggml_context * ctx;
std::unique_ptr<uint8_t []> scratch;
struct rwkv_graph graph;
Expand Down Expand Up @@ -860,11 +885,6 @@ struct rwkv_file_guard {
~rwkv_file_guard() { if (file) { fclose(file); } }
};

struct rwkv_ggml_guard {
struct ggml_context * ctx;
~rwkv_ggml_guard() { if (ctx) { ggml_free(ctx); } }
};

void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) {
bool * ptr = ctx ? &ctx->print_errors : &global_print_errors;
*ptr = print_errors;
Expand All @@ -881,14 +901,12 @@ enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) {
return value;
}

struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
global_last_error = RWKV_ERROR_NONE;

bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance) {
FILE * file = fopen(file_path, "rb");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path);
rwkv_file_guard file_guard { file };

// Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to the get file length.
// Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to get the file length.
struct stat file_stat;
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file), &file_stat) == 0, "Failed to stat file %s", file_path);

Expand All @@ -897,28 +915,23 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t

size_t tensors_start = ftell(file);
struct rwkv_ctx_size ctx_size;
size_t ffn_key = 0;

std::string name;
instance.ffn_key_size = 0;

while ((size_t) ftell(file) < (size_t) file_stat.st_size) {
struct rwkv_tensor_header header;
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(file, header), "Invalid tensor header");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, rwkv_tensor_size(header), SEEK_CUR) == 0, "Failed to read tensor data");
rwkv_ctx_size_add_tensor(ctx_size, 1, 0, header);

if (ffn_key == 0 && name == "blocks.0.ffn.key.weight") {
ffn_key = header.height;
if (instance.ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") {
instance.ffn_key_size = header.height;
}
}

RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, ffn_key, "Model is missing parameter blocks.0.ffn.key.weight");

rwkv_ctx_size_add(ctx_size, 1, rwkv_single_graph_size(header.n_vocab, header.n_embed, header.n_layer, ffn_key));
// And finally to end it all off: the graph work tensor
enum ggml_type mul_mat_type = ggml_is_quantized(rwkv_type_to_ggml[header.data_type]) ? GGML_TYPE_Q8_1 : rwkv_type_to_ggml[header.data_type];
rwkv_ctx_size_add_objects(ctx_size, 1, sizeof(struct ggml_tensor) + rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, ffn_key) * n_threads + 64 * (n_threads - 1)));

RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, instance.ffn_key_size, "Model is missing parameter blocks.0.ffn.key.weight");
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file, tensors_start, SEEK_SET) == 0, "Failed to seek in file");

std::unique_ptr<uint8_t []> scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]);
Expand Down Expand Up @@ -957,16 +970,46 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]);
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]);

// Don't free ggml context now
ggml_guard.ctx = NULL;
// Attach ggml context to instance
instance.ctx.ctx = ctx;
instance.model = std::move(model);
instance.scratch = std::move(scratch);
return true;
}

struct rwkv_context * rwkv_new_context_impl(std::shared_ptr<struct rwkv_instance> instance, const uint32_t n_threads) {
global_last_error = RWKV_ERROR_NONE;

struct rwkv_file_header & header = instance->model.header;

rwkv_ctx_size ctx_size;
rwkv_ctx_size_add(ctx_size, 1, rwkv_single_graph_size(header.n_vocab, header.n_embed, header.n_layer, instance->ffn_key_size));
// And finally to end it all off: the graph work tensor
enum ggml_type mul_mat_type = ggml_is_quantized(rwkv_type_to_ggml[header.data_type]) ? GGML_TYPE_Q8_1 : rwkv_type_to_ggml[header.data_type];
rwkv_ctx_size_add(ctx_size, 1, rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, instance->ffn_key_size) * n_threads + 64 * (n_threads - 1)));

std::unique_ptr<uint8_t []> scratch(new(std::nothrow) uint8_t [ctx_size.scratch_size]);
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate graph scratch space (%d)", ctx_size.scratch_size);

struct ggml_context * ctx = ggml_init({ ctx_size.objects_size + ctx_size.objects_count * GGML_OBJECT_SIZE, NULL, false});
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx, "Failed to create GGML context");
rwkv_ggml_guard ggml_guard { ctx };

ggml_set_scratch(ctx, { 0, ctx_size.scratch_size, scratch.get() });

// Build graph
struct rwkv_graph graph;
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_single_graph(ctx, model, n_threads, graph));
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_single_graph(ctx, instance->model, n_threads, graph));

std::unique_ptr<struct rwkv_context> rwkv_ctx(new(std::nothrow) struct rwkv_context());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context");

// Don't free ggml context
ggml_guard.ctx = NULL;
rwkv_ctx->model = std::move(model);

rwkv_ctx->instance = std::move(instance);
rwkv_ctx->ctx = ctx;
rwkv_ctx->scratch = std::move(scratch);
rwkv_ctx->graph = std::move(graph);
Expand All @@ -975,21 +1018,39 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t
rwkv_ctx->gpu_layers = 0;
rwkv_ctx->vram_total = 0;

ggml_set_scratch(ctx, { 0, 0, NULL });

return rwkv_ctx.release();
}

struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
global_last_error = RWKV_ERROR_NONE;

std::shared_ptr<struct rwkv_instance> instance(new(std::nothrow) struct rwkv_instance);
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, instance.get(), "Failed to allocate instance");
RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get()));

return rwkv_new_context_impl(instance, n_threads);
}

struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads) {
struct rwkv_context * clone = rwkv_new_context_impl(ctx->instance, n_threads);

if (clone) {
clone->print_errors = ctx->print_errors;
}

return clone;
}

bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers) {
#ifdef GGML_USE_CUBLAS
{
size_t n_gpu = std::min(n_gpu_layers, ctx->model.header.n_layer);
size_t n_gpu = std::min(n_gpu_layers, ctx->instance->model.header.n_layer);

size_t gpu_layers = ((struct rwkv_context *) ctx)->gpu_layers;
size_t vram_total = ((struct rwkv_context *) ctx)->vram_total;

for (size_t i = 0; i < n_gpu; i++) {
const struct rwkv_layer & layer = ctx->model.layers[i];
const struct rwkv_layer & layer = ctx->instance->model.layers[i];

// Use cuBLAS only for heavy matrices; other operations are not supported for GPU at the moment
ggml_cuda_transform_tensor(layer.att_key); vram_total += ggml_nbytes(layer.att_key);
Expand All @@ -1012,7 +1073,7 @@ bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_g
bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) {
((struct rwkv_context *) ctx)->last_error = RWKV_ERROR_NONE;

const struct rwkv_file_header & header = ctx->model.header;
const struct rwkv_file_header & header = ctx->instance->model.header;
RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < header.n_vocab, "Token is out of range 0..%d", header.n_vocab - 1);

const struct rwkv_graph & graph = ctx->graph;
Expand Down Expand Up @@ -1055,11 +1116,11 @@ bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const floa
}

uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) {
return ctx->model.header.n_layer * 5 * ctx->model.header.n_embed;
return ctx->instance->model.header.n_layer * 5 * ctx->instance->model.header.n_embed;
}

uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) {
return ctx->model.header.n_vocab;
return ctx->instance->model.header.n_vocab;
}

void rwkv_free(struct rwkv_context * ctx) {
Expand Down
14 changes: 14 additions & 0 deletions rwkv.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ extern "C" {
RWKV_ERROR_PARAM_MISSING = 14
};

// RWKV context that can be used for inference.
// All functions that operate on rwkv_context are thread-safe.
// rwkv_context can be sent to different threads between calls to rwkv_eval.
// There is no requirement for rwkv_context to be freed on the creating thread.
struct rwkv_context;

// Sets whether errors are automatically printed to stderr.
Expand All @@ -85,11 +89,20 @@ extern "C" {
// - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);

// Creates a new context from an existing one.
// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
// Each rwkv_context can have one eval running at a time.
// Every rwkv_context must be freed using rwkv_free.
// - ctx: context to be cloned.
// - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads);

// Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
RWKV_API bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers);

// Evaluates the model for a single token.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error. Error messages would be printed to stderr.
// - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
Expand All @@ -104,6 +117,7 @@ extern "C" {
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx);

// Frees all allocated memory and the context.
// Does not need to be the same thread that created the rwkv_context.
RWKV_API void rwkv_free(struct rwkv_context * ctx);

// Quantizes FP32 or FP16 model to one of quantized formats.
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR})

rwkv_add_test(test_ggml_basics.c)
rwkv_add_test(test_tiny_rwkv.c)
rwkv_add_test(test_context_cloning.c)
64 changes: 64 additions & 0 deletions tests/test_context_cloning.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include <rwkv.h>

#include <stdlib.h>
#include <stdio.h>
#include <string.h>

int main() {
struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2);

if (!ctx) {
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
fprintf(stderr, "Unexpected error 0x%.8X\n", error);
return EXIT_FAILURE;
}

float * state = calloc(rwkv_get_state_buffer_element_count(ctx), sizeof(float));
float * logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));

if (!state || !logits) {
fprintf(stderr, "Failed to allocate state/logits\n");
return EXIT_FAILURE;
}

// 0xd1 or 209 is space (0x20 or \u0120 in tokenizer)
const unsigned char * prompt = "hello\xd1world";

rwkv_eval(ctx, prompt[0], NULL, state, logits);

for (const unsigned char * token = prompt + 1; *token != 0; token++) {
rwkv_eval(ctx, *token, state, state, logits);
}

float * expected_logits = logits;
logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));

if (!logits) {
fprintf(stderr, "Failed to allocate state/logits\n");
return EXIT_FAILURE;
}

struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2);

rwkv_eval(ctx, prompt[0], NULL, state, logits);

for (const unsigned char * token = prompt + 1; *token != 0; token++) {
rwkv_eval(ctx, *token, state, state, logits);
}

if (memcmp(expected_logits, logits, rwkv_get_logits_buffer_element_count(ctx) * sizeof(float))) {
fprintf(stderr, "results not identical :(\n");
return EXIT_FAILURE;
} else {
fprintf(stdout, "Results identical, success!\n");
}

rwkv_free(ctx);
rwkv_free(ctx2);

free(expected_logits);
free(logits);
free(state);

return EXIT_SUCCESS;
}

0 comments on commit 3f8bb2c

Please sign in to comment.