Skip to content

Commit

Permalink
Completely skip calculation of logits if nobody cares
Browse files Browse the repository at this point in the history
This speeds up sequence mode evaluations by up to 20% if you ingest
a large prompt and then only retrieve the logits at the very end.

Note that you must pass a NULL pointer to the logits parameter in
order to take advantage of this optimization.
  • Loading branch information
LoganDark committed Jun 25, 2023
1 parent b72abcd commit 185243a
Showing 1 changed file with 63 additions and 4 deletions.
67 changes: 63 additions & 4 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,11 @@ struct rwkv_graph {

// ggml_cgraph is so large that it can cause stack overflows if not stored on the heap
std::unique_ptr<struct ggml_cgraph> cgraph;

size_t pre_logits_nodes;
size_t pre_logits_leafs;
size_t post_logits_nodes;
size_t post_logits_leafs;
};

// RWKV context for a specific instance.
Expand Down Expand Up @@ -1126,7 +1131,12 @@ bool rwkv_build_serial_graph(
struct rwkv_layer_state * inputs,
struct rwkv_layer_state * outputs,
struct ggml_tensor * logits,
struct ggml_cgraph * cgraph
struct ggml_cgraph * cgraph,

size_t * const pre_logits_nodes,
size_t * const pre_logits_leafs,
size_t * const post_logits_nodes,
size_t * const post_logits_leafs
) {
// x = self.w.emb.weight[token]
struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, tokens);
Expand All @@ -1149,12 +1159,18 @@ bool rwkv_build_serial_graph(
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp));
}

*pre_logits_nodes = cgraph->n_nodes;
*pre_logits_leafs = cgraph->n_leafs;

// x = self.layer_norm(x[-1,:], self.w.ln_out)
x = rwkv_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias);

// x = (self.w.head.weight @ x).float()
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits));

*post_logits_nodes = cgraph->n_nodes;
*post_logits_leafs = cgraph->n_leafs;

return true;
}

Expand Down Expand Up @@ -1239,7 +1255,12 @@ bool rwkv_build_sequence_graph(
struct rwkv_layer_state * inputs,
struct rwkv_layer_state * outputs,
struct ggml_tensor * logits,
struct ggml_cgraph * cgraph
struct ggml_cgraph * cgraph,

size_t * const pre_logits_nodes,
size_t * const pre_logits_leafs,
size_t * const post_logits_nodes,
size_t * const post_logits_leafs
) {
const uint32_t n_embed = model.header.n_embed;
const size_t sequence_len = tokens->ne[0];
Expand Down Expand Up @@ -1278,12 +1299,18 @@ bool rwkv_build_sequence_graph(
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp));
}

*pre_logits_nodes = cgraph->n_nodes;
*pre_logits_leafs = cgraph->n_leafs;

// x = self.layer_norm(x[-1,:], self.w.ln_out)
x = rwkv_layer_norm(ctx, ggml_view_1d(ctx, x, n_embed, n_embed * sizeof(float) * (sequence_len - 1)), model.ln_out_weight, model.ln_out_bias);

// x = (self.w.head.weight @ x).float()
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits));

*post_logits_nodes = cgraph->n_nodes;
*post_logits_leafs = cgraph->n_leafs;

return true;
}

Expand Down Expand Up @@ -1473,7 +1500,13 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr<struct rwkv_instance
serial_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, serial_graph.cgraph, "Failed to allocate serial graph");
serial_graph.cgraph->n_threads = n_threads;
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph(serial_graph.ctx.ctx, instance->model, serial_graph.tokens, inputs.get(), outputs.get(), logits, serial_graph.cgraph.get()));

RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph(
serial_graph.ctx.ctx, instance->model,
serial_graph.tokens, inputs.get(), outputs.get(), logits,
serial_graph.cgraph.get(),
&serial_graph.pre_logits_nodes, &serial_graph.pre_logits_leafs, &serial_graph.post_logits_nodes, &serial_graph.post_logits_leafs
));

std::unique_ptr<struct rwkv_context> rwkv_ctx(new(std::nothrow) struct rwkv_context());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx, "Failed to allocate rwkv_context");
Expand Down Expand Up @@ -1568,6 +1601,16 @@ bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * st

rwkv_set_inputs(ctx, state_in);
ggml_set_i32(ctx->serial_graph.tokens, token);

// Short circuit computation of logits if nobody actually cares
if (!logits_out) {
ctx->serial_graph.cgraph->n_nodes = ctx->serial_graph.pre_logits_nodes;
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.pre_logits_leafs;
} else {
ctx->serial_graph.cgraph->n_nodes = ctx->serial_graph.post_logits_nodes;
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs;
}

ggml_graph_compute(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get());
rwkv_get_outputs(ctx, state_out, logits_out);

Expand Down Expand Up @@ -1631,7 +1674,13 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, co
sequence_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph());
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, sequence_graph.cgraph, "Failed to allocate sequence graph");
sequence_graph.cgraph->n_threads = 1;
RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph(sequence_graph.ctx.ctx, ctx->instance->model, sequence_graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits, sequence_graph.cgraph.get()));

RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph(
sequence_graph.ctx.ctx, ctx->instance->model,
sequence_graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits,
sequence_graph.cgraph.get(),
&sequence_graph.pre_logits_nodes, &sequence_graph.pre_logits_leafs, &sequence_graph.post_logits_nodes, &sequence_graph.post_logits_leafs
));

ctx->sequence_len = sequence_len;
ctx->sequence_graph = std::move(sequence_graph);
Expand All @@ -1641,6 +1690,16 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, co
if (sequence) {
rwkv_set_inputs(ctx, state_in);
memcpy(ctx->sequence_graph.tokens->data, sequence, sequence_len * sizeof(uint32_t));

// Short circuit computation of logits if nobody actually cares
if (!logits_out) {
ctx->sequence_graph.cgraph->n_nodes = ctx->sequence_graph.pre_logits_nodes;
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.pre_logits_leafs;
} else {
ctx->sequence_graph.cgraph->n_nodes = ctx->sequence_graph.post_logits_nodes;
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs;
}

ggml_graph_compute(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get());
rwkv_get_outputs(ctx, state_out, logits_out);
}
Expand Down

0 comments on commit 185243a

Please sign in to comment.