diff --git a/ggml b/ggml index f52d2a0..a1d0ea7 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit f52d2a05cf8327baf6c0d49e7b231953179e03d3 +Subproject commit a1d0ea7c2abd44f56822ffdfcfe0a0fcf7170885 diff --git a/rwkv.cpp b/rwkv.cpp index 0ee6518..7664698 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -287,6 +287,8 @@ struct rwkv_tensor_header { uint32_t data_type; uint32_t width; uint32_t height; + + const size_t size() const; }; struct rwkv_tensor { @@ -319,22 +321,8 @@ bool rwkv_fwrite_tensor_header(FILE * file, const struct rwkv_tensor_header & he return true; } -size_t rwkv_tensor_size(enum ggml_type type, const int64_t width, const int64_t height = 1) { - struct ggml_tensor decoy {}; - decoy.type = type; - decoy.ne[0] = width; - decoy.ne[1] = height; - decoy.ne[2] = 1; - decoy.ne[3] = 1; - return ggml_nbytes(&decoy); -} - -size_t rwkv_tensor_size(const struct rwkv_tensor_header & header) { - return rwkv_tensor_size(rwkv_type_to_ggml[header.data_type], header.width, header.height); -} - bool rwkv_fskip_tensor_data(FILE * file, const struct rwkv_tensor_header & header) { - return fseek(file, header.key_length + rwkv_tensor_size(header), SEEK_CUR) == 0; + return fseek(file, header.key_length + header.size(), SEEK_CUR) == 0; } bool rwkv_fread_tensor_header_and_skip(FILE * file, struct rwkv_tensor_header & header) { @@ -344,7 +332,7 @@ bool rwkv_fread_tensor_header_and_skip(FILE * file, struct rwkv_tensor_header & } bool rwkv_fread_tensor_data(FILE * file, struct rwkv_tensor & output, void * buffer = NULL) { - size_t data_size = rwkv_tensor_size(output.header); + size_t data_size = output.header.size(); RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, output.header.key_length, output.name)); if (buffer) { @@ -389,7 +377,7 @@ bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string bool rwkv_fwrite_tensor(FILE * file, const struct rwkv_tensor & tensor) { RWKV_ENSURE_OR_FALSE(rwkv_fwrite_tensor_header(file, tensor.header)); RWKV_ENSURE_OR_FALSE(rwkv_fwrite_string(file, tensor.name)); - RWKV_ENSURE_OR_FALSE(rwkv_fwrite_data(file, tensor.data, rwkv_tensor_size(tensor.header))); + RWKV_ENSURE_OR_FALSE(rwkv_fwrite_data(file, tensor.data, tensor.header.size())); return true; } @@ -482,40 +470,165 @@ struct ggml_tensor * rwkv_max(ggml_context * ctx, struct ggml_tensor * x, struct struct ggml_tensor * rwkv_layer_norm(ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * weight, struct ggml_tensor * bias) { // LayerNorm in RWKV is `x = (x - mean(x)) / sqrt(variance(x) + 1e-5) * weight + bias` // Looks like ggml_norm does the first part, we only need to apply weight & bias. - return ggml_add_inplace(ctx, ggml_mul(ctx, ggml_norm(ctx, x), weight), bias); + return ggml_add_inplace(ctx, ggml_mul_inplace(ctx, ggml_norm(ctx, x), weight), bias); } // --- Implementation --- +// Used as a helper during rwkv_ctx_size calculation. +struct rwkv_future_tensor; + // Used to calculate the memory usage of ggml contexts before allocating them. -// Since ggml uses an internal bump allocator that can't be grown at runtime, -// we need to ensure we have enough space; while at the same time not using more memory than necessary. -struct rwkv_ctx_size { +// Since ggml uses an internal bump allocator that can't be grown at runtime, we need to ensure we have enough space, +// while at the same time not using more memory than necessary. +struct rwkv_future_ctx { size_t objects_count = 0; - size_t objects_size = 0; + size_t memory_size = 0; size_t scratch_size = 0; + + // Align to GGML_MEM_ALIGN, which can currently be up to 16 + static const size_t align(const size_t size) { + return ((size + 15) & ~15); + } + + void add_objects(const size_t size, const size_t count = 1) { + this->objects_count += count; + + if (size && count) { + this->add_memory(size, count); + } + } + + void add_memory(const size_t size, const size_t count = 1) { + this->memory_size += this->align(size) * count; + } + + void add_scratch(const size_t size, const size_t count = 1) { + this->scratch_size += this->align(size) * count; + } + + void add_data(const bool use_scratch, const size_t size, const size_t count = 1) { + if (use_scratch) { + this->add_scratch(size, count); + } else { + this->add_memory(size, count); + } + } + + struct rwkv_future_tensor declare(const enum ggml_type type, const uint64_t width, const uint64_t height = 1); + + struct rwkv_future_tensor alloc(const enum ggml_type type, const uint64_t width, const uint64_t height = 1, const bool use_scratch = true); }; +struct rwkv_future_tensor { + enum ggml_type type = GGML_TYPE_COUNT; + uint64_t width = 0; + uint64_t height = 0; + + static const size_t size(const enum ggml_type type, const uint64_t width, const uint64_t height) { + struct ggml_tensor decoy {}; + decoy.type = type; + decoy.ne[0] = width; + decoy.ne[1] = height; + decoy.ne[2] = 1; + decoy.ne[3] = 1; + return ggml_nbytes(&decoy); + } + + rwkv_future_tensor() {} + rwkv_future_tensor(const enum ggml_type type, const uint64_t width, const uint64_t height = 1): type(type), width(width), height(height) {} + rwkv_future_tensor(const struct ggml_tensor * ref): type(ref->type), width(ref->ne[0]), height(ref->ne[1]) {} + + struct rwkv_future_tensor alloc(struct rwkv_future_ctx & ctx, const bool use_scratch = true) const { + ctx.add_objects(sizeof(struct ggml_tensor)); + ctx.add_data(use_scratch, rwkv_future_tensor::size(type, width, height)); + return *this; + } + + struct rwkv_future_tensor view(struct rwkv_future_ctx & ctx) const { + ctx.add_objects(sizeof(struct ggml_tensor)); + return *this; + } + + struct rwkv_future_tensor subview(struct rwkv_future_ctx & ctx, const uint32_t width, const uint32_t height = 1) const { + ctx.add_objects(sizeof(struct ggml_tensor), 2); + ctx.add_memory(sizeof(uint32_t) * 2); + return rwkv_future_tensor(type, width, height); + } + + struct rwkv_future_tensor dup(struct rwkv_future_ctx & ctx) const { + return this->alloc(ctx); + } + + struct rwkv_future_tensor layer_norm(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & weight, const struct rwkv_future_tensor & bias) const { + return this->dup(ctx).view(ctx).view(ctx); + } + + struct rwkv_future_tensor repeat(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor reference) const { + return reference.dup(ctx); + } + + struct rwkv_future_tensor set_inplace(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor src) { + ctx.add_objects(sizeof(struct ggml_tensor)); + ctx.add_memory(sizeof(uint32_t) * 5); + return this->view(ctx); + } + + struct rwkv_future_tensor consume(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & other) { + return this->view(ctx); + } + + struct rwkv_future_tensor combine(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & other) const { + return this->dup(ctx); + } + + struct rwkv_future_tensor fn(struct rwkv_future_ctx & ctx) const { + ctx.add_objects(sizeof(struct ggml_tensor)); + ctx.add_memory(sizeof(void *) / sizeof(uint32_t)); + return this->dup(ctx); + } + + struct rwkv_future_tensor mul_mat(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & other) const { + return ctx.alloc(GGML_TYPE_F32, this->height, other.height); + } + + struct rwkv_future_tensor get_rows(struct rwkv_future_ctx & ctx, const struct rwkv_future_tensor & other) const { + return ctx.alloc(GGML_TYPE_F32, this->width, other.width); + } +}; + +const size_t rwkv_tensor_header::size() const { + return rwkv_future_tensor::size(rwkv_type_to_ggml[this->data_type], this->width, this->height); +} + +struct rwkv_future_tensor rwkv_future_ctx::declare(const enum ggml_type type, const uint64_t width, const uint64_t height) { + return rwkv_future_tensor(type, width, height); +} + +struct rwkv_future_tensor rwkv_future_ctx::alloc(const enum ggml_type type, const uint64_t width, const uint64_t height, const bool use_scratch) { + return this->declare(type, width, height).alloc(*this, use_scratch); +} + struct rwkv_ggml_context { std::unique_ptr scratch; struct ggml_context * ctx; rwkv_ggml_context(): ctx(NULL) {} - rwkv_ggml_context(struct rwkv_ctx_size size): ctx(NULL) { - scratch.reset(new(std::nothrow) uint8_t[size.scratch_size]); + rwkv_ggml_context(const struct rwkv_future_ctx future_ctx): ctx(NULL) { + scratch.reset(new(std::nothrow) uint8_t[future_ctx.scratch_size]); if (!scratch) { return; } - ctx = ggml_init({ size.objects_count * GGML_OBJECT_SIZE + size.objects_size, NULL, false}); + ctx = ggml_init({ future_ctx.objects_count * GGML_OBJECT_SIZE + future_ctx.memory_size, NULL, false}); if (!ctx) { return; } - ggml_set_scratch(ctx, { 0, size.scratch_size, scratch.get() }); + ggml_set_scratch(ctx, { 0, future_ctx.scratch_size, scratch.get() }); } struct rwkv_ggml_context & operator=(struct rwkv_ggml_context && source) { @@ -598,7 +711,7 @@ struct rwkv_context { enum rwkv_error_flags last_error; bool print_errors; - uint32_t gpu_layers; + size_t gpu_layers; }; // https://stackoverflow.com/a/6458689 @@ -647,136 +760,108 @@ bool rwkv_set_params(struct rwkv_model & model, F callback) { return true; } -void rwkv_ctx_size_add_objects(struct rwkv_ctx_size & ctx_size, size_t objects, size_t object_size = sizeof(struct ggml_tensor)) { - ctx_size.objects_count += objects; - ctx_size.objects_size += ((object_size + 15) & ~15) * objects; -} - -void rwkv_ctx_size_add_scratch(struct rwkv_ctx_size & ctx_size, size_t length, size_t count = 1) { - ctx_size.scratch_size += ((length + 15) & ~15) * count; -} - -void rwkv_ctx_size_add(struct rwkv_ctx_size & ctx_size, size_t objects, size_t scratch = 0, size_t scratches = 1) { - rwkv_ctx_size_add_objects(ctx_size, objects); - rwkv_ctx_size_add_scratch(ctx_size, scratch, scratches); -} - -void rwkv_ctx_size_add(struct rwkv_ctx_size & ctx_size, size_t count, const struct rwkv_ctx_size & other) { - ctx_size.objects_count += other.objects_count * count; - ctx_size.objects_size += other.objects_size * count; - ctx_size.scratch_size += other.scratch_size * count; -} - -void rwkv_ctx_size_add_tensor( - struct rwkv_ctx_size & ctx_size, - const uint64_t tensors, - const uint64_t views, - const enum ggml_type type, - const uint64_t width, - const uint64_t height = 1 +void rwkv_future_carry_x(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor weight, + const struct rwkv_future_tensor bias, + struct rwkv_future_tensor & x, + struct rwkv_future_tensor & x_prev, + struct rwkv_future_tensor & carry ) { - rwkv_ctx_size_add_objects(ctx_size, tensors + views); - rwkv_ctx_size_add_scratch(ctx_size, rwkv_tensor_size(type, width, height), tensors); -} - -void rwkv_ctx_size_add_tensor(struct rwkv_ctx_size & size, const uint64_t tensors, const uint64_t views, const struct rwkv_tensor_header & header) { - rwkv_ctx_size_add_tensor(size, tensors, views, rwkv_type_to_ggml[header.data_type], header.width, header.height); -} - -struct rwkv_ctx_size rwkv_xx_size(const size_t n_embed = 0, const size_t sequence_len = 1) { - struct rwkv_ctx_size ctx_size; - - if (sequence_len == 1) { - /* x0 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); + if (x.height == 1) { + x = x.layer_norm(ctx, weight, bias); + x_prev = carry; + carry = x; } else { - /* x0 */ rwkv_ctx_size_add_tensor(ctx_size, 4, 1, GGML_TYPE_F32, n_embed, sequence_len); + x = x.layer_norm(ctx, weight.repeat(ctx, x), bias.repeat(ctx, x)); - /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 1, 2, GGML_TYPE_F32, n_embed, sequence_len); - /* xx */ rwkv_ctx_size_add_objects(ctx_size, 2, sizeof(struct ggml_tensor) + rwkv_tensor_size(GGML_TYPE_I32, 5)); - /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, 1, GGML_TYPE_F32, n_embed * sequence_len - 1); + x_prev = x.dup(ctx) + .set_inplace(ctx, carry) + .set_inplace(ctx, x.subview(ctx, x.width, x.height - 1)); - /* xx */ rwkv_ctx_size_add_tensor(ctx_size, 0, 1, GGML_TYPE_F32, n_embed); + carry = x.subview(ctx, x.width); } - - return ctx_size; } -void rwkv_xx(struct ggml_context * ctx, struct ggml_tensor * weight, struct ggml_tensor * bias, struct ggml_tensor *& x, struct ggml_tensor *& xx, struct ggml_tensor *& state) { - size_t n_embed = x->ne[0]; - size_t sequence_len = x->ne[1]; +void rwkv_carry_x(struct ggml_context * ctx, + struct ggml_tensor * weight, + struct ggml_tensor * bias, + struct ggml_tensor *& x, + struct ggml_tensor *& x_prev, + struct ggml_tensor *& carry +) { + const size_t n_embed = x->ne[0]; + const size_t sequence_len = x->ne[1]; if (sequence_len == 1) { // self.layer_norm(x, self.w.blocks[i].ln2) x = rwkv_layer_norm(ctx, x, weight, bias); // xx = state[5*i+0] - xx = state; + x_prev = carry; // state[5*i+0] = x - state = x; + carry = x; } else { // self.layer_norm(x, self.w.blocks[i].ln2) x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, weight, x), ggml_repeat(ctx, bias, x)); // xx = torch.cat((state[5*i+0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:])) - xx = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len); - xx = ggml_set_1d_inplace(ctx, xx, state, 0); - xx = ggml_set_1d_inplace(ctx, xx, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float)); + x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len); + x_prev = ggml_set_1d_inplace(ctx, x_prev, carry, 0); + x_prev = ggml_set_1d_inplace(ctx, x_prev, ggml_view_1d(ctx, x, n_embed * (sequence_len - 1), 0), n_embed * sizeof(float)); // state[5*i+0] = x[-1,:] - state = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float)); + carry = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_len - 1) * sizeof(float)); } } -struct rwkv_ctx_size rwkv_att_rkv_size(const size_t n_embed = 0, const size_t sequence_len = 1) { - size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); - - struct rwkv_ctx_size ctx_size; - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - - /* xv */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); - /* xv */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* xv */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - - /* r */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed, sequence_len); - /* r */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* k */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed, sequence_len); - /* v */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed, sequence_len); +void rwkv_future_att_rkv(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor time_mix_k, + const struct rwkv_future_tensor time_mix_v, + const struct rwkv_future_tensor time_mix_r, + const struct rwkv_future_tensor x, + const struct rwkv_future_tensor x_prev, + const struct rwkv_future_tensor att_r, + const struct rwkv_future_tensor att_k, + const struct rwkv_future_tensor att_v, + struct rwkv_future_tensor & r, + struct rwkv_future_tensor & k, + struct rwkv_future_tensor & v +) { + const struct rwkv_future_tensor xk = x.combine(ctx, time_mix_k).consume(ctx, x_prev.combine(ctx, time_mix_k.fn(ctx))); + const struct rwkv_future_tensor xv = x.combine(ctx, time_mix_v).consume(ctx, x_prev.combine(ctx, time_mix_v.fn(ctx))); + const struct rwkv_future_tensor xr = x.combine(ctx, time_mix_r).consume(ctx, x_prev.combine(ctx, time_mix_r.fn(ctx))); - return ctx_size; + r = att_r.mul_mat(ctx, xr).fn(ctx); + k = att_k.mul_mat(ctx, xk); + v = att_v.mul_mat(ctx, xv); } void rwkv_att_rkv( struct ggml_context * ctx, struct rwkv_layer layer, - struct ggml_tensor * x0, - struct ggml_tensor * xx, + struct ggml_tensor * x, + struct ggml_tensor * x_prev, struct ggml_tensor *& r, struct ggml_tensor *& k, struct ggml_tensor *& v ) { // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) struct ggml_tensor * xk = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_k), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) + ggml_mul(ctx, x, layer.att_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_k)) ); // xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) struct ggml_tensor * xv = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_v), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) + ggml_mul(ctx, x, layer.att_time_mix_v), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_v)) ); // xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) struct ggml_tensor * xr = ggml_add_inplace(ctx, - ggml_mul(ctx, x0, layer.att_time_mix_r), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) + ggml_mul(ctx, x, layer.att_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.att_time_mix_r)) ); // r = torch.sigmoid(rw @ xr) @@ -787,36 +872,35 @@ void rwkv_att_rkv( v = ggml_mul_mat(ctx, layer.att_value, xv); } -struct rwkv_ctx_size rwkv_att_wkv_size(const size_t n_embed = 0) { - size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); - - struct rwkv_ctx_size ctx_size; - /* ww */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* e1 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); - /* e1 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); - /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - - /* a */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); - /* b */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_embed); - - /* ww */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* qq */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* e1 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); - /* e1 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); - /* e2 */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - - /* aa */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); - /* bb */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_embed); - /* pp */ rwkv_ctx_size_add_tensor(ctx_size, 0, 0, GGML_TYPE_F32, n_embed); - - /* wkv */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); +struct rwkv_future_tensor rwkv_future_att_wkv(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor time_first, + const struct rwkv_future_tensor time_decay, + struct rwkv_future_tensor & aa, + struct rwkv_future_tensor & bb, + struct rwkv_future_tensor & pp, + const struct rwkv_future_tensor k, + const struct rwkv_future_tensor v +) { + struct rwkv_future_tensor ww = time_first.combine(ctx, k); + struct rwkv_future_tensor qq = pp.fn(ctx); + struct rwkv_future_tensor e1 = pp.combine(ctx, qq).fn(ctx); + struct rwkv_future_tensor e2 = ww.combine(ctx, qq).fn(ctx); + + struct rwkv_future_tensor a = e1.combine(ctx, aa).consume(ctx, e2.combine(ctx, v)); + struct rwkv_future_tensor b = e1.combine(ctx, bb).consume(ctx, e2); + + ww = pp.combine(ctx, time_decay); + qq = ww.fn(ctx); + e1 = ww.combine(ctx, qq).fn(ctx); + e2 = k.combine(ctx, qq).fn(ctx); + + // aa, bb + aa = e1.combine(ctx, aa).consume(ctx, e2.combine(ctx, v)); + bb = e1.combine(ctx, bb).consume(ctx, e2); + pp = qq; - return ctx_size; + // wkv + return a.combine(ctx, b); } struct ggml_tensor * rwkv_att_wkv( @@ -863,22 +947,42 @@ struct ggml_tensor * rwkv_att_wkv( return ggml_div(ctx, a, b); } -struct rwkv_ctx_size rwkv_att_size(const size_t n_embed = 0) { - struct rwkv_ctx_size ctx_size; - /* xx */ rwkv_ctx_size_add(ctx_size, 1, rwkv_xx_size(n_embed)); - /* rkv */ rwkv_ctx_size_add(ctx_size, 1, rwkv_att_rkv_size(n_embed)); - /* wkv */ rwkv_ctx_size_add(ctx_size, 1, rwkv_att_wkv_size(n_embed)); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed); - return ctx_size; +struct rwkv_future_tensor rwkv_future_att(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor ln1_weight, + const struct rwkv_future_tensor ln1_bias, + const struct rwkv_future_tensor time_mix_k, + const struct rwkv_future_tensor time_mix_v, + const struct rwkv_future_tensor time_mix_r, + const struct rwkv_future_tensor time_first, + const struct rwkv_future_tensor time_decay, + const struct rwkv_future_tensor att_r, + const struct rwkv_future_tensor att_k, + const struct rwkv_future_tensor att_v, + const struct rwkv_future_tensor att_output, + struct rwkv_future_tensor x, + struct rwkv_future_tensor & att_xx, + struct rwkv_future_tensor & att_aa, + struct rwkv_future_tensor & att_bb, + struct rwkv_future_tensor & att_pp +) { + struct rwkv_future_tensor x_prev; + rwkv_future_carry_x(ctx, ln1_weight, ln1_bias, x, x_prev, att_xx); + + struct rwkv_future_tensor r, k, v; + rwkv_future_att_rkv(ctx, time_mix_k, time_mix_v, time_mix_r, x, x_prev, att_r, att_k, att_v, r, k, v); + + struct rwkv_future_tensor wkv = rwkv_future_att_wkv(ctx, time_first, time_decay, att_aa, att_bb, att_pp, k, v); + + return att_output.mul_mat(ctx, r.combine(ctx, wkv)); } struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { - struct ggml_tensor * x0 = x, * xx; - rwkv_xx(ctx, layer.ln1_weight, layer.ln1_bias, x0, xx, state.att_xx); + struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); struct ggml_tensor * r, * k, * v; - rwkv_att_rkv(ctx, layer, x0, xx, r, k, v); + rwkv_att_rkv(ctx, layer, x, x_prev, r, k, v); struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, k, v, state.att_aa, state.att_bb, state.att_pp); @@ -886,74 +990,133 @@ struct ggml_tensor * rwkv_att(struct ggml_context * ctx, struct ggml_tensor * x, return ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, wkv)); } -struct rwkv_ctx_size rwkv_ffn_size(const size_t n_embed = 0, const size_t ffn_key = 0, const size_t sequence_len = 1) { - size_t ptr_nelem = sizeof(void *) / sizeof(uint32_t); - - struct rwkv_ctx_size ctx_size; - /* xx */ rwkv_ctx_size_add(ctx_size, 1, rwkv_xx_size(n_embed, sequence_len)); - - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* xk */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed, sequence_len); - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* xr */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); +struct rwkv_future_tensor rwkv_future_ffn(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor ln2_weight, + const struct rwkv_future_tensor ln2_bias, + const struct rwkv_future_tensor time_mix_k, + const struct rwkv_future_tensor time_mix_r, + const struct rwkv_future_tensor ffn_k, + const struct rwkv_future_tensor ffn_v, + const struct rwkv_future_tensor ffn_r, + struct rwkv_future_tensor x, + struct rwkv_future_tensor & ffn_xx +) { + struct rwkv_future_tensor x_prev; + rwkv_future_carry_x(ctx, ln2_weight, ln2_bias, x, x_prev, ffn_xx); - /* r */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed, sequence_len); - /* r */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, ptr_nelem); - /* k */ rwkv_ctx_size_add_tensor(ctx_size, 3, 0, GGML_TYPE_F32, ffn_key, sequence_len); + struct rwkv_future_tensor xk = x.combine(ctx, time_mix_k).consume(ctx, x_prev.combine(ctx, time_mix_k.fn(ctx))); + struct rwkv_future_tensor xr = x.combine(ctx, time_mix_r).consume(ctx, x_prev.combine(ctx, time_mix_r.fn(ctx))); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 0, GGML_TYPE_F32, n_embed, sequence_len); + struct rwkv_future_tensor r = ffn_r.mul_mat(ctx, xr).fn(ctx); + struct rwkv_future_tensor k = ffn_k.mul_mat(ctx, xk).view(ctx).view(ctx); - return ctx_size; + return r.consume(ctx, ffn_v.mul_mat(ctx, k)); } struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { - struct ggml_tensor * x0 = x, * xx; - rwkv_xx(ctx, layer.ln2_weight, layer.ln2_bias, x0, xx, state.ffn_xx); + struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); // xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) // xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) struct ggml_tensor * xk = ggml_add_inplace( ctx, - ggml_mul(ctx, x0, layer.ffn_time_mix_k), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) + ggml_mul(ctx, x, layer.ffn_time_mix_k), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_k)) ); // xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) struct ggml_tensor * xr = ggml_add_inplace( ctx, - ggml_mul(ctx, x0, layer.ffn_time_mix_r), - ggml_mul(ctx, xx, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) + ggml_mul(ctx, x, layer.ffn_time_mix_r), + ggml_mul(ctx, x_prev, rwkv_1_minus_x(ctx, layer.ffn_time_mix_r)) ); // r = torch.sigmoid(rw @ xr) struct ggml_tensor * r = rwkv_sigmoid(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); // k = torch.square(torch.relu(kw @ xk)) - struct ggml_tensor * k = ggml_sqr(ctx, ggml_relu(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); + struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); // r * (vw @ k) - return ggml_mul(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); + return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); } -struct rwkv_ctx_size rwkv_serial_graph_size(const size_t n_vocab, const size_t n_embed, const size_t n_layer, const size_t ffn_key_size) { - struct rwkv_ctx_size ctx_size; - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); +struct rwkv_future_tensor rwkv_future_graph_work(struct rwkv_future_ctx & ctx, + const enum ggml_type type, + const size_t ffn_key_height, + const size_t n_threads, + const size_t sequence_len = 1 +) { +#ifdef GGML_USE_CUBLAS + enum ggml_type mul_mat_type = type == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; +#else + enum ggml_type mul_mat_type = ggml_is_quantized(type) ? GGML_TYPE_Q8_1 : type; +#endif + return ctx.alloc(GGML_TYPE_I8, rwkv_future_tensor::size(mul_mat_type, ffn_key_height, sequence_len) * n_threads + 64 * (n_threads - 1)); +} + +struct rwkv_future_tensor rwkv_future_serial_graph(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor tokens, + const size_t n_threads, + + const struct rwkv_future_tensor emb, + const struct rwkv_future_tensor ln0_weight, + const struct rwkv_future_tensor ln0_bias, + + const size_t n_layer, + + const struct rwkv_future_tensor ln1_weight, + const struct rwkv_future_tensor ln1_bias, + const struct rwkv_future_tensor att_time_mix_k, + const struct rwkv_future_tensor att_time_mix_v, + const struct rwkv_future_tensor att_time_mix_r, + const struct rwkv_future_tensor att_time_first, + const struct rwkv_future_tensor att_time_decay, + const struct rwkv_future_tensor att_r, + const struct rwkv_future_tensor att_k, + const struct rwkv_future_tensor att_v, + const struct rwkv_future_tensor att_output, + struct rwkv_future_tensor & att_xx, + struct rwkv_future_tensor & att_aa, + struct rwkv_future_tensor & att_bb, + struct rwkv_future_tensor & att_pp, + + const struct rwkv_future_tensor ln2_weight, + const struct rwkv_future_tensor ln2_bias, + const struct rwkv_future_tensor ffn_time_mix_k, + const struct rwkv_future_tensor ffn_time_mix_r, + const struct rwkv_future_tensor ffn_k, + const struct rwkv_future_tensor ffn_v, + const struct rwkv_future_tensor ffn_r, + struct rwkv_future_tensor & ffn_xx, + + const struct rwkv_future_tensor ln_out_weight, + const struct rwkv_future_tensor ln_out_bias, + const struct rwkv_future_tensor head +) { + struct rwkv_future_tensor x = emb.get_rows(ctx, tokens).layer_norm(ctx, ln0_weight, ln0_bias); - /* att */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_att_size(n_embed)); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); - /* ffn */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_ffn_size(n_embed, ffn_key_size)); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed); + for (size_t i = 0; i < n_layer; i++) { + x = x.consume(ctx, rwkv_future_att(ctx, + ln1_weight, ln1_bias, att_time_mix_k, att_time_mix_v, att_time_mix_r, att_time_first, att_time_decay, + att_r, att_k, att_v, att_output, x, att_xx, att_aa, att_bb, att_pp)); + + x = x.consume(ctx, rwkv_future_ffn(ctx, + ln2_weight, ln2_bias, ffn_time_mix_k, ffn_time_mix_r, ffn_k, ffn_v, ffn_r, x, ffn_xx)); + + ffn_xx.view(ctx); + att_xx.view(ctx); + att_aa.view(ctx); + att_bb.view(ctx); + att_pp.view(ctx); + } - /* output */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * 5, GGML_TYPE_F32, n_embed); + x = x.layer_norm(ctx, ln_out_weight, ln_out_bias); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 1, GGML_TYPE_F32, n_embed); - /* logits */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_vocab); + rwkv_future_graph_work(ctx, ffn_k.type, ffn_k.height, n_threads, tokens.width); - return ctx_size; + return head.mul_mat(ctx, x).view(ctx); } bool rwkv_build_serial_graph( @@ -995,31 +1158,78 @@ bool rwkv_build_serial_graph( return true; } -struct rwkv_ctx_size rwkv_sequence_graph_size(const size_t n_vocab, const size_t n_embed, const size_t n_layer, const size_t ffn_key_size, const size_t sequence_len) { - struct rwkv_ctx_size ctx_size; - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_F32, n_embed, sequence_len); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 4, 1, GGML_TYPE_F32, n_embed, sequence_len); +struct rwkv_future_tensor rwkv_future_sequence_graph(struct rwkv_future_ctx & ctx, + const struct rwkv_future_tensor tokens, + const size_t n_threads, + + const struct rwkv_future_tensor emb, + const struct rwkv_future_tensor ln0_weight, + const struct rwkv_future_tensor ln0_bias, + + const size_t n_layer, + + const struct rwkv_future_tensor ln1_weight, + const struct rwkv_future_tensor ln1_bias, + const struct rwkv_future_tensor att_time_mix_k, + const struct rwkv_future_tensor att_time_mix_v, + const struct rwkv_future_tensor att_time_mix_r, + const struct rwkv_future_tensor att_time_first, + const struct rwkv_future_tensor att_time_decay, + const struct rwkv_future_tensor att_r, + const struct rwkv_future_tensor att_k, + const struct rwkv_future_tensor att_v, + const struct rwkv_future_tensor att_output, + struct rwkv_future_tensor & att_xx, + struct rwkv_future_tensor & att_aa, + struct rwkv_future_tensor & att_bb, + struct rwkv_future_tensor & att_pp, + + const struct rwkv_future_tensor ln2_weight, + const struct rwkv_future_tensor ln2_bias, + const struct rwkv_future_tensor ffn_time_mix_k, + const struct rwkv_future_tensor ffn_time_mix_r, + const struct rwkv_future_tensor ffn_k, + const struct rwkv_future_tensor ffn_v, + const struct rwkv_future_tensor ffn_r, + struct rwkv_future_tensor & ffn_xx, + + const struct rwkv_future_tensor ln_out_weight, + const struct rwkv_future_tensor ln_out_bias, + const struct rwkv_future_tensor head +) { + struct rwkv_future_tensor x = emb.get_rows(ctx, tokens); + x = x.layer_norm(ctx, ln0_weight.repeat(ctx, x), ln0_bias.repeat(ctx, x)); - /* xx */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_xx_size(n_embed, sequence_len)); - /* rkv */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_att_rkv_size(n_embed, sequence_len)); + for (size_t i = 0; i < n_layer; i++) { + struct rwkv_future_tensor x0 = x, x_prev; + rwkv_future_carry_x(ctx, ln1_weight, ln1_bias, x0, x_prev, att_xx); + + struct rwkv_future_tensor r, k, v; + rwkv_future_att_rkv(ctx, att_time_mix_k, att_time_mix_v, att_time_mix_r, x0, x_prev, att_r, att_k, att_v, r, k, v); + + for (size_t i = 0; i < tokens.width; i++) { + struct rwkv_future_tensor kt = k.subview(ctx, emb.width); + struct rwkv_future_tensor vt = v.subview(ctx, emb.width); + struct rwkv_future_tensor xt = x_prev.subview(ctx, emb.width); + struct rwkv_future_tensor wkv = rwkv_future_att_wkv(ctx, att_time_first, att_time_decay, att_aa, att_bb, att_pp, k, v); + wkv.view(ctx); + } - /* kt */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * sequence_len, GGML_TYPE_F32, n_embed); - /* vt */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * sequence_len, GGML_TYPE_F32, n_embed); - /* xt */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * sequence_len, GGML_TYPE_F32, n_embed); - /* wkv */ rwkv_ctx_size_add(ctx_size, n_layer * sequence_len, rwkv_att_wkv_size(n_embed)); - /* xt */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * sequence_len, GGML_TYPE_F32, n_embed); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, n_layer * 2, 0, GGML_TYPE_F32, n_embed, sequence_len); + x = x.consume(ctx, att_output.mul_mat(ctx, r.combine(ctx, x_prev))); + x = x.consume(ctx, rwkv_future_ffn(ctx, ln2_weight, ln2_bias, ffn_time_mix_k, ffn_time_mix_r, ffn_k, ffn_v, ffn_r, x, ffn_xx)); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed, sequence_len); - /* ffn */ rwkv_ctx_size_add(ctx_size, n_layer, rwkv_ffn_size(n_embed, ffn_key_size, sequence_len)); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer, GGML_TYPE_F32, n_embed, sequence_len); + ffn_xx.view(ctx); + att_xx.view(ctx); + att_aa.view(ctx); + att_bb.view(ctx); + att_pp.view(ctx); + } - /* output */ rwkv_ctx_size_add_tensor(ctx_size, 0, n_layer * 5, GGML_TYPE_F32, n_embed); + x = x.subview(ctx, emb.width).layer_norm(ctx, ln_out_weight, ln_out_bias); - /* x */ rwkv_ctx_size_add_tensor(ctx_size, 2, 2, GGML_TYPE_F32, n_embed); - /* logits */ rwkv_ctx_size_add_tensor(ctx_size, 1, 1, GGML_TYPE_F32, n_vocab); + rwkv_future_graph_work(ctx, ffn_k.type, ffn_k.height, n_threads, tokens.width); - return ctx_size; + return head.mul_mat(ctx, x).view(ctx); } bool rwkv_build_sequence_graph( @@ -1041,23 +1251,23 @@ bool rwkv_build_sequence_graph( struct rwkv_layer & layer = model.layers[i]; struct rwkv_layer_state state = inputs[i]; - struct ggml_tensor * x0 = x, * xx; - rwkv_xx(ctx, layer.ln1_weight, layer.ln1_bias, x0, xx, state.att_xx); + struct ggml_tensor * x0 = x, * x_prev; + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x0, x_prev, state.att_xx); struct ggml_tensor * r, * k, * v; - rwkv_att_rkv(ctx, layer, x0, xx, r, k, v); + rwkv_att_rkv(ctx, layer, x0, x_prev, r, k, v); ggml_build_forward_expand(cgraph, r); for (uint32_t t = 0; t < sequence_len; t++) { struct ggml_tensor * kt = ggml_view_1d(ctx, k, n_embed, n_embed * sizeof(float) * t); struct ggml_tensor * vt = ggml_view_1d(ctx, v, n_embed, n_embed * sizeof(float) * t); - struct ggml_tensor * xt = ggml_view_1d(ctx, xx, n_embed, n_embed * sizeof(float) * t); + struct ggml_tensor * xt = ggml_view_1d(ctx, x_prev, n_embed, n_embed * sizeof(float) * t); struct ggml_tensor * wkv = rwkv_att_wkv(ctx, layer.att_time_first, layer.att_time_decay, kt, vt, state.att_aa, state.att_bb, state.att_pp); ggml_build_forward_expand(cgraph, ggml_cpy(ctx, wkv, xt)); } - x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, xx))); + x = ggml_add_inplace(ctx, x, ggml_mul_mat(ctx, layer.att_output, ggml_mul(ctx, r, x_prev))); x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); struct rwkv_layer_state & output = outputs[i]; @@ -1077,15 +1287,6 @@ bool rwkv_build_sequence_graph( return true; } -size_t rwkv_estimate_graph_work(const enum ggml_type type, const size_t ffn_key_size, const uint32_t n_threads, const size_t sequence_len = 1) { -#ifdef GGML_USE_CUBLAS - enum ggml_type mul_mat_type = GGML_TYPE_F16; -#else - enum ggml_type mul_mat_type = ggml_is_quantized(type) ? GGML_TYPE_Q8_1 : type; -#endif - return rwkv_tensor_size(GGML_TYPE_I8, rwkv_tensor_size(mul_mat_type, ffn_key_size, sequence_len) * n_threads + 64 * (n_threads - 1)); -} - void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors) { bool * ptr = ctx ? &ctx->print_errors : &global_print_errors; *ptr = print_errors; @@ -1132,14 +1333,14 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst struct rwkv_tensor_header tensor_header; std::string name; - struct rwkv_ctx_size ctx_size; + struct rwkv_future_ctx future_ctx; while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_tensor_header(file.file, tensor_header), "Invalid tensor header"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(file.file, tensor_header.key_length, name), "Failed to read tensor name"); - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, rwkv_tensor_size(tensor_header), SEEK_CUR) == 0, "Failed to read tensor data"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, tensor_header.size(), SEEK_CUR) == 0, "Failed to read tensor data"); - rwkv_ctx_size_add_tensor(ctx_size, 1, 0, tensor_header); + future_ctx.alloc(rwkv_type_to_ggml[tensor_header.data_type], tensor_header.width, tensor_header.height); if (ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") { ffn_key_size = tensor_header.height; @@ -1149,7 +1350,7 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, ffn_key_size, "Model is missing parameter blocks.0.ffn.key.weight"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, sizeof(struct rwkv_file_header), SEEK_SET) == 0, "Failed to seek in file"); - ctx = ctx_size; + ctx = future_ctx; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, ctx.ctx, "Failed to allocate model context"); struct ggml_tensor * tensor; @@ -1188,14 +1389,20 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptrffn_key_size)); - /* work */ rwkv_ctx_size_add(graph_ctx_size, 1, rwkv_estimate_graph_work(rwkv_type_to_ggml[header.data_type], instance->ffn_key_size, n_threads)); + struct rwkv_future_ctx graph_future_ctx; + const struct rwkv_future_tensor future_token = graph_future_ctx.alloc(GGML_TYPE_I32, 1, 1, false); + + const struct rwkv_model & model = instance->model; + const struct rwkv_layer & layer = model.layers[0]; + const struct rwkv_layer_state & state = inputs[0]; + struct rwkv_future_tensor ffn_xx = state.ffn_xx; + struct rwkv_future_tensor att_xx = state.att_xx; + struct rwkv_future_tensor att_aa = state.att_aa; + struct rwkv_future_tensor att_bb = state.att_bb; + struct rwkv_future_tensor att_pp = state.att_pp; + + const struct rwkv_future_tensor future_graph = rwkv_future_serial_graph(graph_future_ctx, future_token, n_threads, + model.emb, + model.ln0_weight, model.ln0_bias, + + n_layer, + layer.ln1_weight, layer.ln1_bias, + layer.att_time_mix_k, layer.att_time_mix_v, layer.att_time_mix_r, + layer.att_time_first, layer.att_time_decay, + layer.att_receptance, layer.att_key, layer.att_value, layer.att_output, + att_xx, att_aa, att_bb, att_pp, + + layer.ln2_weight, layer.ln2_bias, + layer.ffn_time_mix_k, layer.ffn_time_mix_r, + layer.ffn_key, layer.ffn_value, layer.ffn_receptance, + ffn_xx, + + model.ln_out_weight, model.ln_out_weight, + model.head + ); struct rwkv_graph serial_graph; - serial_graph.ctx = graph_ctx_size; + serial_graph.ctx = graph_future_ctx; RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, serial_graph.ctx.ctx, "Failed to allocate serial graph context"); serial_graph.tokens = ggml_new_i32(serial_graph.ctx.ctx, 0); serial_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); @@ -1278,25 +1512,31 @@ struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32 bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) { #ifdef GGML_USE_CUBLAS - uint32_t layers_to_offload = std::min(n_layers, ctx->instance->model.header.n_layer - ctx->gpu_layers); - - for (uint32_t i = 0; i < layers_to_offload; i++) { - const struct rwkv_layer & layer = ctx->instance->model.layers[ctx->gpu_layers + i]; - - // Use cuBLAS only for heavy matrices; other operations are not supported for the GPU at the moment - ggml_cuda_transform_tensor(layer.att_key); - ggml_cuda_transform_tensor(layer.att_value); - ggml_cuda_transform_tensor(layer.att_receptance); - ggml_cuda_transform_tensor(layer.att_output); + const auto offload = [&](struct ggml_tensor * tensor) { + // TODO support multi-GPU + tensor->backend = GGML_BACKEND_GPU; + ggml_cuda_transform_tensor(tensor->data, tensor); + }; + + const size_t n_gpu = std::min(n_layers, ctx->instance->model.header.n_layer); + + if (ctx->gpu_layers < n_gpu) { + for (size_t & i = ctx->gpu_layers; i < n_gpu; i++) { + const struct rwkv_layer & layer = ctx->instance->model.layers[i]; + + // TODO also offload other operations to GPU with ggml_cuda_assign_buffers + offload(layer.att_key); + offload(layer.att_value); + offload(layer.att_receptance); + offload(layer.att_output); + + offload(layer.ffn_key); + offload(layer.ffn_value); + offload(layer.ffn_receptance); + } - ggml_cuda_transform_tensor(layer.ffn_key); - ggml_cuda_transform_tensor(layer.ffn_value); - ggml_cuda_transform_tensor(layer.ffn_receptance); + return true; } - - ctx->gpu_layers += layers_to_offload; - - return layers_to_offload > 0; #endif return false; } @@ -1351,25 +1591,50 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, co if (ctx->sequence_len != sequence_len) { // Build new sequence graph - struct rwkv_ctx_size ctx_size; - /* tokens */ rwkv_ctx_size_add_tensor(ctx_size, 1, 0, GGML_TYPE_I32, sequence_len); - /* graph */ rwkv_ctx_size_add(ctx_size, 1, rwkv_sequence_graph_size(n_vocab, n_embed, n_layer, ctx->instance->ffn_key_size, sequence_len)); - /* work */ rwkv_ctx_size_add(ctx_size, 1, rwkv_estimate_graph_work(rwkv_type_to_ggml[header.data_type], ctx->instance->ffn_key_size, 1, sequence_len)); - - struct rwkv_graph graph; - graph.ctx = ctx_size; - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, graph.ctx.ctx, "Failed to allocate sequence graph context"); - graph.tokens = ggml_new_tensor_1d(graph.ctx.ctx, GGML_TYPE_I32, sequence_len); - graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, graph.cgraph, "Failed to allocate sequence graph"); - graph.cgraph->n_threads = 1; - RWKV_ASSERT_FALSE( - RWKV_ERROR_GRAPH, - rwkv_build_sequence_graph(graph.ctx.ctx, ctx->instance->model, graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits, graph.cgraph.get()) + + struct rwkv_future_ctx graph_future_ctx; + const struct rwkv_future_tensor future_tokens = graph_future_ctx.alloc(GGML_TYPE_I32, sequence_len); + + const struct rwkv_model & model = ctx->instance->model; + const struct rwkv_layer & layer = model.layers[0]; + const struct rwkv_layer_state & state = ctx->input_layers[0]; + struct rwkv_future_tensor ffn_xx = state.ffn_xx; + struct rwkv_future_tensor att_xx = state.att_xx; + struct rwkv_future_tensor att_aa = state.att_aa; + struct rwkv_future_tensor att_bb = state.att_bb; + struct rwkv_future_tensor att_pp = state.att_pp; + + const struct rwkv_future_tensor future_graph = rwkv_future_sequence_graph(graph_future_ctx, future_tokens, 1, + model.emb, + model.ln0_weight, model.ln0_bias, + + n_layer, + layer.ln1_weight, layer.ln1_bias, + layer.att_time_mix_k, layer.att_time_mix_v, layer.att_time_mix_r, + layer.att_time_first, layer.att_time_decay, + layer.att_receptance, layer.att_key, layer.att_value, layer.att_output, + att_xx, att_aa, att_bb, att_pp, + + layer.ln2_weight, layer.ln2_bias, + layer.ffn_time_mix_k, layer.ffn_time_mix_r, + layer.ffn_key, layer.ffn_value, layer.ffn_receptance, + ffn_xx, + + model.ln_out_weight, model.ln_out_weight, + model.head ); - ((struct rwkv_context *) ctx)->sequence_len = sequence_len; - ((struct rwkv_context *) ctx)->sequence_graph = std::move(graph); + struct rwkv_graph sequence_graph; + sequence_graph.ctx = graph_future_ctx; + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, sequence_graph.ctx.ctx, "Failed to allocate sequence graph context"); + sequence_graph.tokens = ggml_new_tensor_1d(sequence_graph.ctx.ctx, GGML_TYPE_I32, sequence_len); + sequence_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, sequence_graph.cgraph, "Failed to allocate sequence graph"); + sequence_graph.cgraph->n_threads = 1; + RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph(sequence_graph.ctx.ctx, ctx->instance->model, sequence_graph.tokens, ctx->input_layers.get(), ctx->output_layers.get(), ctx->logits, sequence_graph.cgraph.get())); + + ctx->sequence_len = sequence_len; + ctx->sequence_graph = std::move(sequence_graph); } // Allow building the sequence graph without actually evaluating, by specifying sequence = NULL. @@ -1486,7 +1751,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const struct rwkv_tensor_header header; RWKV_ASSERT_FALSE(RWKV_ERROR_FILE, rwkv_fread_tensor_header_and_skip(in_file.file, header)); - size_t in_size = rwkv_tensor_size(header); + size_t in_size = header.size(); if (in_size > max_in_size) { max_in_size = in_size; @@ -1498,14 +1763,14 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const max_out_size = in_size; } - size_t f32_size = rwkv_tensor_size(GGML_TYPE_F32, header.width, header.height); + size_t f32_size = rwkv_future_tensor::size(GGML_TYPE_F32, header.width, header.height); if (f32_size > max_in_size) { max_in_size = f32_size; } } - size_t out_size = rwkv_tensor_size(out_type, header.width, header.height); + size_t out_size = rwkv_future_tensor::size(out_type, header.width, header.height); if (out_size > max_out_size) { max_out_size = out_size; @@ -1540,7 +1805,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const RWKV_MSG("%*s - [%5" PRId32 ", %5" PRId32 "], type = %6s ", (int) max_key_length, name_str, header.width, header.height, rwkv_type_to_string[header.data_type]); data = header.data_type == TYPE_FP16 ? out_buf : in_buf; - size_t orig_size = rwkv_tensor_size(header), new_size = orig_size; + size_t orig_size = header.size(), new_size = orig_size; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_data(in_file.file, orig_size, data), "\nFailed to read tensor data of %s", name_str); // Quantize only 2D tensors, except embedding and head matrices.