Skip to content

Commit

Permalink
v6 but a?
Browse files Browse the repository at this point in the history
  • Loading branch information
YuChuXi committed Apr 4, 2024
1 parent d56d64f commit 1abd4e8
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 51 deletions.
8 changes: 8 additions & 0 deletions python/convert_pytorch_to_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
if is_v6_0:
if '.time_faaaa' in k:
tensor = tensor.unsqueeze(-1)
if '.time_maa_w1' in k or '.time_decay_w' in k:
tensor = tensor.transpose(0,1)
tensor.contiguous()

if '.time_maa_w2' in k:
# (5, 32, 2048) -> (32, 2048, 5)
tensor = tensor.permute(0,2,1)
tensor.contiguous()

elif is_v5_1_or_2:
if '.time_decay' in k:
Expand Down
117 changes: 66 additions & 51 deletions rwkv_graph.inc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ggml.h"
#include <stdio.h>
// View tensors of a state of a single layer.
struct rwkv_layer_state {
struct ggml_tensor * ffn_xx;
Expand Down Expand Up @@ -69,7 +70,6 @@ static void rwkv_carry_x(
if (sequence_len == 1) {
// xx = state[5*i+0]
x_prev = carry;

// state[5*i+0] = x
carry = x;
} else {
Expand Down Expand Up @@ -190,19 +190,7 @@ static struct ggml_tensor * rwkv_att_v5(

struct ggml_tensor * x_prev;

x = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);

if (sequence_length > 1) {
x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_length);
x_prev = ggml_set_1d_inplace(ctx, x_prev, state.att_xx, 0);
x_prev = ggml_set_1d_inplace(
ctx,
x_prev,
ggml_view_1d(ctx, x, n_embed * (sequence_length - 1), 0), n_embed * sizeof(float)
);
} else {
x_prev = state.att_xx;
}
rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx);

struct ggml_tensor * xk = ggml_add_inplace(
ctx,
Expand Down Expand Up @@ -248,7 +236,6 @@ static struct ggml_tensor * rwkv_att_v5(
);
}

state.att_xx = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_length - 1) * sizeof(float));

struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, 1, head_count, sequence_length);
struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), 1, head_size, head_count, sequence_length);
Expand Down Expand Up @@ -325,6 +312,19 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
(t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
(t1->ne[3]%t0->ne[3] == 0);
}


static inline void print_shape(const char * echo,const struct ggml_tensor * t){
printf(echo);
printf(" (%ld, %ld, %ld, %ld)\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
}

static inline void pppp(const char * echo){
printf(echo);
printf("\n");
}


static struct ggml_tensor * rwkv_att_v6(
struct ggml_context * ctx,
struct ggml_tensor * x,
Expand All @@ -341,41 +341,56 @@ static struct ggml_tensor * rwkv_att_v6(

rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx);

// sx = state[i1] - x
x_prev = ggml_sub_inplace(ctx, x_prev, x);

// xxx = x + sx * x_maa
// xxx = x + sx * x_maa (C, T)
struct ggml_tensor * xxx = ggml_add_inplace(
ctx,
ggml_mul(ctx, x_prev, layer.att_time_maa_x),
x
);

// xxx = torch.tanh(xxx @ tm_w1).view(T, 5, 1, -1).permute(1,0,2,3)
xxx = ggml_permute(
// xxx (2048, T) @ att_time_maa_w1 (2048, 160) -> (T, 160) -> (T, 5, 32, 1) -> xxx (32, 1, 5, T)
xxx = ggml_reshape_4d(
ctx,
ggml_reshape_4d(
ggml_tanh(
ctx,
ggml_tanh(
ctx,
ggml_mul_mat(ctx, xxx, layer.att_time_maa_w1)
),
sequence_length, 5, 1, layer.att_time_maa_w2->ne[1]
ggml_mul_mat(ctx, xxx, layer.att_time_maa_w1)
),
1, 0, 2, 3
sequence_length, 5, layer.att_time_maa_w2->ne[0], 1
);

xxx = ggml_permute(
ctx,
xxx,
3, 2, 0, 1
);

// xxx = torch.bmm(xxx, tm_w2).view(5, T, -1)
// xxx (32, 1, 5, T) @ att_time_maa_w2 (32, 2048, 5) -> (1, 2048, 5, T) -> (5, 2048, T, 1) -> xxx (5, 2048, T)
xxx = ggml_reshape_3d(
ctx,
ggml_mul_mat(ctx, xxx, layer.att_time_maa_w2),
sequence_length, 5, n_embed
ggml_cont(
ctx,
ggml_permute(ctx,ggml_mul_mat(ctx, xxx, layer.att_time_maa_w2), 3, 1, 0, 2)
),
5, layer.att_time_maa_w2->ne[1], sequence_length
);

// mw (2048, T) * att_time_maa_w (2048) + sx (2048, T) -> xw (2048, T)
// !!!! 检查是否正确,因为转置
struct ggml_tensor * mw = ggml_view_1d(ctx, xxx, n_embed, n_embed * 0 * sizeof(float));
struct ggml_tensor * mk = ggml_view_1d(ctx, xxx, n_embed, n_embed * 1 * sizeof(float));
struct ggml_tensor * mv = ggml_view_1d(ctx, xxx, n_embed, n_embed * 2 * sizeof(float));
struct ggml_tensor * mr = ggml_view_1d(ctx, xxx, n_embed, n_embed * 3 * sizeof(float));
struct ggml_tensor * mg = ggml_view_1d(ctx, xxx, n_embed, n_embed * 4 * sizeof(float));

struct ggml_tensor * xw = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx,
ggml_add_inplace(ctx, xxx->src[0], layer.att_time_maa_w),
ggml_add_inplace(ctx, mw, layer.att_time_maa_w),
x_prev
),
x
Expand All @@ -385,7 +400,7 @@ static struct ggml_tensor * rwkv_att_v6(
ctx,
ggml_mul_inplace(
ctx,
ggml_add_inplace(ctx, xxx->src[1], layer.att_time_maa_k),
ggml_add_inplace(ctx, mk, layer.att_time_maa_k),
x_prev
),
x
Expand All @@ -395,7 +410,7 @@ static struct ggml_tensor * rwkv_att_v6(
ctx,
ggml_mul_inplace(
ctx,
ggml_add_inplace(ctx, xxx->src[2], layer.att_time_maa_v),
ggml_add_inplace(ctx, mw, layer.att_time_maa_v),
x_prev
),
x
Expand All @@ -405,7 +420,7 @@ static struct ggml_tensor * rwkv_att_v6(
ctx,
ggml_mul_inplace(
ctx,
ggml_add_inplace(ctx, xxx->src[3], layer.att_time_maa_r),
ggml_add_inplace(ctx, mr, layer.att_time_maa_r),
x_prev
),
x
Expand All @@ -415,37 +430,37 @@ static struct ggml_tensor * rwkv_att_v6(
ctx,
ggml_mul_inplace(
ctx,
ggml_add_inplace(ctx, xxx->src[4], layer.att_time_maa_g),
ggml_add_inplace(ctx, mg, layer.att_time_maa_g),
x_prev
),
x
);

struct ggml_tensor * time_decay_w = ggml_reshape_4d(

// w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2)).view(H, S, 1)
// att_time_decay_w1 (2048, 64) @ xw (2048, T) -> (64, T)
// att_time_decay_w2 (64, 2048) @ (64, T) -> (2048, T) -> w (1, S, H, T) <=> (1, S, H, T) att_heads
struct ggml_tensor * w = ggml_add_inplace(
ctx,
ggml_add_inplace(
ggml_mul_mat(
ctx,
ggml_mul_mat(
layer.att_time_decay_w2,
ggml_tanh_inplace(
ctx,
ggml_tanh_inplace(
ctx,
ggml_mul_mat(ctx, xw, layer.att_time_decay_w1)
),
layer.att_time_decay_w2
),
layer.att_time_decay
ggml_mul_mat(ctx,layer.att_time_decay_w1, xw)
)
),
1, head_size, head_count, sequence_length
);

time_decay_w = rwkv_exp(ctx,
rwkv_1_minus_x(ctx,
rwkv_exp(ctx, time_decay_w)
)
layer.att_time_decay
);
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, sequence_length);

state.att_xx = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_length - 1) * sizeof(float));
// w = torch.exp(-torch.exp(w))
w = rwkv_exp(ctx,rwkv_1_minus_x(ctx,rwkv_exp(ctx, w)));

// r = (rw @ xr).view(H, 1, S)
// k = (kw @ xk).view(H, S, 1)
// v = (vw @ xv).view(H, 1, S)
// g = F.silu(gw @ xg)
struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, 1, head_count, sequence_length);
struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), 1, head_size, head_count, sequence_length);
struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_value, xv), head_size, 1, head_count, sequence_length);
Expand All @@ -468,7 +483,7 @@ static struct ggml_tensor * rwkv_att_v6(
v,
r,
layer.att_time_faaaa,
time_decay_w,
w,
state_out
);

Expand Down
1 change: 1 addition & 0 deletions rwkv_model_loading.inc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) {
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w2"), buffer), layer.att_time_decay_w2));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate));

RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_k"), buffer), layer.ffn_time_maa_k));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_r"), buffer), layer.ffn_time_maa_r));
Expand Down

0 comments on commit 1abd4e8

Please sign in to comment.