Skip to content

Commit

Permalink
fix load error
Browse files Browse the repository at this point in the history
  • Loading branch information
YuChuXi committed Apr 3, 2024
1 parent 3ef4086 commit cdb1064
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 29 deletions.
20 changes: 10 additions & 10 deletions rwkv_graph.inc
Original file line number Diff line number Diff line change
Expand Up @@ -330,30 +330,30 @@ static struct ggml_tensor * rwkv_att_v6(

struct ggml_tensor * xxx = ggml_add_inplace(
ctx,
ggml_mul(ctx, x_prev, layer.att_x_maa),
ggml_mul(ctx, x_prev, layer.time_maa_x),
x
);

xxx = ggml_reshape_4d(
ctx,
ggml_tanh(
ctx,
ggml_mul_mat(ctx,xxx,layer.att_tm_w1)
ggml_mul_mat(ctx,xxx,layer.time_maa_w1)
),
5, -1, 1, sequence_length
);

xxx = ggml_reshape_3d(
ctx,
ggml_mul_inplace(ctx, xxx, layer.att_tm_w2),
ggml_mul_inplace(ctx, xxx, layer.time_maa_w2),
5, -1, sequence_length
);

struct ggml_tensor * xw = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx,
ggml_add_inplace(ctx, xxx->src[0], layer.att_w_maa),
ggml_add_inplace(ctx, xxx->src[0], layer.time_maa_w),
x_prev
),
x
Expand All @@ -363,7 +363,7 @@ static struct ggml_tensor * rwkv_att_v6(
ctx,
ggml_mul_inplace(
ctx,
ggml_add_inplace(ctx, xxx->src[1], layer.att_k_maa),
ggml_add_inplace(ctx, xxx->src[1], layer.time_maa_k),
x_prev
),
x
Expand All @@ -373,7 +373,7 @@ static struct ggml_tensor * rwkv_att_v6(
ctx,
ggml_mul_inplace(
ctx,
ggml_add_inplace(ctx, xxx->src[2], layer.att_v_maa),
ggml_add_inplace(ctx, xxx->src[2], layer.time_maa_v),
x_prev
),
x
Expand All @@ -383,7 +383,7 @@ static struct ggml_tensor * rwkv_att_v6(
ctx,
ggml_mul_inplace(
ctx,
ggml_add_inplace(ctx, xxx->src[3], layer.att_r_maa),
ggml_add_inplace(ctx, xxx->src[3], layer.time_maa_r),
x_prev
),
x
Expand All @@ -393,7 +393,7 @@ static struct ggml_tensor * rwkv_att_v6(
ctx,
ggml_mul_inplace(
ctx,
ggml_add_inplace(ctx, xxx->src[4], layer.att_g_maa),
ggml_add_inplace(ctx, xxx->src[4], layer.time_maa_g),
x_prev
),
x
Expand All @@ -407,9 +407,9 @@ static struct ggml_tensor * rwkv_att_v6(
ctx,
ggml_tanh_inplace(
ctx,
ggml_mul_mat(ctx, xw, layer.att_td_w1)
ggml_mul_mat(ctx, xw, layer.time_decay_w1)
),
layer.att_td_w2
layer.time_decay_w2
),
layer.att_time_decay
),
Expand Down
54 changes: 35 additions & 19 deletions rwkv_model_loading.inc
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ struct rwkv_layer {
struct ggml_tensor * att_gate;

// Added in RWKV v6;
struct ggml_tensor * att_x_maa;
struct ggml_tensor * att_w_maa;
struct ggml_tensor * att_k_maa;
struct ggml_tensor * att_v_maa;
struct ggml_tensor * att_r_maa;
struct ggml_tensor * att_g_maa;
struct ggml_tensor * att_tm_w1;
struct ggml_tensor * att_tm_w2;
struct ggml_tensor * att_td_w1;
struct ggml_tensor * att_td_w2;
struct ggml_tensor * time_maa_x;
struct ggml_tensor * time_maa_w;
struct ggml_tensor * time_maa_k;
struct ggml_tensor * time_maa_v;
struct ggml_tensor * time_maa_r;
struct ggml_tensor * time_maa_g;
struct ggml_tensor * time_maa_w1;
struct ggml_tensor * time_maa_w2;
struct ggml_tensor * time_decay_w1;
struct ggml_tensor * time_decay_w2;

struct ggml_tensor * ln2_weight;
struct ggml_tensor * ln2_bias;
Expand Down Expand Up @@ -120,9 +120,22 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) {
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.weight"), buffer), layer.ln1_weight));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.bias"), buffer), layer.ln1_bias));

RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r));
if (model.arch_version_major == 6){
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_x"), buffer), layer.att_x_maa));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w"), buffer), layer.att_w_maa));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_k"), buffer), layer.att_k_maa));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_v"), buffer), layer.att_v_maa));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_r"), buffer), layer.att_r_maa));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_g"), buffer), layer.att_g_maa));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w1"), buffer), layer.time_maa_w1));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w2"), buffer), layer.time_maa_w2));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w1"), buffer), layer.time_decay_w1));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w2"), buffer), layer.time_decay_w2));
} else {
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r));
}

if (model.arch_version_major >= 5 && model.arch_version_minor >= 2) {
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_faaaa"), buffer), layer.att_time_faaaa));
Expand All @@ -140,7 +153,7 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) {
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias));

if (model.arch_version_minor >= 2) {
if (model.arch_version_minor >= 2 || model.arch_version_major == 6) {
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_g"), buffer), layer.att_time_mix_g));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate));
}
Expand All @@ -149,11 +162,14 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) {
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias));

RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance));
if (model.arch_version_major == 6) {
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_k"), buffer), layer.ffn_time_maa_k));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_r"), buffer), layer.ffn_time_maa_r));
} else {
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value));
RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance));
}
}

RWKV_ENSURE_OR_FALSE(callback("ln_out.weight", model.ln_out_weight));
Expand Down

0 comments on commit cdb1064

Please sign in to comment.