Skip to content

Commit

Permalink
Fix warnings, clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed May 6, 2023
1 parent 207e91a commit 1fef5d1
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 52 deletions.
85 changes: 41 additions & 44 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,31 @@

// --- Utilities ---

// Checks that x is not false. If x is false, prints fancy message to stderr and returns 0.
#define RWKV_ASSERT_FALSE(x, ...) \
do { \
// Checks that x is not false. If x is false, prints fancy message to stderr and returns RET_VAL.
#define RWKV_ASSERT(RET_VAL, x, ...) \
{ \
if (!(x)) { \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \
return false; \
return RET_VAL; \
} \
} while (0)
}

// Checks that x is not false. If x is false, prints fancy message to stderr and returns false.
#define RWKV_ASSERT_FALSE(x, ...) RWKV_ASSERT(false, x, __VA_ARGS__)

// Checks that x is not false. If x is false, prints fancy message to stderr and returns NULL.
#define RWKV_ASSERT_NULL(x, ...) \
do { \
if (!(x)) { \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \
return NULL; \
} \
} while (0)
#define RWKV_ASSERT_NULL(x, ...) RWKV_ASSERT(NULL, x, __VA_ARGS__)

// Reads single int32 value from a file.
bool read_int32(FILE * file, int32_t * dest) {
RWKV_ASSERT_FALSE(fread(dest, 4, 1, file) == 1, "Failed to read an int32 value from a file");
RWKV_ASSERT_FALSE(fread(dest, sizeof(int32_t), 1, file) == 1, "Failed to read an int32 value from a file");
return true;
}

// Reads single uint32 value from a file.
bool read_uint32(FILE * file, uint32_t * dest) {
RWKV_ASSERT_FALSE(fread(dest, sizeof(uint32_t), 1, file) == 1, "Failed to read a uint32 value from a file");
return true;
}

Expand Down Expand Up @@ -98,9 +100,9 @@ struct rwkv_layer {
};

struct rwkv_model {
int32_t n_vocab;
int32_t n_layer;
int32_t n_embed;
uint32_t n_vocab;
uint32_t n_layer;
uint32_t n_embed;
// 0 for float32, 1 for float16.
int32_t data_type;

Expand All @@ -119,18 +121,18 @@ struct rwkv_model {

// Finds model parameter by key and sets it into dest.
// If the parameter was not found, returns false.
bool set_parameter(std::unordered_map<std::string, struct ggml_tensor *> * parameters, char * key, struct ggml_tensor ** dest) {
bool set_parameter(std::unordered_map<std::string, struct ggml_tensor *> * parameters, std::string key, struct ggml_tensor ** dest) {
struct ggml_tensor * parameter = (*parameters)[key];
RWKV_ASSERT_FALSE(parameter != NULL, "Parameter %s not found in model file", key);
RWKV_ASSERT_FALSE(parameter != NULL, "Parameter %s not found in model file", key.c_str());
*dest = parameter;
return true;
}

// Finds block parameter by block index and key and sets it into dest.
// If the parameter was not found, returns false.
bool set_block_parameter(std::unordered_map<std::string, struct ggml_tensor *> * parameters, int32_t block_index, char * key, struct ggml_tensor ** dest) {
bool set_block_parameter(std::unordered_map<std::string, struct ggml_tensor *> * parameters, int32_t block_index, std::string key, struct ggml_tensor ** dest) {
char full_key[128];
sprintf(full_key, "blocks.%d.%s", block_index, key);
sprintf(full_key, "blocks.%d.%s", block_index, key.c_str());
return set_parameter(parameters, full_key, dest);
}

Expand Down Expand Up @@ -198,7 +200,7 @@ struct rwkv_context {
bool freed;
};

struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_threads) {
struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) {
FILE * file = fopen(file_path, "rb");
RWKV_ASSERT_NULL(file != NULL, "Failed to open file %s", file_path);

Expand All @@ -212,14 +214,9 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr

struct rwkv_model * model = (struct rwkv_model *) calloc(1, sizeof(struct rwkv_model));

read_int32(file, &(model->n_vocab));
RWKV_ASSERT_NULL(model->n_vocab > 0, "Non-positive n_vocab %d", model->n_vocab);

read_int32(file, &(model->n_embed));
RWKV_ASSERT_NULL(model->n_embed > 0, "Non-positive n_embed %d", model->n_embed);

read_int32(file, &(model->n_layer));
RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer);
read_uint32(file, &(model->n_vocab));
read_uint32(file, &(model->n_embed));
read_uint32(file, &(model->n_layer));

read_int32(file, &(model->data_type));
RWKV_ASSERT_NULL(model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type);
Expand Down Expand Up @@ -321,7 +318,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
set_parameter(&parameters, "blocks.0.ln0.weight", &(model->ln0_weight));
set_parameter(&parameters, "blocks.0.ln0.bias", &(model->ln0_bias));

for (int i = 0; i < model->n_layer; i++) {
for (uint32_t i = 0; i < model->n_layer; i++) {
rwkv_layer layer = model->layers[i];

set_block_parameter(&parameters, i, "ln1.weight", &(layer.ln1_weight));
Expand Down Expand Up @@ -360,8 +357,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
RWKV_ASSERT_NULL(emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %lld", emb->ne[0]);
RWKV_ASSERT_NULL(emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %lld", emb->ne[1]);

int32_t n_embed = model->n_embed;
int32_t n_layer = model->n_layer;
uint32_t n_embed = model->n_embed;
uint32_t n_layer = model->n_layer;

// Build graph
struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed);
Expand All @@ -376,7 +373,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
// We collect parts of new state here. Each part is (n_embed) vector.
struct ggml_tensor ** state_parts = new ggml_tensor * [n_layer * 5];

for (int i = 0; i < n_layer; i++) {
for (uint32_t i = 0; i < n_layer; i++) {
auto layer = model->layers[i];

// RWKV/time mixing
Expand Down Expand Up @@ -533,7 +530,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr

*graph = ggml_build_forward(logits);

for (int i = 0; i < n_layer * 5; i++) {
for (uint32_t i = 0; i < n_layer * 5; i++) {
ggml_build_forward_expand(graph, state_parts[i]);
}

Expand All @@ -550,30 +547,30 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
return rwkv_ctx;
}

uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx) {
uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx) {
return ctx->model->n_layer * 5 * ctx->model->n_embed;
}

uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx) {
uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx) {
return ctx->model->n_vocab;
}

bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out) {
bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) {
RWKV_ASSERT_FALSE(state_out != NULL, "state_out is NULL");
RWKV_ASSERT_FALSE(logits_out != NULL, "logits_out is NULL");

int32_t n_layer = ctx->model->n_layer;
int32_t n_embed = ctx->model->n_embed;
int32_t n_vocab = ctx->model->n_vocab;
uint32_t n_layer = ctx->model->n_layer;
uint32_t n_embed = ctx->model->n_embed;
uint32_t n_vocab = ctx->model->n_vocab;

RWKV_ASSERT_FALSE(token >= 0 && token < n_vocab, "Token is out of range 0..%d", n_vocab - 1);
RWKV_ASSERT_FALSE(token < (uint32_t) n_vocab, "Token is out of range 0..%d", n_vocab - 1);

ggml_set_i32_1d(ctx->token_index, 0, token);

if (state_in == NULL) {
ggml_set_f32(ctx->state, 0.0F);

for (int i = 0; i < n_layer; i++) {
for (uint32_t i = 0; i < n_layer; i++) {
// state[5 * i + 4] = -1e30
ggml_set_f32(
ggml_view_1d(ctx->ctx, ctx->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)),
Expand All @@ -586,7 +583,7 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float

ggml_graph_compute(ctx->ctx, ctx->graph);

for (size_t i = 0; i < size_t(n_layer * 5); i++) {
for (uint32_t i = 0; i < n_layer * 5; i++) {
struct ggml_tensor * part = ctx->state_parts[i];

memcpy(state_out + i * n_embed, part->data, part->ne[0] * sizeof(float));
Expand Down
8 changes: 4 additions & 4 deletions rwkv.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,21 @@ extern "C" {
// Returns NULL on any error. Error messages would be printed to stderr.
// - model_file_path: path to model file in ggml format.
// - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, uint32_t n_threads);
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);

// Evaluates the model for a single token.
// Returns false on any error. Error messages would be printed to stderr.
// - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
RWKV_API bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float * state_out, float * logits_out);
RWKV_API bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out);

// Returns count of FP32 elements in state buffer.
RWKV_API uint32_t rwkv_get_state_buffer_element_count(struct rwkv_context * ctx);
RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx);

// Returns count of FP32 elements in logits buffer.
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(struct rwkv_context * ctx);
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx);

// Frees all allocated memory and the context.
RWKV_API void rwkv_free(struct rwkv_context * ctx);
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ggml_basics.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

#define ASSERT_ELEMENT_F32(tensor, i, expected_value) {\
float actual = ((float *) tensor->data)[i];\
ASSERT(fabsf(actual - expected_value) <= 0.0000001F, "At %s[%d]: expected %f, actual %f", #tensor, i, expected_value, actual);\
ASSERT(fabsf(actual - expected_value) <= 0.0000001F, "At %s[%d]: expected %f, actual %f", #tensor, i, (double) expected_value, (double) actual);\
}

int main(int argc, const char ** argv) {
int main(void) {
struct ggml_init_params params = {
.mem_size = 16 * 1024,
.mem_buffer = NULL,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_tiny_rwkv.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ void test_model(const char * model_path, const float * expected_logits, const fl
fprintf(stderr, "Difference sum: %f\n", diff_sum);

// When something breaks, difference would be way more than 10
ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big difference %f, expected no more than %f", diff_sum, max_diff);
ASSERT(fabsf(diff_sum) <= fabsf(max_diff) + 0.01F, "Too big difference %f, expected no more than %f", (double) diff_sum, (double) max_diff);

rwkv_free(model);

free(state);
free(logits);
}

int main(int argc, const char ** argv) {
int main(void) {
fprintf(stderr, "System info: %s\n", rwkv_get_system_info_string());

float * expected_logits = malloc(sizeof(float) * N_VOCAB);
Expand Down

0 comments on commit 1fef5d1

Please sign in to comment.