Skip to content

Commit

Permalink
Add more debug
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Jun 10, 2023
1 parent 2dede61 commit ca1e3e2
Showing 1 changed file with 58 additions and 56 deletions.
114 changes: 58 additions & 56 deletions tests/test_context_cloning.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,63 @@
#include <stdio.h>
#include <string.h>

int main() {
struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2);

if (!ctx) {
enum rwkv_error_flags error = rwkv_get_last_error(NULL);
fprintf(stderr, "Unexpected error 0x%.8X\n", error);
return EXIT_FAILURE;
}

float * state = calloc(rwkv_get_state_buffer_element_count(ctx), sizeof(float));
float * logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));

if (!state || !logits) {
fprintf(stderr, "Failed to allocate state/logits\n");
return EXIT_FAILURE;
}

// 0xd1 or 209 is space (0x20 or \u0120 in tokenizer)
const unsigned char * prompt = "hello\xd1world";

rwkv_eval(ctx, prompt[0], NULL, state, logits);

for (const unsigned char * token = prompt + 1; *token != 0; token++) {
rwkv_eval(ctx, *token, state, state, logits);
}

float * expected_logits = logits;
logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
#define L() { fprintf(stderr, "L%d\n", __LINE__); }

if (!logits) {
fprintf(stderr, "Failed to allocate state/logits\n");
return EXIT_FAILURE;
}

struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2);

rwkv_eval(ctx, prompt[0], NULL, state, logits);

for (const unsigned char * token = prompt + 1; *token != 0; token++) {
rwkv_eval(ctx, *token, state, state, logits);
}

if (memcmp(expected_logits, logits, rwkv_get_logits_buffer_element_count(ctx) * sizeof(float))) {
fprintf(stderr, "results not identical :(\n");
return EXIT_FAILURE;
} else {
fprintf(stdout, "Results identical, success!\n");
}

rwkv_free(ctx);
rwkv_free(ctx2);

free(expected_logits);
free(logits);
free(state);

return EXIT_SUCCESS;
int main() {
L(); struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2);
L();
L(); if (!ctx) {
L(); enum rwkv_error_flags error = rwkv_get_last_error(NULL);
L(); fprintf(stderr, "Unexpected error 0x%.8X\n", error);
L(); return EXIT_FAILURE;
L(); }
L();
L(); float * state = calloc(rwkv_get_state_buffer_element_count(ctx), sizeof(float));
L(); float * logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
L();
L(); if (!state || !logits) {
L(); fprintf(stderr, "Failed to allocate state/logits\n");
L(); return EXIT_FAILURE;
L(); }
L();
L(); // 0xd1 or 209 is space (0x20 or \u0120 in tokenizer)
L(); const unsigned char * prompt = "hello\xd1world";
L();
L(); rwkv_eval(ctx, prompt[0], NULL, state, logits);
L();
L(); for (const unsigned char * token = prompt + 1; *token != 0; token++) {
L(); rwkv_eval(ctx, *token, state, state, logits);
L(); }
L();
L(); float * expected_logits = logits;
L(); logits = calloc(rwkv_get_logits_buffer_element_count(ctx), sizeof(float));
L();
L(); if (!logits) {
L(); fprintf(stderr, "Failed to allocate state/logits\n");
L(); return EXIT_FAILURE;
L(); }
L();
L(); struct rwkv_context * ctx2 = rwkv_clone_context(ctx, 2);
L();
L(); rwkv_eval(ctx, prompt[0], NULL, state, logits);
L();
L(); for (const unsigned char * token = prompt + 1; *token != 0; token++) {
L(); rwkv_eval(ctx, *token, state, state, logits);
L(); }
L();
L(); if (memcmp(expected_logits, logits, rwkv_get_logits_buffer_element_count(ctx) * sizeof(float))) {
L(); fprintf(stderr, "results not identical :(\n");
L(); return EXIT_FAILURE;
L(); } else {
L(); fprintf(stdout, "Results identical, success!\n");
L(); }
L();
L(); rwkv_free(ctx);
L(); rwkv_free(ctx2);
L();
L(); free(expected_logits);
L(); free(logits);
L(); free(state);
L();
L(); return EXIT_SUCCESS;
}

0 comments on commit ca1e3e2

Please sign in to comment.