Skip to content

Commit

Permalink
att G
Browse files Browse the repository at this point in the history
  • Loading branch information
YuChuXi committed Apr 3, 2024
1 parent d8f13ff commit d89754c
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 17 deletions.
14 changes: 12 additions & 2 deletions python/convert_pytorch_to_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t

is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict
is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict
is_v6_0: bool = 'blocks.0.att.time_maa' in state_dict

if is_v5_2:
if is_v6_0:
print('Detected RWKV v6.0')
elif is_v5_2:
print('Detected RWKV v5.2')
elif is_v5_1_or_2:
print('Detected RWKV v5.1')
Expand Down Expand Up @@ -63,7 +66,14 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
if '.time_' in k:
tensor = tensor.squeeze()

if is_v5_1_or_2:
if is_v6_0:
if '.time_first' in k:
tensor = torch.exp(tensor).reshape(-1, 1, 1)

if '.time_faaaa' in k:
tensor = tensor.unsqueeze(-1)

elif is_v5_1_or_2:
if '.time_decay' in k:
if is_v5_2:
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
Expand Down
10 changes: 8 additions & 2 deletions python/merge_lora_into_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def parse_args():
parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file')
parser.add_argument('src_path', help='Path to source rwkv.cpp model')
parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2', type=str, choices=['v4', 'v5.1', 'v5.2'])
parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2, v6.0', type=str, choices=['v4', 'v5.1', 'v5.2', 'v6.0'])
parser.add_argument('lora_path', help='Path to LoRA checkpoint in PyTorch format')
parser.add_argument('lora_alpha', help='Value of lora_alpha parameter used when training this LoRA checkpoint', type=int)
parser.add_argument('dest_path', help='Path to destination rwkv.cpp model, will be overwitten with the merged model')
Expand Down Expand Up @@ -47,7 +47,7 @@ def main() -> None:

arch_version: str = args.rwkv_arch_version

if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2'):
if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2' or arch_version == 'v6.0'):
raise ValueError(f'Invalid RWKV architecture version {arch_version}')

print(f'Reading {args.lora_path}')
Expand Down Expand Up @@ -108,6 +108,12 @@ def main() -> None:
if '.time_' in key:
replacement = replacement.squeeze()

if arch_version == 'v6.0':
if '.time_first' in key:
replacement = torch.exp(replacement).reshape(-1, 1, 1)

if '.time_faaaa' in key:
replacement = replacement.unsqueeze(-1)
if arch_version == 'v5.1' or arch_version == 'v5.2':
if '.time_decay' in key:
if arch_version == 'v5.2':
Expand Down
208 changes: 195 additions & 13 deletions rwkv_graph.inc
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include "ggml.h"

// View tensors of a state of a single layer.
struct rwkv_layer_state {
struct ggml_tensor * ffn_xx;
Expand Down Expand Up @@ -190,21 +192,9 @@ static struct ggml_tensor * rwkv_att_v5(
size_t n_embed = x->ne[0];
size_t sequence_length = x->ne[1];

x = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias);

struct ggml_tensor * x_prev;

if (sequence_length > 1) {
x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_length);
x_prev = ggml_set_1d_inplace(ctx, x_prev, state.att_xx, 0);
x_prev = ggml_set_1d_inplace(
ctx,
x_prev,
ggml_view_1d(ctx, x, n_embed * (sequence_length - 1), 0), n_embed * sizeof(float)
);
} else {
x_prev = state.att_xx;
}
rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx);

struct ggml_tensor * xk = ggml_add_inplace(
ctx,
Expand Down Expand Up @@ -320,6 +310,170 @@ static struct ggml_tensor * rwkv_att_v5(
return ggml_mul_mat(ctx, layer.att_output, x);
}

static struct ggml_tensor * rwkv_att_v6(
struct ggml_context * ctx,
struct ggml_tensor * x,
struct rwkv_layer layer,
struct rwkv_layer_state & state,
const int64_t head_count,
const int64_t head_size,
const uint32_t arch_version_minor
) {
size_t n_embed = x->ne[0];
size_t sequence_length = x->ne[1];

struct ggml_tensor * x_prev;

rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx);

x_prev = ggml_sub_inplace(ctx, x_prev, x);

struct ggml_tensor * xxx = ggml_add_inplace(
ctx,
ggml_mul(ctx, x_prev, layer.att_x_maa),
x
);

// WKVRG!!!

xxx = ggml_reshape_4d(
ctx,
ggml_tanh(
ctx,
ggml_mul_mat(ctx,xxx,layer.att_tm_w1)
),
5, -1, 1, sequence_length
);
xxx = ggml_reshape_3d(
ctx,
ggml_mul_inplace(ctx, xxx, layer.att_tm_w2),
5, -1, sequence_length
);

struct ggml_tensor * xw = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx
ggml_add_inplace(ctx, mw, layer.att_w_maa),
x_prev,
),
x
);

struct ggml_tensor * xk = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx
ggml_add_inplace(ctx, mk, layer.att_k_maa),
x_prev,
),
x
);

struct ggml_tensor * xv = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx
ggml_add_inplace(ctx, mv, layer.att_v_maa),
x_prev,
),
x
);

struct ggml_tensor * xr = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx
ggml_add_inplace(ctx, mr, layer.att_r_maa),
x_prev,
),
x
);

struct ggml_tensor * xg = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx
ggml_add_inplace(ctx, mg, layer.att_g_maa),
x_prev,
),
x
);

struct ggml_tensor * time_decay_w = ggml_reshape_4d(
ctx,
ggml_add_inplace(
ctx,
ggml_mul_mat(
ctx,
ggml_tanh_inplace(
ctx,
ggml_mul_mat(ctx, xw, layer.att_td_w1)
),
layer.att_td_w2
),
layer.att_time_decay
),
1, head_size, head_count, sequence_length
);

time_decay_w = rwkv_exp(ctx,
rwkv_1_minus_x(ctx,
rwkv_exp(ctx, time_decay_w)
)
);

state.att_xx = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_length - 1) * sizeof(float));

struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, 1, head_count, sequence_length);
struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), 1, head_size, head_count, sequence_length);
struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_value, xv), head_size, 1, head_count, sequence_length);

struct ggml_tensor * g = ggml_silu_inplace(
ctx,
ggml_mul_mat(ctx, layer.att_gate, xg)
);

// dup is not strictly required; doing it just in case.
struct ggml_tensor * state_out = ggml_dup(ctx, state.att_heads);

x = rwkv_wkv_v5(
ctx,
sequence_length,
n_embed,
head_count,
head_size,
x,
k,
v,
r,
layer.att_time_faaaa,
time_decay_w,
state_out
);

state.att_heads = state_out;

// ggml_group_norm considers groups in the third dimension.
x = ggml_reshape_4d(ctx, x, 1, 1, n_embed, sequence_length);
x = ggml_group_norm_inplace(ctx, x, head_count);
// Convert back to a regular vector.
x = ggml_reshape_2d(ctx, x, n_embed, sequence_length);
x = ggml_add_inplace(
ctx,
ggml_mul_inplace(
ctx,
x,
layer.att_ln_x_weight
),
layer.att_ln_x_bias
);

x = ggml_mul_inplace(ctx, x, g);

return ggml_mul_mat(ctx, layer.att_output, x);
}

static struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) {
struct ggml_tensor * x_prev;
rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx);
Expand Down Expand Up @@ -349,6 +503,34 @@ static struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tens
return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k));
}

static struct ggml_tensor * rwkv_ffn_v6(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) {
struct ggml_tensor * x_prev;
rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx);

// xk = x + sx * time_maa_k
// xr = x + sx * time_maa_r
struct ggml_tensor * xk = ggml_add_inplace(
ctx,
x,
ggml_mul(ctx, x_prev, layer.ffn_time_maa_k)
);

struct ggml_tensor * xr = ggml_add_inplace(
ctx,
x,
ggml_mul(ctx, x_prev, layer.ffn_time_maa_r)
);

// r = torch.sigmoid(rw @ xr)
struct ggml_tensor * r = rwkv_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr));

// k = torch.square(torch.relu(kw @ xk))
struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk)));

// r * (vw @ k)
return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k));
}

static void rwkv_create_input_and_output_views(
struct ggml_context * ctx,
struct rwkv_layer_state * inputs,
Expand Down

0 comments on commit d89754c

Please sign in to comment.