From 84634c047a9831b16cdf1cc3f2626e0ef0b2373b Mon Sep 17 00:00:00 2001 From: LoganDark Date: Tue, 27 Jun 2023 02:27:55 -0700 Subject: [PATCH] Elide logits if the logits pointer parameter is NULL (#107) * Completely skip calculation of logits if nobody cares This speeds up sequence mode evaluations by up to 20% if you ingest a large prompt and then only retrieve the logits at the very end. Note that you must pass a NULL pointer to the logits parameter in order to take advantage of this optimization. * logits_out=NULL documentation --- rwkv.cpp | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- rwkv.h | 4 ++++ 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/rwkv.cpp b/rwkv.cpp index 7664698..c3b5443 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -683,6 +683,11 @@ struct rwkv_graph { // ggml_cgraph is so large that it can cause stack overflows if not stored on the heap std::unique_ptr cgraph; + + size_t pre_logits_nodes; + size_t pre_logits_leafs; + size_t post_logits_nodes; + size_t post_logits_leafs; }; // RWKV context for a specific instance. @@ -1126,7 +1131,12 @@ bool rwkv_build_serial_graph( struct rwkv_layer_state * inputs, struct rwkv_layer_state * outputs, struct ggml_tensor * logits, - struct ggml_cgraph * cgraph + struct ggml_cgraph * cgraph, + + size_t * const pre_logits_nodes, + size_t * const pre_logits_leafs, + size_t * const post_logits_nodes, + size_t * const post_logits_leafs ) { // x = self.w.emb.weight[token] struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, tokens); @@ -1149,12 +1159,18 @@ bool rwkv_build_serial_graph( ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp)); } + *pre_logits_nodes = cgraph->n_nodes; + *pre_logits_leafs = cgraph->n_leafs; + // x = self.layer_norm(x[-1,:], self.w.ln_out) x = rwkv_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias); // x = (self.w.head.weight @ x).float() ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits)); + *post_logits_nodes = cgraph->n_nodes; + *post_logits_leafs = cgraph->n_leafs; + return true; } @@ -1239,7 +1255,12 @@ bool rwkv_build_sequence_graph( struct rwkv_layer_state * inputs, struct rwkv_layer_state * outputs, struct ggml_tensor * logits, - struct ggml_cgraph * cgraph + struct ggml_cgraph * cgraph, + + size_t * const pre_logits_nodes, + size_t * const pre_logits_leafs, + size_t * const post_logits_nodes, + size_t * const post_logits_leafs ) { const uint32_t n_embed = model.header.n_embed; const size_t sequence_len = tokens->ne[0]; @@ -1278,12 +1299,18 @@ bool rwkv_build_sequence_graph( ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp)); } + *pre_logits_nodes = cgraph->n_nodes; + *pre_logits_leafs = cgraph->n_leafs; + // x = self.layer_norm(x[-1,:], self.w.ln_out) x = rwkv_layer_norm(ctx, ggml_view_1d(ctx, x, n_embed, n_embed * sizeof(float) * (sequence_len - 1)), model.ln_out_weight, model.ln_out_bias); // x = (self.w.head.weight @ x).float() ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits)); + *post_logits_nodes = cgraph->n_nodes; + *post_logits_leafs = cgraph->n_leafs; + return true; } @@ -1473,7 +1500,13 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptrn_threads = n_threads; - RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph(serial_graph.ctx.ctx, instance->model, serial_graph.tokens, inputs.get(), outputs.get(), logits, serial_graph.cgraph.get())); + + RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph( + serial_graph.ctx.ctx, instance->model, + serial_graph.tokens, inputs.get(), outputs.get(), logits, + serial_graph.cgraph.get(), + &serial_graph.pre_logits_nodes, &serial_graph.pre_logits_leafs, &serial_graph.post_logits_nodes, &serial_graph.post_logits_leafs + )); std::unique_ptr rwkv_ctx(new(std::nothrow) struct rwkv_context()); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx, "Failed to allocate rwkv_context"); @@ -1568,6 +1601,16 @@ bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * st rwkv_set_inputs(ctx, state_in); ggml_set_i32(ctx->serial_graph.tokens, token); + + // Short circuit computation of logits if nobody actually cares + if (!logits_out) { + ctx->serial_graph.cgraph->n_nodes = ctx->serial_graph.pre_logits_nodes; + ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.pre_logits_leafs; + } else { + ctx->serial_graph.cgraph->n_nodes = ctx->serial_graph.post_logits_nodes; + ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs; + } + ggml_graph_compute(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get()); rwkv_get_outputs(ctx, state_out, logits_out); @@ -1631,7 +1674,13 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, co sequence_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, sequence_graph.cgraph, "Failed to allocate sequence graph"); sequence_graph.cgraph->n_threads = 1; - RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph(sequence_graph.ctx.ctx, ctx->instance->model, sequence_graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits, sequence_graph.cgraph.get())); + + RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph( + sequence_graph.ctx.ctx, ctx->instance->model, + sequence_graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits, + sequence_graph.cgraph.get(), + &sequence_graph.pre_logits_nodes, &sequence_graph.pre_logits_leafs, &sequence_graph.post_logits_nodes, &sequence_graph.post_logits_leafs + )); ctx->sequence_len = sequence_len; ctx->sequence_graph = std::move(sequence_graph); @@ -1641,6 +1690,16 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, co if (sequence) { rwkv_set_inputs(ctx, state_in); memcpy(ctx->sequence_graph.tokens->data, sequence, sequence_len * sizeof(uint32_t)); + + // Short circuit computation of logits if nobody actually cares + if (!logits_out) { + ctx->sequence_graph.cgraph->n_nodes = ctx->sequence_graph.pre_logits_nodes; + ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.pre_logits_leafs; + } else { + ctx->sequence_graph.cgraph->n_nodes = ctx->sequence_graph.post_logits_nodes; + ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs; + } + ggml_graph_compute(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get()); rwkv_get_outputs(ctx, state_out, logits_out); } diff --git a/rwkv.h b/rwkv.h index 493dffe..b1ada36 100644 --- a/rwkv.h +++ b/rwkv.h @@ -105,6 +105,8 @@ extern "C" { // Evaluates the model for a single token. // Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. // Returns false on any error. + // You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration + // that you do not calculate logits. // - token: next token index, in range 0 <= token < n_vocab. // - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass. // - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. @@ -116,6 +118,8 @@ extern "C" { // Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length. // Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. // Returns false on any error. + // You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration + // that you do not calculate logits. // - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization. // - sequence_len: number of tokens to read from the array. // - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.