Skip to content

Commit

Permalink
Move logit tensor allocation below graph reset
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Jun 10, 2023
1 parent 022b6dc commit 30749ba
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1181,8 +1181,6 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr<struct rwkv_instance
output_state.att_pp = ggml_view_1d(ctx.ctx, output, n_embed, n_embed * (i * 5 + 4) * sizeof(float));
}

struct ggml_tensor * logits = ggml_new_tensor_1d(ctx.ctx, GGML_TYPE_F32, n_vocab);

struct rwkv_ctx_size graph_ctx_size;
/* token */ rwkv_ctx_size_add_objects(graph_ctx_size, 1, sizeof(struct ggml_tensor) + sizeof(uint32_t));
/* graph */ rwkv_ctx_size_add(graph_ctx_size, 1, rwkv_ser_graph_size(n_vocab, n_embed, n_layer, instance->ffn_key_size));
Expand All @@ -1195,6 +1193,9 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr<struct rwkv_instance
graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph());
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, graph.cgraph, "Failed to allocate serial graph");
graph.cgraph->n_threads = n_threads;

struct ggml_tensor * logits = ggml_new_tensor_1d(ctx.ctx, GGML_TYPE_F32, n_vocab);

RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_ser_graph(graph.ctx.ctx, instance->model, graph.tokens, inputs.get(), outputs.get(), logits, graph.cgraph.get()));

std::unique_ptr<struct rwkv_context> rwkv_ctx(new(std::nothrow) struct rwkv_context());
Expand Down

0 comments on commit 30749ba

Please sign in to comment.