Skip to content

Commit

Permalink
evel
Browse files Browse the repository at this point in the history
  • Loading branch information
YuChuXi committed Apr 3, 2024
1 parent d89754c commit 6ba24eb
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 22 deletions.
62 changes: 40 additions & 22 deletions rwkv_graph.inc
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,6 @@ static struct ggml_tensor * rwkv_att_v6(
x
);

// WKVRG!!!

xxx = ggml_reshape_4d(
ctx,
ggml_tanh(
Expand All @@ -344,6 +342,7 @@ static struct ggml_tensor * rwkv_att_v6(
),
5, -1, 1, sequence_length
);

xxx = ggml_reshape_3d(
ctx,
ggml_mul_inplace(ctx, xxx, layer.att_tm_w2),
Expand All @@ -353,49 +352,49 @@ static struct ggml_tensor * rwkv_att_v6(
struct ggml_tensor * xw = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx
ggml_add_inplace(ctx, mw, layer.att_w_maa),
x_prev,
ctx,
ggml_add_inplace(ctx, xxx->src[0], layer.att_w_maa),
x_prev
),
x
);

struct ggml_tensor * xk = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx
ggml_add_inplace(ctx, mk, layer.att_k_maa),
x_prev,
ctx,
ggml_add_inplace(ctx, xxx->src[1], layer.att_k_maa),
x_prev
),
x
);

struct ggml_tensor * xv = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx
ggml_add_inplace(ctx, mv, layer.att_v_maa),
x_prev,
ctx,
ggml_add_inplace(ctx, xxx->src[2], layer.att_v_maa),
x_prev
),
x
);

struct ggml_tensor * xr = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx
ggml_add_inplace(ctx, mr, layer.att_r_maa),
x_prev,
ctx,
ggml_add_inplace(ctx, xxx->src[3], layer.att_r_maa),
x_prev
),
x
);

struct ggml_tensor * xg = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx
ggml_add_inplace(ctx, mg, layer.att_g_maa),
x_prev,
ctx,
ggml_add_inplace(ctx, xxx->src[4], layer.att_g_maa),
x_prev
),
x
);
Expand Down Expand Up @@ -624,19 +623,38 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu

struct rwkv_layer_state state = inputs[i];

x = model.arch_version_major >= 5 ?
ggml_add_inplace(ctx, x, rwkv_att_v5(
if (model.arch_version_major == 6) {
x = ggml_add_inplace(ctx, x, rwkv_att_v6(
ctx,
x,
layer,
state,
model.head_count,
model.head_size,
model.arch_version_minor
)) :
ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state));
));

x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state));
x = ggml_add_inplace(ctx, x, rwkv_ffn_v6(ctx, x, layer, state));

} else if (model.arch_version_major == 5) {
x = ggml_add_inplace(ctx, x, rwkv_att_v5(
ctx,
x,
layer,
state,
model.head_count,
model.head_size,
model.arch_version_minor
));

x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state));

} else {
x = ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state));

x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state));

}

struct rwkv_layer_state & output_state = outputs[i];

Expand Down
19 changes: 19 additions & 0 deletions rwkv_model_loading.inc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ struct rwkv_layer {
struct ggml_tensor * att_time_mix_k;
struct ggml_tensor * att_time_mix_v;
struct ggml_tensor * att_time_mix_r;

// Removed in RWKV v5.2; set to NULL for this and newer models.
struct ggml_tensor * att_time_first;
struct ggml_tensor * att_time_decay;
Expand All @@ -22,16 +23,34 @@ struct rwkv_layer {
struct ggml_tensor * att_time_faaaa;
struct ggml_tensor * att_time_mix_g;
struct ggml_tensor * att_gate;

// Added in RWKV v6;
struct ggml_tensor * att_x_maa;
struct ggml_tensor * att_w_maa;
struct ggml_tensor * att_k_maa;
struct ggml_tensor * att_v_maa;
struct ggml_tensor * att_r_maa;
struct ggml_tensor * att_g_maa;
struct ggml_tensor * att_tm_w1;
struct ggml_tensor * att_tm_w2;
struct ggml_tensor * att_td_w1;
struct ggml_tensor * att_td_w2;

struct ggml_tensor * ln2_weight;
struct ggml_tensor * ln2_bias;

// FFN.
// v4 v5.x
struct ggml_tensor * ffn_time_mix_k;
struct ggml_tensor * ffn_time_mix_r;
// v6
struct ggml_tensor * ffn_time_maa_k;
struct ggml_tensor * ffn_time_maa_r;

struct ggml_tensor * ffn_key;
struct ggml_tensor * ffn_value;
struct ggml_tensor * ffn_receptance;

};

// The model holds all parameter tensors and the ggml context containing them.
Expand Down

0 comments on commit 6ba24eb

Please sign in to comment.