Skip to content

Commit

Permalink
Elide logits if the logits pointer parameter is NULL (#107)
Browse files Browse the repository at this point in the history
* Completely skip calculation of logits if nobody cares

This speeds up sequence mode evaluations by up to 20% if you ingest
a large prompt and then only retrieve the logits at the very end.

Note that you must pass a NULL pointer to the logits parameter in
order to take advantage of this optimization.

* logits_out=NULL documentation
  • Loading branch information
LoganDark committed Jun 27, 2023
1 parent ffc085c commit 84634c0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
67 changes: 63 additions & 4 deletions rwkv.cpp
Expand Up @@ -683,6 +683,11 @@ struct rwkv_graph {

// ggml_cgraph is so large that it can cause stack overflows if not stored on the heap
std::unique_ptr<struct ggml_cgraph> cgraph;

size_t pre_logits_nodes;
size_t pre_logits_leafs;
size_t post_logits_nodes;
size_t post_logits_leafs;
};

// RWKV context for a specific instance.
Expand Down Expand Up @@ -1126,7 +1131,12 @@ bool rwkv_build_serial_graph(
struct rwkv_layer_state * inputs,
struct rwkv_layer_state * outputs,
struct ggml_tensor * logits,
struct ggml_cgraph * cgraph
struct ggml_cgraph * cgraph,

size_t * const pre_logits_nodes,
size_t * const pre_logits_leafs,
size_t * const post_logits_nodes,
size_t * const post_logits_leafs
) {
// x = self.w.emb.weight[token]
struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, tokens);
Expand All @@ -1149,12 +1159,18 @@ bool rwkv_build_serial_graph(
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp));
}

*pre_logits_nodes = cgraph->n_nodes;
*pre_logits_leafs = cgraph->n_leafs;

// x = self.layer_norm(x[-1,:], self.w.ln_out)
x = rwkv_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias);

// x = (self.w.head.weight @ x).float()
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits));

*post_logits_nodes = cgraph->n_nodes;
*post_logits_leafs = cgraph->n_leafs;

return true;
}

Expand Down Expand Up @@ -1239,7 +1255,12 @@ bool rwkv_build_sequence_graph(
struct rwkv_layer_state * inputs,
struct rwkv_layer_state * outputs,
struct ggml_tensor * logits,
struct ggml_cgraph * cgraph
struct ggml_cgraph * cgraph,

size_t * const pre_logits_nodes,
size_t * const pre_logits_leafs,
size_t * const post_logits_nodes,
size_t * const post_logits_leafs
) {
const uint32_t n_embed = model.header.n_embed;
const size_t sequence_len = tokens->ne[0];
Expand Down Expand Up @@ -1278,12 +1299,18 @@ bool rwkv_build_sequence_graph(
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp));
}

*pre_logits_nodes = cgraph->n_nodes;
*pre_logits_leafs = cgraph->n_leafs;

// x = self.layer_norm(x[-1,:], self.w.ln_out)
x = rwkv_layer_norm(ctx, ggml_view_1d(ctx, x, n_embed, n_embed * sizeof(float) * (sequence_len - 1)), model.ln_out_weight, model.ln_out_bias);

// x = (self.w.head.weight @ x).float()
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits));

*post_logits_nodes = cgraph->n_nodes;
*post_logits_leafs = cgraph->n_leafs;

return true;
}

Expand Down Expand Up @@ -1473,7 +1500,13 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr<struct rwkv_instance
serial_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, serial_graph.cgraph, "Failed to allocate serial graph");
serial_graph.cgraph->n_threads = n_threads;
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph(serial_graph.ctx.ctx, instance->model, serial_graph.tokens, inputs.get(), outputs.get(), logits, serial_graph.cgraph.get()));

RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph(
serial_graph.ctx.ctx, instance->model,
serial_graph.tokens, inputs.get(), outputs.get(), logits,
serial_graph.cgraph.get(),
&serial_graph.pre_logits_nodes, &serial_graph.pre_logits_leafs, &serial_graph.post_logits_nodes, &serial_graph.post_logits_leafs
));

std::unique_ptr<struct rwkv_context> rwkv_ctx(new(std::nothrow) struct rwkv_context());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx, "Failed to allocate rwkv_context");
Expand Down Expand Up @@ -1568,6 +1601,16 @@ bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * st

rwkv_set_inputs(ctx, state_in);
ggml_set_i32(ctx->serial_graph.tokens, token);

// Short circuit computation of logits if nobody actually cares
if (!logits_out) {
ctx->serial_graph.cgraph->n_nodes = ctx->serial_graph.pre_logits_nodes;
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.pre_logits_leafs;
} else {
ctx->serial_graph.cgraph->n_nodes = ctx->serial_graph.post_logits_nodes;
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs;
}

ggml_graph_compute(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get());
rwkv_get_outputs(ctx, state_out, logits_out);

Expand Down Expand Up @@ -1631,7 +1674,13 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, co
sequence_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph());
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, sequence_graph.cgraph, "Failed to allocate sequence graph");
sequence_graph.cgraph->n_threads = 1;
RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph(sequence_graph.ctx.ctx, ctx->instance->model, sequence_graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits, sequence_graph.cgraph.get()));

RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph(
sequence_graph.ctx.ctx, ctx->instance->model,
sequence_graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits,
sequence_graph.cgraph.get(),
&sequence_graph.pre_logits_nodes, &sequence_graph.pre_logits_leafs, &sequence_graph.post_logits_nodes, &sequence_graph.post_logits_leafs
));

ctx->sequence_len = sequence_len;
ctx->sequence_graph = std::move(sequence_graph);
Expand All @@ -1641,6 +1690,16 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, co
if (sequence) {
rwkv_set_inputs(ctx, state_in);
memcpy(ctx->sequence_graph.tokens->data, sequence, sequence_len * sizeof(uint32_t));

// Short circuit computation of logits if nobody actually cares
if (!logits_out) {
ctx->sequence_graph.cgraph->n_nodes = ctx->sequence_graph.pre_logits_nodes;
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.pre_logits_leafs;
} else {
ctx->sequence_graph.cgraph->n_nodes = ctx->sequence_graph.post_logits_nodes;
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs;
}

ggml_graph_compute(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get());
rwkv_get_outputs(ctx, state_out, logits_out);
}
Expand Down
4 changes: 4 additions & 0 deletions rwkv.h
Expand Up @@ -105,6 +105,8 @@ extern "C" {
// Evaluates the model for a single token.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error.
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration
// that you do not calculate logits.
// - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
Expand All @@ -116,6 +118,8 @@ extern "C" {
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error.
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration
// that you do not calculate logits.
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
// - sequence_len: number of tokens to read from the array.
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
Expand Down

0 comments on commit 84634c0

Please sign in to comment.