Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Elide logits if the logits pointer parameter is NULL #107

Merged
merged 2 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
67 changes: 63 additions & 4 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,11 @@ struct rwkv_graph {

// ggml_cgraph is so large that it can cause stack overflows if not stored on the heap
std::unique_ptr<struct ggml_cgraph> cgraph;

size_t pre_logits_nodes;
size_t pre_logits_leafs;
size_t post_logits_nodes;
size_t post_logits_leafs;
};

// RWKV context for a specific instance.
Expand Down Expand Up @@ -1126,7 +1131,12 @@ bool rwkv_build_serial_graph(
struct rwkv_layer_state * inputs,
struct rwkv_layer_state * outputs,
struct ggml_tensor * logits,
struct ggml_cgraph * cgraph
struct ggml_cgraph * cgraph,

size_t * const pre_logits_nodes,
size_t * const pre_logits_leafs,
size_t * const post_logits_nodes,
size_t * const post_logits_leafs
) {
// x = self.w.emb.weight[token]
struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, tokens);
Expand All @@ -1149,12 +1159,18 @@ bool rwkv_build_serial_graph(
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp));
}

*pre_logits_nodes = cgraph->n_nodes;
*pre_logits_leafs = cgraph->n_leafs;

// x = self.layer_norm(x[-1,:], self.w.ln_out)
x = rwkv_layer_norm(ctx, x, model.ln_out_weight, model.ln_out_bias);

// x = (self.w.head.weight @ x).float()
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits));

*post_logits_nodes = cgraph->n_nodes;
*post_logits_leafs = cgraph->n_leafs;

return true;
}

Expand Down Expand Up @@ -1239,7 +1255,12 @@ bool rwkv_build_sequence_graph(
struct rwkv_layer_state * inputs,
struct rwkv_layer_state * outputs,
struct ggml_tensor * logits,
struct ggml_cgraph * cgraph
struct ggml_cgraph * cgraph,

size_t * const pre_logits_nodes,
size_t * const pre_logits_leafs,
size_t * const post_logits_nodes,
size_t * const post_logits_leafs
) {
const uint32_t n_embed = model.header.n_embed;
const size_t sequence_len = tokens->ne[0];
Expand Down Expand Up @@ -1278,12 +1299,18 @@ bool rwkv_build_sequence_graph(
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, state.att_pp, output.att_pp));
}

*pre_logits_nodes = cgraph->n_nodes;
*pre_logits_leafs = cgraph->n_leafs;

// x = self.layer_norm(x[-1,:], self.w.ln_out)
x = rwkv_layer_norm(ctx, ggml_view_1d(ctx, x, n_embed, n_embed * sizeof(float) * (sequence_len - 1)), model.ln_out_weight, model.ln_out_bias);

// x = (self.w.head.weight @ x).float()
ggml_build_forward_expand(cgraph, ggml_cpy(ctx, ggml_mul_mat(ctx, model.head, x), logits));

*post_logits_nodes = cgraph->n_nodes;
*post_logits_leafs = cgraph->n_leafs;

return true;
}

Expand Down Expand Up @@ -1473,7 +1500,13 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr<struct rwkv_instance
serial_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, serial_graph.cgraph, "Failed to allocate serial graph");
serial_graph.cgraph->n_threads = n_threads;
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph(serial_graph.ctx.ctx, instance->model, serial_graph.tokens, inputs.get(), outputs.get(), logits, serial_graph.cgraph.get()));

RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph(
serial_graph.ctx.ctx, instance->model,
serial_graph.tokens, inputs.get(), outputs.get(), logits,
serial_graph.cgraph.get(),
&serial_graph.pre_logits_nodes, &serial_graph.pre_logits_leafs, &serial_graph.post_logits_nodes, &serial_graph.post_logits_leafs
));

std::unique_ptr<struct rwkv_context> rwkv_ctx(new(std::nothrow) struct rwkv_context());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, rwkv_ctx, "Failed to allocate rwkv_context");
Expand Down Expand Up @@ -1568,6 +1601,16 @@ bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * st

rwkv_set_inputs(ctx, state_in);
ggml_set_i32(ctx->serial_graph.tokens, token);

// Short circuit computation of logits if nobody actually cares
LoganDark marked this conversation as resolved.
Show resolved Hide resolved
if (!logits_out) {
ctx->serial_graph.cgraph->n_nodes = ctx->serial_graph.pre_logits_nodes;
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.pre_logits_leafs;
} else {
ctx->serial_graph.cgraph->n_nodes = ctx->serial_graph.post_logits_nodes;
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs;
}

ggml_graph_compute(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get());
rwkv_get_outputs(ctx, state_out, logits_out);

Expand Down Expand Up @@ -1631,7 +1674,13 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, co
sequence_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph());
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, sequence_graph.cgraph, "Failed to allocate sequence graph");
sequence_graph.cgraph->n_threads = 1;
RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph(sequence_graph.ctx.ctx, ctx->instance->model, sequence_graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits, sequence_graph.cgraph.get()));

RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph(
sequence_graph.ctx.ctx, ctx->instance->model,
sequence_graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits,
sequence_graph.cgraph.get(),
&sequence_graph.pre_logits_nodes, &sequence_graph.pre_logits_leafs, &sequence_graph.post_logits_nodes, &sequence_graph.post_logits_leafs
));

ctx->sequence_len = sequence_len;
ctx->sequence_graph = std::move(sequence_graph);
Expand All @@ -1641,6 +1690,16 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, co
if (sequence) {
rwkv_set_inputs(ctx, state_in);
memcpy(ctx->sequence_graph.tokens->data, sequence, sequence_len * sizeof(uint32_t));

// Short circuit computation of logits if nobody actually cares
if (!logits_out) {
ctx->sequence_graph.cgraph->n_nodes = ctx->sequence_graph.pre_logits_nodes;
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.pre_logits_leafs;
} else {
ctx->sequence_graph.cgraph->n_nodes = ctx->sequence_graph.post_logits_nodes;
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs;
}

ggml_graph_compute(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get());
rwkv_get_outputs(ctx, state_out, logits_out);
}
Expand Down
4 changes: 4 additions & 0 deletions rwkv.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ extern "C" {
// Evaluates the model for a single token.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error.
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration
// that you do not calculate logits.
// - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
Expand All @@ -116,6 +118,8 @@ extern "C" {
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error.
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration
// that you do not calculate logits.
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
// - sequence_len: number of tokens to read from the array.
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
Expand Down