diff --git a/rwkv.cpp b/rwkv.cpp index 77a3f95..27055fc 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -244,22 +244,23 @@ struct ggml_tensor * rwkv_max(ggml_context * ctx, struct ggml_tensor * x, struct struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) { // LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias` // Looks like ggml_norm does the first part, we only need to apply weight & bias. - x = ggml_norm(ctx, x); - x = ggml_mul(ctx, x, weight); - x = ggml_add_inplace(ctx, x, bias); - return x; + return ggml_add_inplace(ctx, ggml_mul(ctx, ggml_norm(ctx, x), weight), bias); } // --- Implementation --- -struct rwkv_context { - std::unique_ptr model; - struct ggml_tensor * token_index; +struct rwkv_graph { struct ggml_tensor * state; - struct ggml_tensor ** state_parts; + std::unique_ptr state_parts; + struct ggml_tensor * token_index; struct ggml_tensor * logits; + std::unique_ptr cgraph; +}; + +struct rwkv_context { + std::unique_ptr model; struct ggml_context * ctx; - std::unique_ptr graph; + struct rwkv_graph graph; enum rwkv_error_flags last_error; bool print_errors; }; @@ -280,6 +281,164 @@ enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) { return value; } +bool rwkv_build_graph(struct ggml_context * ctx, struct rwkv_model * model, const uint32_t n_threads, struct rwkv_graph * out) { + std::unique_ptr cgraph(new(std::nothrow) struct ggml_cgraph()); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, cgraph.get(), "Failed to allocate graph"); + cgraph->n_threads = n_threads; + + size_t n_embed = model->n_embed, n_layer = model->n_layer; + struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed); + + // We collect parts of new state here. Each part is (n_embed) vector. + std::unique_ptr state_parts(new(std::nothrow) ggml_tensor * [n_layer * 5]); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, state_parts.get(), "Failed to allocate state parts"); + + // x = self.w.emb.weight[token] + struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index); + + // x = self.layer_norm(x, self.w.blocks[0].ln0) + x = rwkv_layer_norm(ctx, x, model->ln0_weight, model->ln0_bias); + + for (size_t i = 0; i < n_layer; i++) { + struct rwkv_layer layer = model->layers[i]; + size_t part_index = i * 5; + size_t state_part_size = n_embed * sizeof(float); + + // RWKV/time mixing + { + // self.layer_norm(x, self.w.blocks[i].ln1) + struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); + + // x0 = state[5 * i + 1] + struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (part_index + 1) * state_part_size); + // aa = state[5 * i + 2] + struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (part_index + 2) * state_part_size); + // bb = state[5 * i + 3] + struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (part_index + 3) * state_part_size); + // pp = state[5 * i + 4] + struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (part_index + 4) * state_part_size); + + // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) + struct ggml_tensor * xk = ggml_add_inplace(ctx, + ggml_mul(ctx, x0, layer.att_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) + ); + + // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) + struct ggml_tensor * xv = ggml_add_inplace(ctx, + ggml_mul(ctx, x0, layer.att_time_mix_v), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) + ); + + // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) + struct ggml_tensor * xr = ggml_add_inplace(ctx, + ggml_mul(ctx, x0, layer.att_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) + ); + + // r = torch.sigmoid(rw @ xr) + struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr)); + // k = kw @ xk + struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk); + // v = vw @ xv + struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv); + + // ww = time_first + k + struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); + // qq = torch.maximum(pp, ww) + struct ggml_tensor * qq = rwkv_max(ctx, pp, ww); + // e1 = torch.exp(pp - qq) + struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq)); + // e2 = torch.exp(ww - qq) + struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); + + // a = e1 * aa + e2 * v + struct ggml_tensor * a = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); + // b = e1 * bb + e2 + struct ggml_tensor * b = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); + + // ww = pp + time_decay + ww = ggml_add_inplace(ctx, pp, layer.att_time_decay); + // qq = torch.maximum(ww, k) + qq = rwkv_max(ctx, ww, k); + // e1 = torch.exp(ww - qq) + e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); + // e2 = torch.exp(k - qq) + e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq)); + + // state[5 * i + 1] = x0 + // state[5 * i + 2] = e1 * aa + e2 * v + // state[5 * i + 3] = e1 * bb + e2 + // state[5 * i + 4] = qq + + state_parts[part_index + 1] = x0; + state_parts[part_index + 2] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v)); + state_parts[part_index + 3] = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2); + state_parts[part_index + 4] = qq; + + // wkv = a / b + struct ggml_tensor * wkv = ggml_div(ctx, a, b); + + // ow @ (r * wkv) + x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv))); + } + + // FFN/channel mixing + { + // self.layer_norm(x, self.w.blocks[i].ln2) + struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); + + // x_prev = state[5 * i + 0] + struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, part_index * state_part_size); + + // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) + struct ggml_tensor * xk = ggml_add_inplace( + ctx, + ggml_mul(ctx, x0, layer.ffn_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) + ); + + // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) + struct ggml_tensor * xr = ggml_add_inplace( + ctx, + ggml_mul(ctx, x0, layer.ffn_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) + ); + + // state[5 * i + 0] = x + state_parts[part_index] = x0; + + // r = torch.sigmoid(rw @ xr) + struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); + + // k = torch.square(torch.relu(kw @ xk)) + struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); + + // r * (vw @ k) + x = ggml_add_inplace(ctx, x, ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k))); + } + } + + // x = self.layer_norm(x, self.w.ln_out) + x = rwkv_layer_norm(ctx, x, model->ln_out_weight, model->ln_out_bias); + + // x = (self.w.head.weight @ x).float() + struct ggml_tensor * logits = ggml_mul_mat(ctx, model->head, x); + + ggml_build_forward_expand(cgraph.get(), logits); + + for (uint32_t i = 0; i < n_layer * 5; i++) + ggml_build_forward_expand(cgraph.get(), state_parts[i]); + + out->state = state; + out->state_parts = std::move(state_parts); + out->token_index = token_index; + out->logits = logits; + out->cgraph = std::move(cgraph); + return true; +} + struct rwkv_file_guard { FILE * file; ~rwkv_file_guard() { if (file) fclose(file); } @@ -293,7 +452,7 @@ struct rwkv_ggml_guard { struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t n_threads) { global_last_error = RWKV_ERROR_NONE; - FILE* file = fopen(file_path, "rb"); + FILE * file = fopen(file_path, "rb"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file, "Failed to open file %s", file_path); rwkv_file_guard file_guard { file }; @@ -316,7 +475,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_uint32(file, &model->n_layer, "n_layer")); RWKV_ASSERT_NULL(RWKV_ERROR_MODEL, read_int32(file, &model->data_type, "data_type")); - const char* unsupported_msg = "Models in %s format cannot be loaded anymore because the format was removed. You need to quantize the model into another format"; + const char * unsupported_msg = "Models in %s format cannot be loaded anymore because the format was removed. You need to quantize the model into another format"; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_DATA_TYPE, model->data_type >= 0 && model->data_type < FORMAT_TYPE_COUNT, "Unsupported model data type %d", model->data_type); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 4, unsupported_msg, "Q4_1_O"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL | RWKV_ERROR_UNSUPPORTED, model->data_type != 6, unsupported_msg, "Q4_3"); @@ -418,192 +577,16 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, const uint32_t RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model->n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model->n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]); - uint32_t n_embed = model->n_embed; - uint32_t n_layer = model->n_layer; + size_t n_embed = model->n_embed; + size_t n_layer = model->n_layer; // Build graph - struct ggml_tensor * state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_layer * 5 * n_embed); - - // x = self.w.emb.weight[token] - struct ggml_tensor * token_index = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - struct ggml_tensor * x = ggml_get_rows(ctx, model->emb, token_index); - - // x = self.layer_norm(x, self.w.blocks[0].ln0) - x = rwkv_layer_norm(ctx, x, model->ln0_weight, model->ln0_bias); - - // We collect parts of new state here. Each part is (n_embed) vector. - struct ggml_tensor ** state_parts = new ggml_tensor * [n_layer * 5]; - - for (uint32_t i = 0; i < n_layer; i++) { - auto layer = model->layers[i]; - - // RWKV/time mixing - { - // self.layer_norm(x, self.w.blocks[i].ln1) - struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); - // state[5 * i + 1] - struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 1) * n_embed * sizeof(float)); - // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) - // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) - // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) - struct ggml_tensor * xk = ggml_add_inplace( - ctx, - ggml_mul(ctx, x0, layer.att_time_mix_k), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) - ); - struct ggml_tensor * xv = ggml_add_inplace( - ctx, - ggml_mul(ctx, x0, layer.att_time_mix_v), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) - ); - struct ggml_tensor * xr = ggml_add_inplace( - ctx, - ggml_mul(ctx, x0, layer.att_time_mix_r), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) - ); - // state[5 * i + 1] = x - state_parts[5 * i + 1] = x0; - - // r = torch.sigmoid(rw @ xr) - struct ggml_tensor * r = rwkv_sigmoid( - ctx, - ggml_mul_mat(ctx, layer.att_receptance, xr) - ); - // k = kw @ xk - struct ggml_tensor * k = ggml_mul_mat(ctx, layer.att_key, xk); - // v = vw @ xv - struct ggml_tensor * v = ggml_mul_mat(ctx, layer.att_value, xv); - - // aa = state[5 * i + 2] - // bb = state[5 * i + 3] - // pp = state[5 * i + 4] - struct ggml_tensor * aa = ggml_view_1d(ctx, state, n_embed, (5 * i + 2) * n_embed * sizeof(float)); - struct ggml_tensor * bb = ggml_view_1d(ctx, state, n_embed, (5 * i + 3) * n_embed * sizeof(float)); - struct ggml_tensor * pp = ggml_view_1d(ctx, state, n_embed, (5 * i + 4) * n_embed * sizeof(float)); - - // ww = time_first + k - struct ggml_tensor * ww = ggml_add(ctx, layer.att_time_first, k); - // qq = torch.maximum(pp, ww) - struct ggml_tensor * qq = rwkv_max(ctx, pp, ww); - // e1 = torch.exp(pp - qq) - struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq)); - // e2 = torch.exp(ww - qq) - struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); - // a = e1 * aa + e2 * v - struct ggml_tensor * a = ggml_add_inplace( - ctx, - ggml_mul(ctx, e1, aa), - ggml_mul(ctx, e2, v) - ); - // b = e1 * bb + e2 - struct ggml_tensor * b = ggml_add_inplace( - ctx, - ggml_mul(ctx, e1, bb), - e2 - ); - // wkv = a / b - struct ggml_tensor * wkv = ggml_div(ctx, a, b); - // ww = pp + time_decay - ww = ggml_add(ctx, pp, layer.att_time_decay); - // qq = torch.maximum(ww, k) - qq = rwkv_max(ctx, ww, k); - // e1 = torch.exp(ww - qq) - e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq)); - // e2 = torch.exp(k - qq) - e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq)); - // state[5 * i + 2] = e1 * aa + e2 * v - state_parts[5 * i + 2] = ggml_add_inplace( - ctx, - ggml_mul(ctx, e1, aa), - ggml_mul(ctx, e2, v) - ); - // state[5 * i + 3] = e1 * bb + e2 - state_parts[5 * i + 3] = ggml_add_inplace( - ctx, - ggml_mul(ctx, e1, bb), - e2 - ); - // state[5 * i + 4] = qq - state_parts[5 * i + 4] = qq; - // ow @ (r * wkv) - x = ggml_add_inplace( - ctx, - x, - ggml_mul_mat( - ctx, - layer.att_output, - ggml_mul(ctx, r, wkv) - ) - ); - } - - // FFN/channel mixing - { - // self.layer_norm(x, self.w.blocks[i].ln2) - struct ggml_tensor * x0 = rwkv_layer_norm(ctx, x, layer.ln2_weight, layer.ln2_bias); - // state[5 * i + 0] - struct ggml_tensor * x_prev = ggml_view_1d(ctx, state, n_embed, (5 * i + 0) * n_embed * sizeof(float)); - // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) - // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) - struct ggml_tensor * xk = ggml_add_inplace( - ctx, - ggml_mul(ctx, x0, layer.ffn_time_mix_k), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) - ); - struct ggml_tensor * xr = ggml_add_inplace( - ctx, - ggml_mul(ctx, x0, layer.ffn_time_mix_r), - ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) - ); - // state[5 * i + 0] = x - state_parts[5 * i + 0] = x0; - - // r = torch.sigmoid(rw @ xr) - struct ggml_tensor * r = rwkv_sigmoid( - ctx, - ggml_mul_mat(ctx, layer.ffn_receptance, xr) - ); - // k = torch.square(torch.relu(kw @ xk)) - struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu( - ctx, - ggml_mul_mat(ctx, layer.ffn_key, xk) - )); - // r * (vw @ k) - x = ggml_add_inplace( - ctx, - x, - ggml_mul( - ctx, - r, - ggml_mul_mat(ctx, layer.ffn_value, k) - ) - ); - } - } - - // x = self.layer_norm(x, self.w.ln_out) - x = rwkv_layer_norm(ctx, x, model->ln_out_weight, model->ln_out_bias); - - // x = (self.w.head.weight @ x).float() - struct ggml_tensor * logits = ggml_mul_mat(ctx, model->head, x); - - std::unique_ptr graph(new(std::nothrow) struct ggml_cgraph()); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_GRAPH | RWKV_ERROR_ALLOC, graph.get(), "Failed to allocate graph"); - - ggml_build_forward_expand(graph.get(), logits); - - for (uint32_t i = 0; i < n_layer * 5; i++) - ggml_build_forward_expand(graph.get(), state_parts[i]); - - graph->n_threads = n_threads; + struct rwkv_graph graph; + RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_graph(ctx, model.get(), n_threads, &graph)); std::unique_ptr rwkv_ctx(new(std::nothrow) struct rwkv_context()); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx.get(), "Failed to allocate context"); rwkv_ctx->model = std::move(model); - rwkv_ctx->token_index = token_index; - rwkv_ctx->state = state; - rwkv_ctx->state_parts = state_parts; - rwkv_ctx->logits = logits; rwkv_ctx->ctx = ctx; rwkv_ctx->graph = std::move(graph); rwkv_ctx->last_error = RWKV_ERROR_NONE; @@ -627,40 +610,40 @@ bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const floa RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, logits_out != NULL, "logits_out is NULL"); RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, token < ctx->model->n_vocab, "Token is out of range 0..%d", ctx->model->n_vocab - 1); - uint32_t n_layer = ctx->model->n_layer; - uint32_t n_embed = ctx->model->n_embed; + const struct rwkv_graph * graph = &ctx->graph; + size_t n_layer = ctx->model->n_layer; + size_t n_embed = ctx->model->n_embed; - ggml_set_i32_1d(ctx->token_index, 0, token); + ggml_set_i32_1d(graph->token_index, 0, token); if (state_in == NULL) { - ggml_set_f32(ctx->state, 0.0F); + ggml_set_f32(graph->state, 0.0F); - for (uint64_t i = 0; i < n_layer; i++) { + for (size_t i = 0; i < n_layer; i++) { // state[5 * i + 4] = -1e30 ggml_set_f32( - ggml_view_1d(ctx->ctx, ctx->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)), + ggml_view_1d(ctx->ctx, graph->state, n_embed, (5 * i + 4) * n_embed * sizeof(float)), -1e30F ); } } else { - memcpy(ctx->state->data, state_in, ctx->state->ne[0] * sizeof(float)); + memcpy(graph->state->data, state_in, graph->state->ne[0] * sizeof(float)); } - ggml_graph_compute(ctx->ctx, ctx->graph.get()); + ggml_graph_compute(ctx->ctx, graph->cgraph.get()); - for (uint32_t i = 0; i < n_layer * 5; i++) { - struct ggml_tensor * part = ctx->state_parts[i]; + for (size_t i = 0; i < n_layer * 5; i++) { + struct ggml_tensor * part = graph->state_parts[i]; memcpy(state_out + i * n_embed, part->data, part->ne[0] * sizeof(float)); } - memcpy(logits_out, ctx->logits->data, ctx->logits->ne[0] * sizeof(float)); + memcpy(logits_out, graph->logits->data, graph->logits->ne[0] * sizeof(float)); return true; } void rwkv_free(struct rwkv_context * ctx) { std::unique_ptr rwkv_ctx(ctx); - delete[] ctx->state_parts; ggml_free(ctx->ctx); } @@ -803,7 +786,7 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode // This is a histogramm of some values. If it shows single 1.0, then all 0.0, something went very wrong! std::vector hist_cur(1 << 4, 0); - size_t (*f)(const float* src, void* dst, int n, int k, int64_t* hist) = + size_t (*f)(const float * src, void * dst, int n, int k, int64_t * hist) = format_ggml_type == GGML_TYPE_Q4_0 ? ggml_quantize_q4_0 : format_ggml_type == GGML_TYPE_Q4_1 ? ggml_quantize_q4_1 : format_ggml_type == GGML_TYPE_Q4_2 ? ggml_quantize_q4_2 :