-
Notifications
You must be signed in to change notification settings - Fork 76
/
rwkv.cpp
1593 lines (1244 loc) · 61.5 KB
/
rwkv.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include "rwkv.h"
// Fix build on Linux.
// https://stackoverflow.com/questions/8518264/where-is-the-declaration-of-cpu-alloc
#if defined(__linux__)
#define _GNU_SOURCE
#include <sched.h>
#endif
#include "ggml.h"
#include "ggml-alloc.h"
#ifdef GGML_USE_CUBLAS
#include "ggml/src/ggml-cuda.h"
#elif defined(GGML_USE_CLBLAST)
#include "ggml/src/ggml-opencl.h"
#endif
#include <string>
#include <vector>
#include <cstring>
#include <cinttypes>
#include <cmath>
#include <fstream>
#include <unordered_map>
#include <memory>
#include <utility>
#define _FILE_OFFSET_BITS 64
// Puts an optional break point, if debug is enabled.
#define RWKV_MAYBE_BREAK
#include <sys/stat.h>
#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__)
#define stat _stat64
#define fstat _fstat64
#define ftell _ftelli64
#define fseek _fseeki64
#ifndef NDEBUG
#include <intrin.h>
#define RWKV_MAYBE_BREAK __debugbreak()
#endif
#else
#if !defined(__APPLE__)
#define ftell ftello
#define fseek fseeko
#endif
#endif
static_assert(sizeof(stat::st_size) >= 8, "File offsets should be 64-bit or else rwkv.cpp will not be able to load model files over 2 GB");
static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit or else rwkv.cpp will not be able to load model files over 2 GB");
// --- Error handling ---
thread_local enum rwkv_error_flags global_last_error = RWKV_ERROR_NONE;
thread_local bool global_print_errors = true;
inline enum rwkv_error_flags operator|(enum rwkv_error_flags a, enum rwkv_error_flags b) {
return static_cast<enum rwkv_error_flags>(static_cast<int>(a) | static_cast<int>(b));
}
inline enum rwkv_error_flags operator|=(enum rwkv_error_flags & a, enum rwkv_error_flags b) {
return a = a | b;
}
#define RWKV_MSG(...) do { if (global_print_errors) fprintf(stderr, __VA_ARGS__); } while (0)
#define RWKV_CTX_MSG(ctx, ...) do { if (ctx->print_errors) fprintf(stderr, __VA_ARGS__); } while (0)
// If the condition x is false, adds ERR_VAL to the last error, and returns RET_VAL.
#define RWKV_ASSERT(ERR_VAL, RET_VAL, x) do { \
if (!(x)) { \
global_last_error |= ERR_VAL; \
RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \
RWKV_MAYBE_BREAK; \
return RET_VAL; \
} } while (0)
// If the condition x is false, adds ERR_VAL to the last error, prints a message to stderr, and returns RET_VAL.
#define RWKV_ASSERT_MSG(ERR_VAL, RET_VAL, x, ...) do { \
if (!(x)) { \
global_last_error |= ERR_VAL; \
RWKV_MSG(__VA_ARGS__); \
RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \
RWKV_MAYBE_BREAK; \
return RET_VAL; \
} } while (0)
// If the condition x is false, adds ERR_VAL to the ctx's last error, prints a message to stderr, and returns RET_VAL.
#define RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, RET_VAL, x, ...) do { \
if (!(x)) { \
((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \
RWKV_CTX_MSG(ctx, __VA_ARGS__); \
RWKV_CTX_MSG(ctx, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \
RWKV_MAYBE_BREAK; \
return RET_VAL; \
} } while (0)
// If the condition x is false, adds ERR_VAL to the ctx's last error, and returns RET_VAL.
#define RWKV_CTX_ASSERT(ctx, ERR_VAL, RET_VAL, x) do { \
if (!(x)) { \
((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \
RWKV_CTX_MSG(ctx, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \
RWKV_MAYBE_BREAK; \
return RET_VAL; \
} } while (0)
// If the condition x is false, returns RET_VAL.
#define RWKV_ENSURE(RET_VAL, x) do { \
if (!(x)) { \
RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \
RWKV_MAYBE_BREAK; \
return RET_VAL; \
} } while (0)
// If the condition x is false, prints a message to stderr, and returns RET_VAL.
#define RWKV_ENSURE_MSG(RET_VAL, x, ...) do { \
if (!(x)) { \
RWKV_MSG(__VA_ARGS__); \
RWKV_MSG("\n%s:%d: %s\n", __FILE__, __LINE__, #x); \
RWKV_MAYBE_BREAK; \
return RET_VAL; \
} } while (0)
// If the condition x is false, prints a message to stderr, and returns RET_VAL.
#define RWKV_CTX_ENSURE_MSG(ctx, RET_VAL, x, ...) do { \
if (!(x)) { \
((struct rwkv_context *) ctx)->last_error |= ERR_VAL; \
RWKV_CTX_MSG(ctx, __VA_ARGS__); \
RWKV_CTX_MSG(ctx, "\n%s:%d: %s\n", __FILE__, __LINE__, #x); \
RWKV_MAYBE_BREAK; \
return RET_VAL; \
} } while (0)
#define RWKV_ASSERT_FALSE_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, false, x, __VA_ARGS__)
#define RWKV_ASSERT_NULL_MSG(ERR_VAL, x, ...) RWKV_ASSERT_MSG(ERR_VAL, NULL, x, __VA_ARGS__)
#define RWKV_CTX_ASSERT_FALSE_MSG(ctx, ERR_VAL, x, ...) RWKV_CTX_ASSERT_MSG(ctx, ERR_VAL, false, x, __VA_ARGS__)
#define RWKV_ASSERT_FALSE(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, false, x)
#define RWKV_ASSERT_NULL(ERR_VAL, x) RWKV_ASSERT(ERR_VAL, NULL, x)
#define RWKV_CTX_ASSERT_FALSE(ctx, ERR_VAL, x) RWKV_CTX_ASSERT(ctx, ERR_VAL, false, x)
#define RWKV_ENSURE_OR_FALSE(x) RWKV_ENSURE(false, x)
#define RWKV_ENSURE_OR_NULL(x) RWKV_ENSURE(NULL, x)
#define RWKV_ENSURE_OR_FALSE_MSG(x, ...) RWKV_ENSURE_MSG(false, x, __VA_ARGS__)
// --- Utilities ---
size_t rwkv_tensor_nbytes(const enum ggml_type type, const int64_t width, const int64_t height) {
return (ggml_type_size(type) * width * height) / ggml_blck_size(type);
}
// For some reason, ggml_nbytes calculates the size in a way incompatible with rwkv.cpp
size_t rwkv_tensor_nbytes(const struct ggml_tensor * tensor) {
return rwkv_tensor_nbytes(tensor->type, tensor->ne[0], tensor->ne[1]);
}
size_t rwkv_ggml_overhead() {
return ggml_tensor_overhead() * GGML_MAX_NODES + ggml_graph_overhead();
}
struct ggml_context * rwkv_init_ggml_context(const size_t memory_size, const bool no_alloc) {
struct ggml_init_params init_params = {
memory_size,
NULL,
no_alloc
};
return ggml_init(init_params);
}
// --- IO utilities ---
// Reads a single uint32 value from a file.
bool rwkv_fread_uint32(FILE * file, uint32_t & dest) {
return fread((void *) &dest, sizeof(uint32_t), 1, file) == 1;
}
// Reads a single string value from a file.
bool rwkv_fread_string(FILE * file, size_t length, std::string & dest) {
dest.resize(length);
return fread((void *) dest.data(), length, 1, file) == 1;
}
// Reads a single data buffer from a file.
bool rwkv_fread_data(FILE * file, size_t length, void * dest) {
return fread(dest, length, 1, file) == 1;
}
// Writes a single uint32 value to a file.
bool rwkv_fwrite_uint32(FILE * file, const uint32_t value) {
return fwrite((const void *) &value, sizeof(uint32_t), 1, file);
}
// Writes a single string value to a file.
bool rwkv_fwrite_string(FILE * file, const std::string & value) {
return fwrite((const void *) value.data(), value.length(), 1, file) == 1;
}
// Writes a single data buffer to a file.
bool rwkv_fwrite_data(FILE * file, const void * data, const size_t length) {
return fwrite(data, length, 1, file) == 1;
}
// --- File handling ---
#define TYPE_UNKNOWN TYPE_COUNT
enum rwkv_type {
TYPE_FP32,
TYPE_FP16,
TYPE_Q4_0,
TYPE_Q4_1,
TYPE_Q4_1_O, // Unsupported
TYPE_Q4_2, // Unsupported
TYPE_Q4_3, // Unsupported
TYPE_Q5_0,
TYPE_Q5_1,
TYPE_Q8_0,
TYPE_COUNT
};
#define GGML_TYPE_UNKNOWN GGML_TYPE_COUNT
extern const enum ggml_type rwkv_type_to_ggml[TYPE_COUNT + 1] = {
GGML_TYPE_F32, /* FP32 */
GGML_TYPE_F16, /* FP16 */
GGML_TYPE_Q4_0, /* Q4_0 */
GGML_TYPE_Q4_1, /* Q4_1 */
GGML_TYPE_UNKNOWN, /* Q4_1_O */
GGML_TYPE_UNKNOWN, /* Q4_2 */
GGML_TYPE_UNKNOWN, /* Q4_3 */
GGML_TYPE_Q5_0, /* Q5_0 */
GGML_TYPE_Q5_1, /* Q5_1 */
GGML_TYPE_Q8_0, /* Q8_0 */
GGML_TYPE_COUNT /* COUNT */
};
extern const enum rwkv_type rwkv_type_from_ggml[GGML_TYPE_COUNT + 1] = {
TYPE_FP32, /* FP32 */
TYPE_FP16, /* FP16 */
TYPE_Q4_0, /* Q4_0 */
TYPE_Q4_1, /* Q4_1 */
TYPE_Q4_2, /* Q4_2 */
TYPE_Q4_3, /* Q4_3 */
TYPE_Q5_0, /* Q5_0 */
TYPE_Q5_1, /* Q5_1 */
TYPE_Q8_0, /* Q8_0 */
TYPE_COUNT, /* Q8_1 */
TYPE_COUNT, /* I8 */
TYPE_COUNT, /* I16 */
TYPE_COUNT, /* I32 */
TYPE_COUNT, /* COUNT */
};
extern const char * rwkv_type_to_string[TYPE_COUNT + 1] = {"FP32", "FP16", "Q4_0", "Q4_1", "Q4_1_O", "Q4_2", "Q4_3", "Q5_0", "Q5_1", "Q8_0", "unknown"};
enum rwkv_type rwkv_type_from_string(const char * str) {
for (int ord = 0; ord < TYPE_COUNT; ord++) {
if (strcmp(str, rwkv_type_to_string[ord]) == 0) {
return (enum rwkv_type) ord;
}
}
return TYPE_UNKNOWN;
}
struct rwkv_file_header {
uint32_t magic;
uint32_t version;
uint32_t n_vocab;
uint32_t n_embed;
uint32_t n_layer;
uint32_t data_type;
};
bool rwkv_is_file_version_in_range(uint32_t version) {
return version >= RWKV_FILE_VERSION_MIN && version <= RWKV_FILE_VERSION_MAX;
}
bool rwkv_fread_file_header(FILE * file, struct rwkv_file_header & header, bool verify_data_type = true) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_file_header), &header));
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_MAGIC, header.magic == RWKV_FILE_MAGIC);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_VERSION, rwkv_is_file_version_in_range(header.version), "Unsupported file version %" PRId32, header.version);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, header.data_type < TYPE_COUNT, "Model data type out of range (%" PRId32 " > %" PRId32 ")", header.data_type, TYPE_COUNT - 1);
if (verify_data_type) {
enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type];
RWKV_ASSERT_FALSE_MSG(
RWKV_ERROR_DATA_TYPE,
ggml_type != GGML_TYPE_UNKNOWN,
"Models in %s format cannot be loaded anymore because the format was removed.\n"
"You need to quantize the model into another format or use an older version of rwkv.cpp.\n"
"See https://github.com/saharNooby/rwkv.cpp#compatibility for more info",
rwkv_type_to_string[header.data_type]
);
RWKV_ASSERT_FALSE_MSG(
RWKV_ERROR_DATA_TYPE,
(!ggml_is_quantized(ggml_type) || header.version == RWKV_FILE_VERSION_1),
"The quantized model file in %s format was created with an old version of rwkv.cpp and can not be loaded anymore.\n"
"You need to requantize the model or use an older version of rwkv.cpp.\n"
"See https://github.com/saharNooby/rwkv.cpp#compatibility for more info",
rwkv_type_to_string[header.data_type]
);
}
return true;
}
bool rwkv_fwrite_file_header(FILE * file, const struct rwkv_file_header & header) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_data(file, &header, sizeof(struct rwkv_file_header)));
return true;
}
struct rwkv_tensor_header {
uint32_t dim_count;
uint32_t key_length;
uint32_t data_type;
uint32_t width;
uint32_t height;
size_t size() const;
};
size_t rwkv_tensor_header::size() const {
return rwkv_tensor_nbytes(rwkv_type_to_ggml[this->data_type], this->width, this->height);
}
struct rwkv_tensor {
struct rwkv_tensor_header header;
std::string name;
uint8_t * data;
};
bool rwkv_fread_tensor_header(FILE * file, struct rwkv_tensor_header & header) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_tensor_header) - sizeof(uint32_t), &header));
header.height = 1;
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_SHAPE, header.dim_count == 1 || header.dim_count == 2, "Tensor has an invalid shape (%" PRId32 " dimensions)", header.dim_count);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, header.data_type < TYPE_COUNT, "Tensor data type out of range (%" PRId32 " > %" PRId32 ")", header.data_type, TYPE_COUNT - 1);
RWKV_ASSERT_FALSE_MSG(
RWKV_ERROR_DATA_TYPE,
rwkv_type_to_ggml[header.data_type] != GGML_TYPE_UNKNOWN,
"Tensor data type (%s) is no longer supported",
rwkv_type_to_string[header.data_type]
);
if (header.dim_count == 2) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_uint32(file, header.height));
}
return true;
}
bool rwkv_fwrite_tensor_header(FILE * file, const struct rwkv_tensor_header & header) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_data(file, &header, sizeof(struct rwkv_tensor_header) - (header.dim_count == 1 ? sizeof(uint32_t) : 0)));
return true;
}
bool rwkv_fskip_tensor_name_and_data(FILE * file, const struct rwkv_tensor_header & header) {
return fseek(file, header.key_length + header.size(), SEEK_CUR) == 0;
}
bool rwkv_fskip_tensor_data(FILE * file, const struct rwkv_tensor_header & header) {
return fseek(file, header.size(), SEEK_CUR) == 0;
}
bool rwkv_fread_tensor_header_and_skip(FILE * file, struct rwkv_tensor_header & header) {
RWKV_ENSURE_OR_FALSE(rwkv_fread_tensor_header(file, header));
RWKV_ASSERT_FALSE(RWKV_ERROR_DATA, rwkv_fskip_tensor_name_and_data(file, header));
return true;
}
bool rwkv_fread_tensor_data(FILE * file, struct rwkv_tensor & output, void * buffer = NULL) {
size_t data_size = output.header.size();
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, output.header.key_length, output.name));
if (buffer) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, data_size, buffer));
} else {
output.data = NULL;
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fskip_tensor_name_and_data(file, output.header));
}
return true;
}
bool rwkv_fread_tensor(FILE * file, struct rwkv_tensor & output, void * buffer = NULL) {
RWKV_ENSURE_OR_FALSE(rwkv_fread_tensor_header(file, output.header));
RWKV_ENSURE_OR_FALSE(rwkv_fread_tensor_data(file, output, buffer));
return true;
}
bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header & header, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) {
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name");
enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type];
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, ggml_type != GGML_TYPE_UNKNOWN, "Unsupported tensor data type %s from %s", rwkv_type_to_string[header.data_type], name.c_str());
tensor = header.dim_count == 1
? ggml_new_tensor_1d(ctx, ggml_type, header.width)
: ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor");
ggml_set_name(tensor, name.c_str());
// Tensor data may be NULL if no_alloc is true
if (tensor->data != NULL) {
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, rwkv_tensor_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str());
} else {
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fskip_tensor_data(file, header), "Failed to skip tensor data from %s", name.c_str());
}
return true;
}
bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) {
struct rwkv_tensor_header header;
RWKV_ENSURE_OR_FALSE_MSG(rwkv_fread_tensor_header(file, header), "Invalid tensor header");
return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor);
}
bool rwkv_fwrite_tensor(FILE * file, const struct rwkv_tensor & tensor) {
RWKV_ENSURE_OR_FALSE(rwkv_fwrite_tensor_header(file, tensor.header));
RWKV_ENSURE_OR_FALSE(rwkv_fwrite_string(file, tensor.name));
RWKV_ENSURE_OR_FALSE(rwkv_fwrite_data(file, tensor.data, tensor.header.size()));
return true;
}
// --- Model loading ---
struct rwkv_layer {
struct ggml_tensor * ln1_weight;
struct ggml_tensor * ln1_bias;
// RWKV, also called "attention" by the author.
struct ggml_tensor * att_time_mix_k;
struct ggml_tensor * att_time_mix_v;
struct ggml_tensor * att_time_mix_r;
struct ggml_tensor * att_time_first;
struct ggml_tensor * att_time_decay;
struct ggml_tensor * att_key;
struct ggml_tensor * att_value;
struct ggml_tensor * att_receptance;
struct ggml_tensor * att_output;
struct ggml_tensor * ln2_weight;
struct ggml_tensor * ln2_bias;
// FFN.
struct ggml_tensor * ffn_time_mix_k;
struct ggml_tensor * ffn_time_mix_r;
struct ggml_tensor * ffn_key;
struct ggml_tensor * ffn_value;
struct ggml_tensor * ffn_receptance;
};
// The model holds all parameter tensors and the ggml context containing them.
// Each tensor has data and can be used in computations happening in other contexts.
struct rwkv_model {
// This context holds all parameter tensors.
// It must not be used for computations.
struct ggml_context * ggml_ctx;
struct rwkv_file_header header;
struct ggml_tensor * emb;
struct ggml_tensor * ln0_weight;
struct ggml_tensor * ln0_bias;
std::unique_ptr<struct rwkv_layer[]> layers;
struct ggml_tensor * ln_out_weight;
struct ggml_tensor * ln_out_bias;
struct ggml_tensor * head;
// How many layers were offloaded to the GPU.
size_t offloaded_layer_count;
// How many RWKV contexts reference this model.
int reference_count;
};
struct rwkv_file {
FILE * file;
rwkv_file(FILE * file): file(file) {}
~rwkv_file() {
if (file) {
fclose(file);
}
}
};
// https://stackoverflow.com/a/6458689
template<typename F>
bool rwkv_set_params(struct rwkv_model & model, F callback) {
RWKV_ENSURE_OR_FALSE(callback("emb.weight", model.emb));
RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.weight", model.ln0_weight));
RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.bias", model.ln0_bias));
uint32_t n_layer = model.header.n_layer;
std::unique_ptr<struct rwkv_layer[]> layers(new(std::nothrow) struct rwkv_layer[n_layer]);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, layers.get(), "Failed to allocate model layers");
model.layers = std::move(layers);
for (uint32_t i = 0; i < n_layer; i++) {
char buffer[128];
size_t offset = sprintf(buffer, "blocks.%" PRId32 ".", i);
rwkv_layer & layer = model.layers[i];
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.weight"), buffer), layer.ln1_weight));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.bias"), buffer), layer.ln1_bias));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance));
}
RWKV_ENSURE_OR_FALSE(callback("ln_out.weight", model.ln_out_weight));
RWKV_ENSURE_OR_FALSE(callback("ln_out.bias", model.ln_out_bias));
RWKV_ENSURE_OR_FALSE(callback("head.weight", model.head));
return true;
}
// Creates a ggml context and loads all parameter tensors from a model file.
bool rwkv_load_model_from_file(const char * file_path, struct rwkv_model & model) {
struct stat file_stat;
std::unordered_map<std::string, struct ggml_tensor *> parameters;
rwkv_file file(fopen(file_path, "rb"));
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_OPEN, file.file, "Failed to open file %s", file_path);
// Be very careful when changing this code. It must support files larger than 2 GB by using 64-bit functions to get the file length.
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_STAT, fstat(fileno(file.file), &file_stat) == 0, "Failed to stat file %s", file_path);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE, rwkv_fread_file_header(file.file, model.header), "Invalid file header");
model.ggml_ctx = rwkv_init_ggml_context(
// ggml tensors must be aligned; assuming here that overhead of parameter headers, included in the file size, will account for that.
file_stat.st_size + rwkv_ggml_overhead(),
false
);
std::string name;
struct ggml_tensor * tensor;
while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) {
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, model.ggml_ctx, name, tensor), "Failed to read a model parameter");
parameters[std::move(name)] = tensor;
}
std::unordered_map<std::string, struct ggml_tensor *> & parameters_ref = parameters;
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(model, [&](const char * key, struct ggml_tensor *& dest) {
struct ggml_tensor * tensor = parameters_ref[key];
RWKV_ENSURE_OR_FALSE_MSG(tensor, "Model parameter %s not found", key);
dest = tensor;
return true;
}));
// Verify order of dimensions
struct ggml_tensor * emb = model.emb;
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_SHAPE, emb->n_dims == 2, "Unexpected dimension count of embedding matrix %d", emb->n_dims);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[0] == model.header.n_embed, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[0]);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_DIMENSION, emb->ne[1] == model.header.n_vocab, "Unexpected dimension of embedding matrix %" PRId64, emb->ne[1]);
return true;
}
// --- Operators ---
void rwkv_exp_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) {
GGML_ASSERT(dest->type == GGML_TYPE_F32);
GGML_ASSERT(src->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(dest));
GGML_ASSERT(ggml_is_contiguous(src));
GGML_ASSERT(ggml_are_same_shape(src, dest));
// Assuming 2D tensors.
int64_t element_count = src->ne[0] * src->ne[1];
float * src_data = (float *) src->data;
float * dest_data = (float *) dest->data;
for (int64_t i = 0; i < element_count; i++) {
dest_data[i] = expf(src_data[i]);
}
// Suppress warnings for unused parameters.
(void) ith;
(void) nth;
(void) userdata;
}
void rwkv_1_minus_x_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) {
GGML_ASSERT(dest->type == GGML_TYPE_F32);
GGML_ASSERT(src->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(dest));
GGML_ASSERT(ggml_is_contiguous(src));
GGML_ASSERT(ggml_are_same_shape(src, dest));
// Assuming 2D tensors.
int64_t element_count = src->ne[0] * src->ne[1];
float * src_data = (float *) src->data;
float * dest_data = (float *) dest->data;
for (int64_t i = 0; i < element_count; i++) {
dest_data[i] = 1.0F - src_data[i];
}
// Suppress warnings for unused parameters.
(void) ith;
(void) nth;
(void) userdata;
}
void rwkv_sigmoid_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) {
GGML_ASSERT(dest->type == GGML_TYPE_F32);
GGML_ASSERT(src->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(dest));
GGML_ASSERT(ggml_is_contiguous(src));
GGML_ASSERT(ggml_are_same_shape(src, dest));
// Assuming 2D tensors.
int64_t element_count = src->ne[0] * src->ne[1];
float * src_data = (float *) src->data;
float * dest_data = (float *) dest->data;
for (int64_t i = 0; i < element_count; i++) {
dest_data[i] = 1.0F / (1.0F + expf(-src_data[i]));
}
// Suppress warnings for unused parameters.
(void) ith;
(void) nth;
(void) userdata;
}
void rwkv_max_impl(
struct ggml_tensor * dest,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
int ith,
int nth,
void * userdata
) {
GGML_ASSERT(dest->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(dest));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_are_same_shape(src0, dest));
GGML_ASSERT(ggml_are_same_shape(src1, dest));
// Assuming 2D tensors.
int64_t element_count = src0->ne[0] * src0->ne[1];
float * src0_data = (float *) src0->data;
float * src1_data = (float *) src1->data;
float * dest_data = (float *) dest->data;
for (int64_t i = 0; i < element_count; i++) {
dest_data[i] = fmaxf(src0_data[i], src1_data[i]);
}
// Suppress warnings for unused parameters.
(void) ith;
(void) nth;
(void) userdata;
}
struct ggml_tensor * rwkv_exp(ggml_context * ctx, struct ggml_tensor * x) {
return ggml_map_custom1(ctx, x, rwkv_exp_impl, 1, NULL);
}
struct ggml_tensor * rwkv_1_minus_x(ggml_context * ctx, struct ggml_tensor * x) {
return ggml_map_custom1(ctx, x, rwkv_1_minus_x_impl, 1, NULL);
}
struct ggml_tensor * rwkv_sigmoid(ggml_context * ctx, struct ggml_tensor * x) {
return ggml_map_custom1(ctx, x, rwkv_sigmoid_impl, 1, NULL);
}
struct ggml_tensor * rwkv_max(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y) {
return ggml_map_custom2(ctx, x, y, rwkv_max_impl, 1, NULL);
}
struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) {
// LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias`
// Looks like ggml_norm does the first part, we only need to apply weight & bias.
return ggml_add_inplace(ctx, ggml_mul_inplace(ctx, ggml_norm(ctx, x, 1e-5F), weight), bias);
}
// --- Implementation ---
// View tensors of a state of a single layer.
struct rwkv_layer_state {
struct ggml_tensor * ffn_xx;
struct ggml_tensor * att_xx;
struct ggml_tensor * att_aa;
struct ggml_tensor * att_bb;
struct ggml_tensor * att_pp;
};
// The computation graph holds ggml context and the ggml cgraph.
// It can be either a serial or a sequential graph.
struct rwkv_computation_graph {
struct ggml_context * ggml_ctx;
// ggml_cgraph is so large that it can cause stack overflows if not stored on the heap.
std::unique_ptr<struct ggml_cgraph> cgraph;
// Input tensors.
struct ggml_tensor * tokens;
struct ggml_tensor * input_state;
std::unique_ptr<struct rwkv_layer_state[]> input_layers;
// Output tensors.
struct ggml_tensor * output_state;
std::unique_ptr<struct rwkv_layer_state[]> output_layers;
struct ggml_tensor * logits;
// ggml graph counters before the graph was extended with logits tensor.
int pre_logits_nodes;
int pre_logits_leafs;
// ggml graph counters after the graph was extended with logits tensor.
int post_logits_nodes;
int post_logits_leafs;
};
// The context holds the model and both serial and sequential computation graphs.
struct rwkv_context {
struct rwkv_model * model;
// The serial graph implements the traditional RNN mode that processes only one token at a time (serial mode).
struct rwkv_computation_graph serial_graph;
// The sequence graph implements the "sequence mode" (or transformer/GPT mode) that processes multiple tokens at a time.
// This can be an order of magnitude or so faster than serial execution if used properly.
struct rwkv_computation_graph sequential_graph;
size_t last_used_sequence_length;
uint32_t n_threads;
enum rwkv_error_flags last_error;
bool print_errors;
};
void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) {
bool * ptr = ctx ? &ctx->print_errors : &global_print_errors;
*ptr = print_errors;
}
bool rwkv_get_print_errors(struct rwkv_context * ctx) {
return ctx ? ctx->print_errors : global_print_errors;
}
enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx) {
enum rwkv_error_flags * ptr = ctx ? &ctx->last_error : &global_last_error;
enum rwkv_error_flags value = *ptr;
*ptr = RWKV_ERROR_NONE;
return value;
}
void rwkv_carry_x(struct ggml_context * ctx,
struct ggml_tensor * weight,
struct ggml_tensor * bias,
struct ggml_tensor *& x,
struct ggml_tensor *& x_prev,
struct ggml_tensor *& carry
) {
const size_t n_embed = x->ne[0];
const size_t sequence_len = x->ne[1];
if (sequence_len == 1) {
// self.layer_norm(x, self.w.blocks[i].ln2)
x = rwkv_layer_norm(ctx, x, weight, bias);
// xx = state[5*i+0]
x_prev = carry;
// state[5*i+0] = x
carry = x;
} else {
// self.layer_norm(x, self.w.blocks[i].ln2)
x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, weight, x), ggml_repeat(ctx, bias, x));
// xx = torch.cat((state[5*i+0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:]))
x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len);
x_prev = ggml_set_1d_inplace(ctx, x_prev, carry, 0);
x_prev = ggml_set_1d_inplace(ctx, x_prev, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float));
// state[5*i+0] = x[-1,:]
carry = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float));
}
}
void rwkv_att_rkv(
struct ggml_context * ctx,
struct rwkv_layer layer,
struct ggml_tensor * x,
struct ggml_tensor * x_prev,
struct ggml_tensor *& r,
struct ggml_tensor *& k,
struct ggml_tensor *& v
) {
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
struct ggml_tensor * xk = ggml_add_inplace(ctx,
ggml_mul(ctx, x, layer.att_time_mix_k),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k))
);
// xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v)
struct ggml_tensor * xv = ggml_add_inplace(ctx,
ggml_mul(ctx, x, layer.att_time_mix_v),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v))
);
// xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r)
struct ggml_tensor * xr = ggml_add_inplace(ctx,
ggml_mul(ctx, x, layer.att_time_mix_r),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r))
);
// r = torch.sigmoid(rw @ xr)
r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr));
// k = kw @ xk
k = ggml_mul_mat(ctx, layer.att_key, xk);
// v = vw @ xv
v = ggml_mul_mat(ctx, layer.att_value, xv);
}
struct ggml_tensor * rwkv_att_wkv(
struct ggml_context * ctx,
struct ggml_tensor * att_time_first,
struct ggml_tensor * att_time_decay,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor *& aa,
struct ggml_tensor *& bb,
struct ggml_tensor *& pp
) {
// ww = time_first + k
struct ggml_tensor * ww = ggml_add(ctx, att_time_first, k);
// qq = torch.maximum(pp, ww)
struct ggml_tensor * qq = rwkv_max(ctx, pp, ww);
// e1 = torch.exp(pp - qq)
struct ggml_tensor * e1 = rwkv_exp(ctx, ggml_sub(ctx, pp, qq));
// e2 = torch.exp(ww - qq)
struct ggml_tensor * e2 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq));
// a = e1 * aa + e2 * v
struct ggml_tensor * a = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v));
// b = e1 * bb + e2
struct ggml_tensor * b = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2);
// ww = pp + time_decay
ww = ggml_add(ctx, pp, att_time_decay);
// qq = torch.maximum(ww, k)
qq = rwkv_max(ctx, ww, k);
// e1 = torch.exp(ww - qq)
e1 = rwkv_exp(ctx, ggml_sub(ctx, ww, qq));
// e2 = torch.exp(k[t] - qq)
e2 = rwkv_exp(ctx, ggml_sub(ctx, k, qq));
// state[5 * i + 2] = e1 * aa + e2 * v
// state[5 * i + 3] = e1 * bb + e2
// state[5 * i + 4] = qq
aa = ggml_add_inplace(ctx, ggml_mul(ctx, e1, aa), ggml_mul(ctx, e2, v));
bb = ggml_add_inplace(ctx, ggml_mul(ctx, e1, bb), e2);
pp = qq;
// wkv = a / b
return ggml_div(ctx, a, b);
}
struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) {
struct ggml_tensor * x_prev;
rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx);
struct ggml_tensor * r, * k, * v;
rwkv_att_rkv(ctx, layer, x, x_prev, r, k, v);
struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, k, v, state.att_aa, state.att_bb, state.att_pp);
// ow @ (r * xx)
return ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv));
}
struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) {
struct ggml_tensor * x_prev;
rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx);
// xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k)
// xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k)
struct ggml_tensor * xk = ggml_add_inplace(
ctx,
ggml_mul(ctx, x, layer.ffn_time_mix_k),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k))
);
// xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r)
struct ggml_tensor * xr = ggml_add_inplace(
ctx,
ggml_mul(ctx, x, layer.ffn_time_mix_r),
ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r))
);
// r = torch.sigmoid(rw @ xr)
struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr));
// k = torch.square(torch.relu(kw @ xk))
struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk)));
// r * (vw @ k)
return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k));
}
void rwkv_create_input_and_output_views(
struct rwkv_layer_state * inputs,
struct rwkv_layer_state * outputs,
struct ggml_tensor * input,
struct ggml_tensor * output,
struct ggml_context * ctx,
size_t n_layer,
size_t n_embed
) {
for (size_t i = 0; i < n_layer; i++) {
struct rwkv_layer_state & input_state = inputs[i];
input_state.ffn_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 0) * sizeof(float));
input_state.att_xx = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 1) * sizeof(float));
input_state.att_aa = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 2) * sizeof(float));
input_state.att_bb = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 3) * sizeof(float));
input_state.att_pp = ggml_view_1d(ctx, input, n_embed, n_embed * (i * 5 + 4) * sizeof(float));
struct rwkv_layer_state & output_state = outputs[i];
output_state.ffn_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 0) * sizeof(float));
output_state.att_xx = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 1) * sizeof(float));
output_state.att_aa = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 2) * sizeof(float));
output_state.att_bb = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 3) * sizeof(float));
output_state.att_pp = ggml_view_1d(ctx, output, n_embed, n_embed * (i * 5 + 4) * sizeof(float));
}
}
// Creates and sets the input and output ggml tensors, builds the computation graph.
bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_computation_graph & graph) {
graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph());
struct rwkv_file_header & header = model.header;
const size_t n_vocab = header.n_vocab;
const size_t n_embed = header.n_embed;
const size_t n_layer = header.n_layer;
struct ggml_context * ctx = graph.ggml_ctx;
// Creates a 1-element tensor.
graph.tokens = ggml_new_i32(ctx, 0);
struct ggml_tensor * input = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer);
struct ggml_tensor * output = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * 5 * n_layer);
// We collect parts of input state here. Each part is (n_embed) vector.
std::unique_ptr<struct rwkv_layer_state[]> inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, inputs.get(), "Failed to allocate input state parts");
// We collect parts of output state here. Each part is (n_embed) vector.
std::unique_ptr<struct rwkv_layer_state[]> outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]);
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts");
rwkv_create_input_and_output_views(inputs.get(), outputs.get(), input, output, ctx, n_layer, n_embed);
graph.logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab);
// x = self.w.emb.weight[token]
struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, graph.tokens);
// x = self.layer_norm(x, self.w.blocks[0].ln0)
x = rwkv_layer_norm(ctx, x, model.ln0_weight, model.ln0_bias);