-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow creating multiple contexts per model (#83)
* Allow creating multiple contexts per model This allows for parallel inference and I am preparing to support sequence mode using a method similar to this * Fix cuBLAS * Update rwkv.h Co-authored-by: Alex <saharNooby@users.noreply.github.com> * Update rwkv.cpp Co-authored-by: Alex <saharNooby@users.noreply.github.com> * Inherit print_errors from parent ctx when cloning * Add context cloning test * Free * Free ggml context when last rwkv_context is freed * Free before exit * int main * add explanation of ffn_key_size * Update rwkv_instance and rwkv_context comments * Thread safety notes --------- Co-authored-by: Alex <saharNooby@users.noreply.github.com>
- Loading branch information
1 parent
363dfb1
commit 3f8bb2c
Showing
4 changed files
with
169 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#include <rwkv.h> | ||
|
||
#include <stdlib.h> | ||
#include <stdio.h> | ||
#include <string.h> | ||
|
||
int main() { | ||
struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); | ||
|
||
if (!ctx) { | ||
enum rwkv_error_flags error = rwkv_get_last_error(NULL); | ||
fprintf(stderr, "Unexpected error 0x%.8X\n", error); | ||
return EXIT_FAILURE; | ||
} | ||
|
||
float * state = calloc(rwkv_get_state_buffer_element_count(ctx), sizeof(float)); | ||
float * logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float)); | ||
|
||
if (!state || !logits) { | ||
fprintf(stderr, "Failed to allocate state/logits\n"); | ||
return EXIT_FAILURE; | ||
} | ||
|
||
// 0xd1 or 209 is space (0x20 or \u0120 in tokenizer) | ||
const unsigned char * prompt = "hello\xd1world"; | ||
|
||
rwkv_eval(ctx, prompt[0], NULL, state, logits); | ||
|
||
for (const unsigned char * token = prompt + 1; *token != 0; token++) { | ||
rwkv_eval(ctx, *token, state, state, logits); | ||
} | ||
|
||
float * expected_logits = logits; | ||
logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float)); | ||
|
||
if (!logits) { | ||
fprintf(stderr, "Failed to allocate state/logits\n"); | ||
return EXIT_FAILURE; | ||
} | ||
|
||
struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2); | ||
|
||
rwkv_eval(ctx, prompt[0], NULL, state, logits); | ||
|
||
for (const unsigned char * token = prompt + 1; *token != 0; token++) { | ||
rwkv_eval(ctx, *token, state, state, logits); | ||
} | ||
|
||
if (memcmp(expected_logits, logits, rwkv_get_logits_buffer_element_count(ctx) * sizeof(float))) { | ||
fprintf(stderr, "results not identical :(\n"); | ||
return EXIT_FAILURE; | ||
} else { | ||
fprintf(stdout, "Results identical, success!\n"); | ||
} | ||
|
||
rwkv_free(ctx); | ||
rwkv_free(ctx2); | ||
|
||
free(expected_logits); | ||
free(logits); | ||
free(state); | ||
|
||
return EXIT_SUCCESS; | ||
} |