From d89754cf1f4241a9c892eefcba8f194db0c06be0 Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Wed, 3 Apr 2024 20:16:34 +0800 Subject: [PATCH 01/11] att G --- python/convert_pytorch_to_ggml.py | 14 +- python/merge_lora_into_ggml.py | 10 +- rwkv_graph.inc | 208 ++++++++++++++++++++++++++++-- 3 files changed, 215 insertions(+), 17 deletions(-) diff --git a/python/convert_pytorch_to_ggml.py b/python/convert_pytorch_to_ggml.py index 9956844..d0bf158 100644 --- a/python/convert_pytorch_to_ggml.py +++ b/python/convert_pytorch_to_ggml.py @@ -34,8 +34,11 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict + is_v6_0: bool = 'blocks.0.att.time_maa' in state_dict - if is_v5_2: + if is_v6_0: + print('Detected RWKV v6.0') + elif is_v5_2: print('Detected RWKV v5.2') elif is_v5_1_or_2: print('Detected RWKV v5.1') @@ -63,7 +66,14 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t if '.time_' in k: tensor = tensor.squeeze() - if is_v5_1_or_2: + if is_v6_0: + if '.time_first' in k: + tensor = torch.exp(tensor).reshape(-1, 1, 1) + + if '.time_faaaa' in k: + tensor = tensor.unsqueeze(-1) + + elif is_v5_1_or_2: if '.time_decay' in k: if is_v5_2: tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1) diff --git a/python/merge_lora_into_ggml.py b/python/merge_lora_into_ggml.py index 3988697..b9c6bd2 100644 --- a/python/merge_lora_into_ggml.py +++ b/python/merge_lora_into_ggml.py @@ -13,7 +13,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file') parser.add_argument('src_path', help='Path to source rwkv.cpp model') - parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2', type=str, choices=['v4', 'v5.1', 'v5.2']) + parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2, v6.0', type=str, choices=['v4', 'v5.1', 'v5.2', 'v6.0']) parser.add_argument('lora_path', help='Path to LoRA checkpoint in PyTorch format') parser.add_argument('lora_alpha', help='Value of lora_alpha parameter used when training this LoRA checkpoint', type=int) parser.add_argument('dest_path', help='Path to destination rwkv.cpp model, will be overwitten with the merged model') @@ -47,7 +47,7 @@ def main() -> None: arch_version: str = args.rwkv_arch_version - if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2'): + if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2' or arch_version == 'v6.0'): raise ValueError(f'Invalid RWKV architecture version {arch_version}') print(f'Reading {args.lora_path}') @@ -108,6 +108,12 @@ def main() -> None: if '.time_' in key: replacement = replacement.squeeze() + if arch_version == 'v6.0': + if '.time_first' in key: + replacement = torch.exp(replacement).reshape(-1, 1, 1) + + if '.time_faaaa' in key: + replacement = replacement.unsqueeze(-1) if arch_version == 'v5.1' or arch_version == 'v5.2': if '.time_decay' in key: if arch_version == 'v5.2': diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 90dda81..2a2d84c 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -1,3 +1,5 @@ +#include "ggml.h" + // View tensors of a state of a single layer. struct rwkv_layer_state { struct ggml_tensor * ffn_xx; @@ -190,21 +192,9 @@ static struct ggml_tensor * rwkv_att_v5( size_t n_embed = x->ne[0]; size_t sequence_length = x->ne[1]; - x = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); - struct ggml_tensor * x_prev; - if (sequence_length > 1) { - x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_length); - x_prev = ggml_set_1d_inplace(ctx, x_prev, state.att_xx, 0); - x_prev = ggml_set_1d_inplace( - ctx, - x_prev, - ggml_view_1d(ctx, x, n_embed * (sequence_length - 1), 0), n_embed * sizeof(float) - ); - } else { - x_prev = state.att_xx; - } + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); struct ggml_tensor * xk = ggml_add_inplace( ctx, @@ -320,6 +310,170 @@ static struct ggml_tensor * rwkv_att_v5( return ggml_mul_mat(ctx, layer.att_output, x); } +static struct ggml_tensor * rwkv_att_v6( + struct ggml_context * ctx, + struct ggml_tensor * x, + struct rwkv_layer layer, + struct rwkv_layer_state & state, + const int64_t head_count, + const int64_t head_size, + const uint32_t arch_version_minor +) { + size_t n_embed = x->ne[0]; + size_t sequence_length = x->ne[1]; + + struct ggml_tensor * x_prev; + + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); + + x_prev = ggml_sub_inplace(ctx, x_prev, x); + + struct ggml_tensor * xxx = ggml_add_inplace( + ctx, + ggml_mul(ctx, x_prev, layer.att_x_maa), + x + ); + + // WKVRG!!! + + xxx = ggml_reshape_4d( + ctx, + ggml_tanh( + ctx, + ggml_mul_mat(ctx,xxx,layer.att_tm_w1) + ), + 5, -1, 1, sequence_length + ); + xxx = ggml_reshape_3d( + ctx, + ggml_mul_inplace(ctx, xxx, layer.att_tm_w2), + 5, -1, sequence_length + ); + + struct ggml_tensor * xw = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx + ggml_add_inplace(ctx, mw, layer.att_w_maa), + x_prev, + ), + x + ); + + struct ggml_tensor * xk = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx + ggml_add_inplace(ctx, mk, layer.att_k_maa), + x_prev, + ), + x + ); + + struct ggml_tensor * xv = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx + ggml_add_inplace(ctx, mv, layer.att_v_maa), + x_prev, + ), + x + ); + + struct ggml_tensor * xr = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx + ggml_add_inplace(ctx, mr, layer.att_r_maa), + x_prev, + ), + x + ); + + struct ggml_tensor * xg = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx + ggml_add_inplace(ctx, mg, layer.att_g_maa), + x_prev, + ), + x + ); + + struct ggml_tensor * time_decay_w = ggml_reshape_4d( + ctx, + ggml_add_inplace( + ctx, + ggml_mul_mat( + ctx, + ggml_tanh_inplace( + ctx, + ggml_mul_mat(ctx, xw, layer.att_td_w1) + ), + layer.att_td_w2 + ), + layer.att_time_decay + ), + 1, head_size, head_count, sequence_length + ); + + time_decay_w = rwkv_exp(ctx, + rwkv_1_minus_x(ctx, + rwkv_exp(ctx, time_decay_w) + ) + ); + + state.att_xx = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_length - 1) * sizeof(float)); + + struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, 1, head_count, sequence_length); + struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), 1, head_size, head_count, sequence_length); + struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_value, xv), head_size, 1, head_count, sequence_length); + + struct ggml_tensor * g = ggml_silu_inplace( + ctx, + ggml_mul_mat(ctx, layer.att_gate, xg) + ); + + // dup is not strictly required; doing it just in case. + struct ggml_tensor * state_out = ggml_dup(ctx, state.att_heads); + + x = rwkv_wkv_v5( + ctx, + sequence_length, + n_embed, + head_count, + head_size, + x, + k, + v, + r, + layer.att_time_faaaa, + time_decay_w, + state_out + ); + + state.att_heads = state_out; + + // ggml_group_norm considers groups in the third dimension. + x = ggml_reshape_4d(ctx, x, 1, 1, n_embed, sequence_length); + x = ggml_group_norm_inplace(ctx, x, head_count); + // Convert back to a regular vector. + x = ggml_reshape_2d(ctx, x, n_embed, sequence_length); + x = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + x, + layer.att_ln_x_weight + ), + layer.att_ln_x_bias + ); + + x = ggml_mul_inplace(ctx, x, g); + + return ggml_mul_mat(ctx, layer.att_output, x); +} + static struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { struct ggml_tensor * x_prev; rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); @@ -349,6 +503,34 @@ static struct ggml_tensor * rwkv_ffn(struct ggml_context * ctx, struct ggml_tens return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); } +static struct ggml_tensor * rwkv_ffn_v6(struct ggml_context * ctx, struct ggml_tensor * x, struct rwkv_layer layer, struct rwkv_layer_state & state) { + struct ggml_tensor * x_prev; + rwkv_carry_x(ctx, layer.ln2_weight, layer.ln2_bias, x, x_prev, state.ffn_xx); + + // xk = x + sx * time_maa_k + // xr = x + sx * time_maa_r + struct ggml_tensor * xk = ggml_add_inplace( + ctx, + x, + ggml_mul(ctx, x_prev, layer.ffn_time_maa_k) + ); + + struct ggml_tensor * xr = ggml_add_inplace( + ctx, + x, + ggml_mul(ctx, x_prev, layer.ffn_time_maa_r) + ); + + // r = torch.sigmoid(rw @ xr) + struct ggml_tensor * r = rwkv_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_receptance, xr)); + + // k = torch.square(torch.relu(kw @ xk)) + struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer.ffn_key, xk))); + + // r * (vw @ k) + return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer.ffn_value, k)); +} + static void rwkv_create_input_and_output_views( struct ggml_context * ctx, struct rwkv_layer_state * inputs, From 6ba24eb6723760d68222022206ddc6b8f63d771b Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Wed, 3 Apr 2024 21:28:47 +0800 Subject: [PATCH 02/11] evel --- rwkv_graph.inc | 62 +++++++++++++++++++++++++++--------------- rwkv_model_loading.inc | 19 +++++++++++++ 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 2a2d84c..db4580d 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -334,8 +334,6 @@ static struct ggml_tensor * rwkv_att_v6( x ); - // WKVRG!!! - xxx = ggml_reshape_4d( ctx, ggml_tanh( @@ -344,6 +342,7 @@ static struct ggml_tensor * rwkv_att_v6( ), 5, -1, 1, sequence_length ); + xxx = ggml_reshape_3d( ctx, ggml_mul_inplace(ctx, xxx, layer.att_tm_w2), @@ -353,9 +352,9 @@ static struct ggml_tensor * rwkv_att_v6( struct ggml_tensor * xw = ggml_add_inplace( ctx, ggml_mul_inplace( - ctx - ggml_add_inplace(ctx, mw, layer.att_w_maa), - x_prev, + ctx, + ggml_add_inplace(ctx, xxx->src[0], layer.att_w_maa), + x_prev ), x ); @@ -363,9 +362,9 @@ static struct ggml_tensor * rwkv_att_v6( struct ggml_tensor * xk = ggml_add_inplace( ctx, ggml_mul_inplace( - ctx - ggml_add_inplace(ctx, mk, layer.att_k_maa), - x_prev, + ctx, + ggml_add_inplace(ctx, xxx->src[1], layer.att_k_maa), + x_prev ), x ); @@ -373,9 +372,9 @@ static struct ggml_tensor * rwkv_att_v6( struct ggml_tensor * xv = ggml_add_inplace( ctx, ggml_mul_inplace( - ctx - ggml_add_inplace(ctx, mv, layer.att_v_maa), - x_prev, + ctx, + ggml_add_inplace(ctx, xxx->src[2], layer.att_v_maa), + x_prev ), x ); @@ -383,9 +382,9 @@ static struct ggml_tensor * rwkv_att_v6( struct ggml_tensor * xr = ggml_add_inplace( ctx, ggml_mul_inplace( - ctx - ggml_add_inplace(ctx, mr, layer.att_r_maa), - x_prev, + ctx, + ggml_add_inplace(ctx, xxx->src[3], layer.att_r_maa), + x_prev ), x ); @@ -393,9 +392,9 @@ static struct ggml_tensor * rwkv_att_v6( struct ggml_tensor * xg = ggml_add_inplace( ctx, ggml_mul_inplace( - ctx - ggml_add_inplace(ctx, mg, layer.att_g_maa), - x_prev, + ctx, + ggml_add_inplace(ctx, xxx->src[4], layer.att_g_maa), + x_prev ), x ); @@ -624,8 +623,8 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu struct rwkv_layer_state state = inputs[i]; - x = model.arch_version_major >= 5 ? - ggml_add_inplace(ctx, x, rwkv_att_v5( + if (model.arch_version_major == 6) { + x = ggml_add_inplace(ctx, x, rwkv_att_v6( ctx, x, layer, @@ -633,10 +632,29 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu model.head_count, model.head_size, model.arch_version_minor - )) : - ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state)); + )); - x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); + x = ggml_add_inplace(ctx, x, rwkv_ffn_v6(ctx, x, layer, state)); + + } else if (model.arch_version_major == 5) { + x = ggml_add_inplace(ctx, x, rwkv_att_v5( + ctx, + x, + layer, + state, + model.head_count, + model.head_size, + model.arch_version_minor + )); + + x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); + + } else { + x = ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state)); + + x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); + + } struct rwkv_layer_state & output_state = outputs[i]; diff --git a/rwkv_model_loading.inc b/rwkv_model_loading.inc index fef0ea9..7213590 100644 --- a/rwkv_model_loading.inc +++ b/rwkv_model_loading.inc @@ -6,6 +6,7 @@ struct rwkv_layer { struct ggml_tensor * att_time_mix_k; struct ggml_tensor * att_time_mix_v; struct ggml_tensor * att_time_mix_r; + // Removed in RWKV v5.2; set to NULL for this and newer models. struct ggml_tensor * att_time_first; struct ggml_tensor * att_time_decay; @@ -22,16 +23,34 @@ struct rwkv_layer { struct ggml_tensor * att_time_faaaa; struct ggml_tensor * att_time_mix_g; struct ggml_tensor * att_gate; + + // Added in RWKV v6; + struct ggml_tensor * att_x_maa; + struct ggml_tensor * att_w_maa; + struct ggml_tensor * att_k_maa; + struct ggml_tensor * att_v_maa; + struct ggml_tensor * att_r_maa; + struct ggml_tensor * att_g_maa; + struct ggml_tensor * att_tm_w1; + struct ggml_tensor * att_tm_w2; + struct ggml_tensor * att_td_w1; + struct ggml_tensor * att_td_w2; struct ggml_tensor * ln2_weight; struct ggml_tensor * ln2_bias; // FFN. + // v4 v5.x struct ggml_tensor * ffn_time_mix_k; struct ggml_tensor * ffn_time_mix_r; + // v6 + struct ggml_tensor * ffn_time_maa_k; + struct ggml_tensor * ffn_time_maa_r; + struct ggml_tensor * ffn_key; struct ggml_tensor * ffn_value; struct ggml_tensor * ffn_receptance; + }; // The model holds all parameter tensors and the ggml context containing them. From 3ef4086a9f27683e871eff9e7407eb5e1672370e Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Wed, 3 Apr 2024 21:59:20 +0800 Subject: [PATCH 03/11] fix Dv6 --- python/convert_pytorch_to_ggml.py | 5 +---- python/merge_lora_into_ggml.py | 3 --- rwkv_graph.inc | 14 ++++++++++++-- rwkv_model_loading.inc | 4 ++++ 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/python/convert_pytorch_to_ggml.py b/python/convert_pytorch_to_ggml.py index d0bf158..6d9154c 100644 --- a/python/convert_pytorch_to_ggml.py +++ b/python/convert_pytorch_to_ggml.py @@ -34,7 +34,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict - is_v6_0: bool = 'blocks.0.att.time_maa' in state_dict + is_v6_0: bool = 'blocks.0.att.time_maa_x' in state_dict if is_v6_0: print('Detected RWKV v6.0') @@ -67,9 +67,6 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t tensor = tensor.squeeze() if is_v6_0: - if '.time_first' in k: - tensor = torch.exp(tensor).reshape(-1, 1, 1) - if '.time_faaaa' in k: tensor = tensor.unsqueeze(-1) diff --git a/python/merge_lora_into_ggml.py b/python/merge_lora_into_ggml.py index b9c6bd2..ac8d517 100644 --- a/python/merge_lora_into_ggml.py +++ b/python/merge_lora_into_ggml.py @@ -109,9 +109,6 @@ def main() -> None: replacement = replacement.squeeze() if arch_version == 'v6.0': - if '.time_first' in key: - replacement = torch.exp(replacement).reshape(-1, 1, 1) - if '.time_faaaa' in key: replacement = replacement.unsqueeze(-1) if arch_version == 'v5.1' or arch_version == 'v5.2': diff --git a/rwkv_graph.inc b/rwkv_graph.inc index db4580d..9d0b03e 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -653,7 +653,7 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu x = ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state)); x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state)); - + } struct rwkv_layer_state & output_state = outputs[i]; @@ -767,7 +767,17 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c struct rwkv_layer_state state = inputs[i]; - if (model.arch_version_major >= 5) { + if (model.arch_version_major == 6) { + x = ggml_add_inplace(ctx, x, rwkv_att_v6( + ctx, + x, + layer, + state, + model.head_count, + model.head_size, + model.arch_version_minor + )); + } else if (model.arch_version_major == 5) { x = ggml_add_inplace(ctx, x, rwkv_att_v5( ctx, x, diff --git a/rwkv_model_loading.inc b/rwkv_model_loading.inc index 7213590..a6224f8 100644 --- a/rwkv_model_loading.inc +++ b/rwkv_model_loading.inc @@ -205,6 +205,10 @@ static bool rwkv_load_model_from_file(const char * file_path, struct rwkv_model } } + if (parameters.find("blocks.0.att.time_maa_x") != parameters.end()) { + model.arch_version_major = 6; + } + std::unordered_map & parameters_ref = parameters; RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params( model, From cdb10645dcc4d0f3cbdf6eb5110d17ef29cab519 Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Wed, 3 Apr 2024 22:25:57 +0800 Subject: [PATCH 04/11] fix load error --- rwkv_graph.inc | 20 ++++++++-------- rwkv_model_loading.inc | 54 +++++++++++++++++++++++++++--------------- 2 files changed, 45 insertions(+), 29 deletions(-) diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 9d0b03e..e9147d1 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -330,7 +330,7 @@ static struct ggml_tensor * rwkv_att_v6( struct ggml_tensor * xxx = ggml_add_inplace( ctx, - ggml_mul(ctx, x_prev, layer.att_x_maa), + ggml_mul(ctx, x_prev, layer.time_maa_x), x ); @@ -338,14 +338,14 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_tanh( ctx, - ggml_mul_mat(ctx,xxx,layer.att_tm_w1) + ggml_mul_mat(ctx,xxx,layer.time_maa_w1) ), 5, -1, 1, sequence_length ); xxx = ggml_reshape_3d( ctx, - ggml_mul_inplace(ctx, xxx, layer.att_tm_w2), + ggml_mul_inplace(ctx, xxx, layer.time_maa_w2), 5, -1, sequence_length ); @@ -353,7 +353,7 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[0], layer.att_w_maa), + ggml_add_inplace(ctx, xxx->src[0], layer.time_maa_w), x_prev ), x @@ -363,7 +363,7 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[1], layer.att_k_maa), + ggml_add_inplace(ctx, xxx->src[1], layer.time_maa_k), x_prev ), x @@ -373,7 +373,7 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[2], layer.att_v_maa), + ggml_add_inplace(ctx, xxx->src[2], layer.time_maa_v), x_prev ), x @@ -383,7 +383,7 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[3], layer.att_r_maa), + ggml_add_inplace(ctx, xxx->src[3], layer.time_maa_r), x_prev ), x @@ -393,7 +393,7 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[4], layer.att_g_maa), + ggml_add_inplace(ctx, xxx->src[4], layer.time_maa_g), x_prev ), x @@ -407,9 +407,9 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_tanh_inplace( ctx, - ggml_mul_mat(ctx, xw, layer.att_td_w1) + ggml_mul_mat(ctx, xw, layer.time_decay_w1) ), - layer.att_td_w2 + layer.time_decay_w2 ), layer.att_time_decay ), diff --git a/rwkv_model_loading.inc b/rwkv_model_loading.inc index a6224f8..7a4382f 100644 --- a/rwkv_model_loading.inc +++ b/rwkv_model_loading.inc @@ -25,16 +25,16 @@ struct rwkv_layer { struct ggml_tensor * att_gate; // Added in RWKV v6; - struct ggml_tensor * att_x_maa; - struct ggml_tensor * att_w_maa; - struct ggml_tensor * att_k_maa; - struct ggml_tensor * att_v_maa; - struct ggml_tensor * att_r_maa; - struct ggml_tensor * att_g_maa; - struct ggml_tensor * att_tm_w1; - struct ggml_tensor * att_tm_w2; - struct ggml_tensor * att_td_w1; - struct ggml_tensor * att_td_w2; + struct ggml_tensor * time_maa_x; + struct ggml_tensor * time_maa_w; + struct ggml_tensor * time_maa_k; + struct ggml_tensor * time_maa_v; + struct ggml_tensor * time_maa_r; + struct ggml_tensor * time_maa_g; + struct ggml_tensor * time_maa_w1; + struct ggml_tensor * time_maa_w2; + struct ggml_tensor * time_decay_w1; + struct ggml_tensor * time_decay_w2; struct ggml_tensor * ln2_weight; struct ggml_tensor * ln2_bias; @@ -120,9 +120,22 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.weight"), buffer), layer.ln1_weight)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.bias"), buffer), layer.ln1_bias)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r)); + if (model.arch_version_major == 6){ + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_x"), buffer), layer.att_x_maa)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w"), buffer), layer.att_w_maa)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_k"), buffer), layer.att_k_maa)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_v"), buffer), layer.att_v_maa)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_r"), buffer), layer.att_r_maa)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_g"), buffer), layer.att_g_maa)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w1"), buffer), layer.time_maa_w1)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w2"), buffer), layer.time_maa_w2)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w1"), buffer), layer.time_decay_w1)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w2"), buffer), layer.time_decay_w2)); + } else { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r)); + } if (model.arch_version_major >= 5 && model.arch_version_minor >= 2) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_faaaa"), buffer), layer.att_time_faaaa)); @@ -140,7 +153,7 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias)); - if (model.arch_version_minor >= 2) { + if (model.arch_version_minor >= 2 || model.arch_version_major == 6) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_g"), buffer), layer.att_time_mix_g)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate)); } @@ -149,11 +162,14 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance)); + if (model.arch_version_major == 6) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_k"), buffer), layer.ffn_time_maa_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_r"), buffer), layer.ffn_time_maa_r)); + } else { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance)); + } } RWKV_ENSURE_OR_FALSE(callback("ln_out.weight", model.ln_out_weight)); From 85f8031fb72dc64bbdeff12b4bfd331150d603fe Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Thu, 4 Apr 2024 01:13:21 +0800 Subject: [PATCH 05/11] a? --- rwkv_graph.inc | 68 ++++++++++++++++++++++++++++-------------- rwkv_model_loading.inc | 44 +++++++++++++-------------- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/rwkv_graph.inc b/rwkv_graph.inc index e9147d1..739004f 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -1,5 +1,4 @@ #include "ggml.h" - // View tensors of a state of a single layer. struct rwkv_layer_state { struct ggml_tensor * ffn_xx; @@ -194,7 +193,21 @@ static struct ggml_tensor * rwkv_att_v5( struct ggml_tensor * x_prev; - rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); + //rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); + + x = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); + + if (sequence_length > 1) { + x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_length); + x_prev = ggml_set_1d_inplace(ctx, x_prev, state.att_xx, 0); + x_prev = ggml_set_1d_inplace( + ctx, + x_prev, + ggml_view_1d(ctx, x, n_embed * (sequence_length - 1), 0), n_embed * sizeof(float) + ); + } else { + x_prev = state.att_xx; + } struct ggml_tensor * xk = ggml_add_inplace( ctx, @@ -310,6 +323,13 @@ static struct ggml_tensor * rwkv_att_v5( return ggml_mul_mat(ctx, layer.att_output, x); } +static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); +} static struct ggml_tensor * rwkv_att_v6( struct ggml_context * ctx, struct ggml_tensor * x, @@ -327,33 +347,40 @@ static struct ggml_tensor * rwkv_att_v6( rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); x_prev = ggml_sub_inplace(ctx, x_prev, x); - + + // xxx = x + sx * x_maa struct ggml_tensor * xxx = ggml_add_inplace( - ctx, - ggml_mul(ctx, x_prev, layer.time_maa_x), - x + ctx, + ggml_mul(ctx, x_prev, layer.att_time_maa_x), + x ); - xxx = ggml_reshape_4d( + // xxx = torch.tanh(xxx @ tm_w1).view(T, 5, 1, -1).permute(1,0,2,3) + xxx = ggml_permute( ctx, - ggml_tanh( + ggml_reshape_4d( ctx, - ggml_mul_mat(ctx,xxx,layer.time_maa_w1) + ggml_tanh( + ctx, + ggml_mul_mat(ctx, xxx, layer.att_time_maa_w1) + ), + sequence_length, 5, 1, layer.att_time_maa_w2->ne[1] ), - 5, -1, 1, sequence_length + 1, 0, 2, 3 ); + // xxx = torch.bmm(xxx, tm_w2).view(5, T, -1) xxx = ggml_reshape_3d( ctx, - ggml_mul_inplace(ctx, xxx, layer.time_maa_w2), - 5, -1, sequence_length + ggml_mul_mat(ctx, xxx, layer.att_time_maa_w2), + sequence_length, 5, n_embed ); struct ggml_tensor * xw = ggml_add_inplace( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[0], layer.time_maa_w), + ggml_add_inplace(ctx, xxx->src[0], layer.att_time_maa_w), x_prev ), x @@ -363,7 +390,7 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[1], layer.time_maa_k), + ggml_add_inplace(ctx, xxx->src[1], layer.att_time_maa_k), x_prev ), x @@ -373,7 +400,7 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[2], layer.time_maa_v), + ggml_add_inplace(ctx, xxx->src[2], layer.att_time_maa_v), x_prev ), x @@ -383,7 +410,7 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[3], layer.time_maa_r), + ggml_add_inplace(ctx, xxx->src[3], layer.att_time_maa_r), x_prev ), x @@ -393,7 +420,7 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[4], layer.time_maa_g), + ggml_add_inplace(ctx, xxx->src[4], layer.att_time_maa_g), x_prev ), x @@ -407,9 +434,9 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_tanh_inplace( ctx, - ggml_mul_mat(ctx, xw, layer.time_decay_w1) + ggml_mul_mat(ctx, xw, layer.att_time_decay_w1) ), - layer.time_decay_w2 + layer.att_time_decay_w2 ), layer.att_time_decay ), @@ -427,7 +454,6 @@ static struct ggml_tensor * rwkv_att_v6( struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, 1, head_count, sequence_length); struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), 1, head_size, head_count, sequence_length); struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_value, xv), head_size, 1, head_count, sequence_length); - struct ggml_tensor * g = ggml_silu_inplace( ctx, ggml_mul_mat(ctx, layer.att_gate, xg) @@ -543,7 +569,6 @@ static void rwkv_create_input_and_output_views( const int64_t head_size ) { size_t sz_float = sizeof(float); - for (size_t i = 0; i < n_layer; i++) { struct rwkv_layer_state & input_state = inputs[i]; struct rwkv_layer_state & output_state = outputs[i]; @@ -582,7 +607,6 @@ static void rwkv_create_input_and_output_views( // Creates and sets the input and output ggml tensors, builds the computation graph. static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_computation_graph & graph) { graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph()); - struct rwkv_file_header & header = model.header; const size_t n_vocab = header.n_vocab; const size_t n_embed = header.n_embed; diff --git a/rwkv_model_loading.inc b/rwkv_model_loading.inc index 7a4382f..a226efa 100644 --- a/rwkv_model_loading.inc +++ b/rwkv_model_loading.inc @@ -25,16 +25,16 @@ struct rwkv_layer { struct ggml_tensor * att_gate; // Added in RWKV v6; - struct ggml_tensor * time_maa_x; - struct ggml_tensor * time_maa_w; - struct ggml_tensor * time_maa_k; - struct ggml_tensor * time_maa_v; - struct ggml_tensor * time_maa_r; - struct ggml_tensor * time_maa_g; - struct ggml_tensor * time_maa_w1; - struct ggml_tensor * time_maa_w2; - struct ggml_tensor * time_decay_w1; - struct ggml_tensor * time_decay_w2; + struct ggml_tensor * att_time_maa_x; + struct ggml_tensor * att_time_maa_w; + struct ggml_tensor * att_time_maa_k; + struct ggml_tensor * att_time_maa_v; + struct ggml_tensor * att_time_maa_r; + struct ggml_tensor * att_time_maa_g; + struct ggml_tensor * att_time_maa_w1; + struct ggml_tensor * att_time_maa_w2; + struct ggml_tensor * att_time_decay_w1; + struct ggml_tensor * att_time_decay_w2; struct ggml_tensor * ln2_weight; struct ggml_tensor * ln2_bias; @@ -121,16 +121,16 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.bias"), buffer), layer.ln1_bias)); if (model.arch_version_major == 6){ - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_x"), buffer), layer.att_x_maa)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w"), buffer), layer.att_w_maa)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_k"), buffer), layer.att_k_maa)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_v"), buffer), layer.att_v_maa)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_r"), buffer), layer.att_r_maa)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_g"), buffer), layer.att_g_maa)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w1"), buffer), layer.time_maa_w1)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w2"), buffer), layer.time_maa_w2)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w1"), buffer), layer.time_decay_w1)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w2"), buffer), layer.time_decay_w2)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_x"), buffer), layer.att_time_maa_x)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w"), buffer), layer.att_time_maa_w)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_k"), buffer), layer.att_time_maa_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_v"), buffer), layer.att_time_maa_v)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_r"), buffer), layer.att_time_maa_r)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_g"), buffer), layer.att_time_maa_g)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w1"), buffer), layer.att_time_maa_w1)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w2"), buffer), layer.att_time_maa_w2)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w1"), buffer), layer.att_time_decay_w1)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w2"), buffer), layer.att_time_decay_w2)); } else { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v)); @@ -149,11 +149,11 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output)); - if (model.arch_version_major >= 5) { + if (model.arch_version_major == 5) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias)); - if (model.arch_version_minor >= 2 || model.arch_version_major == 6) { + if (model.arch_version_minor >= 2) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_g"), buffer), layer.att_time_mix_g)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate)); } From d56d64f920ecde78d12cc6f6e9d469037cc56292 Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Thu, 4 Apr 2024 03:07:16 +0800 Subject: [PATCH 06/11] load --- rwkv_graph.inc | 9 ++---- rwkv_model_loading.inc | 67 +++++++++++++++++++++--------------------- 2 files changed, 36 insertions(+), 40 deletions(-) diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 739004f..565ff80 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -64,18 +64,15 @@ static void rwkv_carry_x( const size_t n_embed = x->ne[0]; const size_t sequence_len = x->ne[1]; + // self.layer_norm(x, self.w.blocks[i].ln2) + x = rwkv_layer_norm(ctx, x, weight, bias); if (sequence_len == 1) { - // self.layer_norm(x, self.w.blocks[i].ln2) - x = rwkv_layer_norm(ctx, x, weight, bias); - // xx = state[5*i+0] x_prev = carry; // state[5*i+0] = x carry = x; } else { - // self.layer_norm(x, self.w.blocks[i].ln2) - x = rwkv_layer_norm(ctx, x, weight, bias); // xx = torch.cat((state[5*i+0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:])) x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_len); @@ -193,8 +190,6 @@ static struct ggml_tensor * rwkv_att_v5( struct ggml_tensor * x_prev; - //rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); - x = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); if (sequence_length > 1) { diff --git a/rwkv_model_loading.inc b/rwkv_model_loading.inc index a226efa..763d096 100644 --- a/rwkv_model_loading.inc +++ b/rwkv_model_loading.inc @@ -120,7 +120,21 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.weight"), buffer), layer.ln1_weight)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln1.bias"), buffer), layer.ln1_bias)); - if (model.arch_version_major == 6){ + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance)); + + if (model.arch_version_major == 6) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_faaaa"), buffer), layer.att_time_faaaa)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_x"), buffer), layer.att_time_maa_x)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w"), buffer), layer.att_time_maa_w)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_k"), buffer), layer.att_time_maa_k)); @@ -131,47 +145,34 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_maa_w2"), buffer), layer.att_time_maa_w2)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w1"), buffer), layer.att_time_decay_w1)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w2"), buffer), layer.att_time_decay_w2)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias)); + + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_k"), buffer), layer.ffn_time_maa_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_r"), buffer), layer.ffn_time_maa_r)); } else { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_k"), buffer), layer.att_time_mix_k)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_v"), buffer), layer.att_time_mix_v)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_r"), buffer), layer.att_time_mix_r)); - } - if (model.arch_version_major >= 5 && model.arch_version_minor >= 2) { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_faaaa"), buffer), layer.att_time_faaaa)); - } else { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first)); - } - - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay"), buffer), layer.att_time_decay)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.key.weight"), buffer), layer.att_key)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.value.weight"), buffer), layer.att_value)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.receptance.weight"), buffer), layer.att_receptance)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.output.weight"), buffer), layer.att_output)); - - if (model.arch_version_major == 5) { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias)); - - if (model.arch_version_minor >= 2) { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_g"), buffer), layer.att_time_mix_g)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_k"), buffer), layer.ffn_time_mix_k)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_mix_r"), buffer), layer.ffn_time_mix_r)); + if (model.arch_version_major == 5) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias)); + if (model.arch_version_minor >= 2) { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_faaaa"), buffer), layer.att_time_faaaa)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_mix_g"), buffer), layer.att_time_mix_g)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate)); + } else { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first)); + } + } else { + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_first"), buffer), layer.att_time_first)); } } - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.weight"), buffer), layer.ln2_weight)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ln2.bias"), buffer), layer.ln2_bias)); - - if (model.arch_version_major == 6) { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_k"), buffer), layer.ffn_time_maa_k)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_r"), buffer), layer.ffn_time_maa_r)); - } else { - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.key.weight"), buffer), layer.ffn_key)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.value.weight"), buffer), layer.ffn_value)); - RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.receptance.weight"), buffer), layer.ffn_receptance)); - } } - RWKV_ENSURE_OR_FALSE(callback("ln_out.weight", model.ln_out_weight)); RWKV_ENSURE_OR_FALSE(callback("ln_out.bias", model.ln_out_bias)); RWKV_ENSURE_OR_FALSE(callback("head.weight", model.head)); From 1abd4e83b9f81e999834c761c8ab805af770929d Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Thu, 4 Apr 2024 18:22:06 +0800 Subject: [PATCH 07/11] v6 but a? --- python/convert_pytorch_to_ggml.py | 8 ++ rwkv_graph.inc | 117 +++++++++++++++++------------- rwkv_model_loading.inc | 1 + 3 files changed, 75 insertions(+), 51 deletions(-) diff --git a/python/convert_pytorch_to_ggml.py b/python/convert_pytorch_to_ggml.py index 6d9154c..a1af9ee 100644 --- a/python/convert_pytorch_to_ggml.py +++ b/python/convert_pytorch_to_ggml.py @@ -69,6 +69,14 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t if is_v6_0: if '.time_faaaa' in k: tensor = tensor.unsqueeze(-1) + if '.time_maa_w1' in k or '.time_decay_w' in k: + tensor = tensor.transpose(0,1) + tensor.contiguous() + + if '.time_maa_w2' in k: + # (5, 32, 2048) -> (32, 2048, 5) + tensor = tensor.permute(0,2,1) + tensor.contiguous() elif is_v5_1_or_2: if '.time_decay' in k: diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 565ff80..aebf4fc 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -1,4 +1,5 @@ #include "ggml.h" +#include // View tensors of a state of a single layer. struct rwkv_layer_state { struct ggml_tensor * ffn_xx; @@ -69,7 +70,6 @@ static void rwkv_carry_x( if (sequence_len == 1) { // xx = state[5*i+0] x_prev = carry; - // state[5*i+0] = x carry = x; } else { @@ -190,19 +190,7 @@ static struct ggml_tensor * rwkv_att_v5( struct ggml_tensor * x_prev; - x = rwkv_layer_norm(ctx, x, layer.ln1_weight, layer.ln1_bias); - - if (sequence_length > 1) { - x_prev = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embed, sequence_length); - x_prev = ggml_set_1d_inplace(ctx, x_prev, state.att_xx, 0); - x_prev = ggml_set_1d_inplace( - ctx, - x_prev, - ggml_view_1d(ctx, x, n_embed * (sequence_length - 1), 0), n_embed * sizeof(float) - ); - } else { - x_prev = state.att_xx; - } + rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); struct ggml_tensor * xk = ggml_add_inplace( ctx, @@ -248,7 +236,6 @@ static struct ggml_tensor * rwkv_att_v5( ); } - state.att_xx = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_length - 1) * sizeof(float)); struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, 1, head_count, sequence_length); struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), 1, head_size, head_count, sequence_length); @@ -325,6 +312,19 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable (t1->ne[3]%t0->ne[3] == 0); } + + +static inline void print_shape(const char * echo,const struct ggml_tensor * t){ + printf(echo); + printf(" (%ld, %ld, %ld, %ld)\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]); +} + +static inline void pppp(const char * echo){ + printf(echo); + printf("\n"); +} + + static struct ggml_tensor * rwkv_att_v6( struct ggml_context * ctx, struct ggml_tensor * x, @@ -341,41 +341,56 @@ static struct ggml_tensor * rwkv_att_v6( rwkv_carry_x(ctx, layer.ln1_weight, layer.ln1_bias, x, x_prev, state.att_xx); + // sx = state[i1] - x x_prev = ggml_sub_inplace(ctx, x_prev, x); - // xxx = x + sx * x_maa + // xxx = x + sx * x_maa (C, T) struct ggml_tensor * xxx = ggml_add_inplace( ctx, ggml_mul(ctx, x_prev, layer.att_time_maa_x), x ); - // xxx = torch.tanh(xxx @ tm_w1).view(T, 5, 1, -1).permute(1,0,2,3) - xxx = ggml_permute( + // xxx (2048, T) @ att_time_maa_w1 (2048, 160) -> (T, 160) -> (T, 5, 32, 1) -> xxx (32, 1, 5, T) + xxx = ggml_reshape_4d( ctx, - ggml_reshape_4d( + ggml_tanh( ctx, - ggml_tanh( - ctx, - ggml_mul_mat(ctx, xxx, layer.att_time_maa_w1) - ), - sequence_length, 5, 1, layer.att_time_maa_w2->ne[1] + ggml_mul_mat(ctx, xxx, layer.att_time_maa_w1) ), - 1, 0, 2, 3 + sequence_length, 5, layer.att_time_maa_w2->ne[0], 1 + ); + + xxx = ggml_permute( + ctx, + xxx, + 3, 2, 0, 1 ); // xxx = torch.bmm(xxx, tm_w2).view(5, T, -1) + // xxx (32, 1, 5, T) @ att_time_maa_w2 (32, 2048, 5) -> (1, 2048, 5, T) -> (5, 2048, T, 1) -> xxx (5, 2048, T) xxx = ggml_reshape_3d( ctx, - ggml_mul_mat(ctx, xxx, layer.att_time_maa_w2), - sequence_length, 5, n_embed + ggml_cont( + ctx, + ggml_permute(ctx,ggml_mul_mat(ctx, xxx, layer.att_time_maa_w2), 3, 1, 0, 2) + ), + 5, layer.att_time_maa_w2->ne[1], sequence_length ); + // mw (2048, T) * att_time_maa_w (2048) + sx (2048, T) -> xw (2048, T) + // !!!! 检查是否正确,因为转置 + struct ggml_tensor * mw = ggml_view_1d(ctx, xxx, n_embed, n_embed * 0 * sizeof(float)); + struct ggml_tensor * mk = ggml_view_1d(ctx, xxx, n_embed, n_embed * 1 * sizeof(float)); + struct ggml_tensor * mv = ggml_view_1d(ctx, xxx, n_embed, n_embed * 2 * sizeof(float)); + struct ggml_tensor * mr = ggml_view_1d(ctx, xxx, n_embed, n_embed * 3 * sizeof(float)); + struct ggml_tensor * mg = ggml_view_1d(ctx, xxx, n_embed, n_embed * 4 * sizeof(float)); + struct ggml_tensor * xw = ggml_add_inplace( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[0], layer.att_time_maa_w), + ggml_add_inplace(ctx, mw, layer.att_time_maa_w), x_prev ), x @@ -385,7 +400,7 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[1], layer.att_time_maa_k), + ggml_add_inplace(ctx, mk, layer.att_time_maa_k), x_prev ), x @@ -395,7 +410,7 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[2], layer.att_time_maa_v), + ggml_add_inplace(ctx, mw, layer.att_time_maa_v), x_prev ), x @@ -405,7 +420,7 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[3], layer.att_time_maa_r), + ggml_add_inplace(ctx, mr, layer.att_time_maa_r), x_prev ), x @@ -415,37 +430,37 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_inplace( ctx, - ggml_add_inplace(ctx, xxx->src[4], layer.att_time_maa_g), + ggml_add_inplace(ctx, mg, layer.att_time_maa_g), x_prev ), x ); - struct ggml_tensor * time_decay_w = ggml_reshape_4d( + + // w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2)).view(H, S, 1) + // att_time_decay_w1 (2048, 64) @ xw (2048, T) -> (64, T) + // att_time_decay_w2 (64, 2048) @ (64, T) -> (2048, T) -> w (1, S, H, T) <=> (1, S, H, T) att_heads + struct ggml_tensor * w = ggml_add_inplace( ctx, - ggml_add_inplace( + ggml_mul_mat( ctx, - ggml_mul_mat( + layer.att_time_decay_w2, + ggml_tanh_inplace( ctx, - ggml_tanh_inplace( - ctx, - ggml_mul_mat(ctx, xw, layer.att_time_decay_w1) - ), - layer.att_time_decay_w2 - ), - layer.att_time_decay + ggml_mul_mat(ctx,layer.att_time_decay_w1, xw) + ) ), - 1, head_size, head_count, sequence_length - ); - - time_decay_w = rwkv_exp(ctx, - rwkv_1_minus_x(ctx, - rwkv_exp(ctx, time_decay_w) - ) + layer.att_time_decay ); + w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, sequence_length); - state.att_xx = ggml_view_1d(ctx, x, n_embed, n_embed * (sequence_length - 1) * sizeof(float)); + // w = torch.exp(-torch.exp(w)) + w = rwkv_exp(ctx,rwkv_1_minus_x(ctx,rwkv_exp(ctx, w))); + // r = (rw @ xr).view(H, 1, S) + // k = (kw @ xk).view(H, S, 1) + // v = (vw @ xv).view(H, 1, S) + // g = F.silu(gw @ xg) struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_receptance, xr), head_size, 1, head_count, sequence_length); struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_key, xk), 1, head_size, head_count, sequence_length); struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer.att_value, xv), head_size, 1, head_count, sequence_length); @@ -468,7 +483,7 @@ static struct ggml_tensor * rwkv_att_v6( v, r, layer.att_time_faaaa, - time_decay_w, + w, state_out ); diff --git a/rwkv_model_loading.inc b/rwkv_model_loading.inc index 763d096..5a5aaca 100644 --- a/rwkv_model_loading.inc +++ b/rwkv_model_loading.inc @@ -147,6 +147,7 @@ static bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.time_decay_w2"), buffer), layer.att_time_decay_w2)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.weight"), buffer), layer.att_ln_x_weight)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.ln_x.bias"), buffer), layer.att_ln_x_bias)); + RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "att.gate.weight"), buffer), layer.att_gate)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_k"), buffer), layer.ffn_time_maa_k)); RWKV_ENSURE_OR_FALSE(callback((strcpy(&buffer[offset], "ffn.time_maa_r"), buffer), layer.ffn_time_maa_r)); From e2d71430482a2240d263a31832779bdb58f6c483 Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Thu, 4 Apr 2024 19:55:23 +0800 Subject: [PATCH 08/11] =?UTF-8?q?l368=20=E9=99=84=E8=BF=91=E8=B0=81?= =?UTF-8?q?=E6=9D=A5=E4=BF=AE=E4=B8=80=E4=B8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rwkv_graph.inc | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/rwkv_graph.inc b/rwkv_graph.inc index aebf4fc..b6a53df 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -361,20 +361,15 @@ static struct ggml_tensor * rwkv_att_v6( sequence_length, 5, layer.att_time_maa_w2->ne[0], 1 ); - xxx = ggml_permute( - ctx, - xxx, - 3, 2, 0, 1 - ); + xxx = ggml_cont_inplace(ctx, ggml_permute(ctx, xxx, 3, 2, 0, 1)); // !!!!! // xxx = torch.bmm(xxx, tm_w2).view(5, T, -1) // xxx (32, 1, 5, T) @ att_time_maa_w2 (32, 2048, 5) -> (1, 2048, 5, T) -> (5, 2048, T, 1) -> xxx (5, 2048, T) + xxx = ggml_permute(ctx, ggml_mul_mat(ctx, xxx, layer.att_time_maa_w2), 3, 1, 0, 2); // !!!!! + xxx = ggml_cont_inplace(ctx, xxx); // !!!!!! xxx = ggml_reshape_3d( ctx, - ggml_cont( - ctx, - ggml_permute(ctx,ggml_mul_mat(ctx, xxx, layer.att_time_maa_w2), 3, 1, 0, 2) - ), + xxx, 5, layer.att_time_maa_w2->ne[1], sequence_length ); From 2912032451894b9a15f6fea0044962feef37c497 Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Thu, 4 Apr 2024 20:00:24 +0800 Subject: [PATCH 09/11] lorainto --- python/merge_lora_into_ggml.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/merge_lora_into_ggml.py b/python/merge_lora_into_ggml.py index ac8d517..b9fbf0f 100644 --- a/python/merge_lora_into_ggml.py +++ b/python/merge_lora_into_ggml.py @@ -110,7 +110,16 @@ def main() -> None: if arch_version == 'v6.0': if '.time_faaaa' in key: - replacement = replacement.unsqueeze(-1) + tensor = tensor.unsqueeze(-1) + if '.time_maa_w1' in key or '.time_decay_w' in key: + tensor = tensor.transpose(0,1) + tensor.contiguous() + + if '.time_maa_w2' in key: + # (5, 32, 2048) -> (32, 2048, 5) + tensor = tensor.permute(0,2,1) + tensor.contiguous() + if arch_version == 'v5.1' or arch_version == 'v5.2': if '.time_decay' in key: if arch_version == 'v5.2': From 08db1e991e4b3d10d365908be57676ec20260aab Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Thu, 4 Apr 2024 20:05:14 +0800 Subject: [PATCH 10/11] =?UTF-8?q?=E6=B8=85=E9=99=A4=E8=B0=83=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rwkv_graph.inc | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/rwkv_graph.inc b/rwkv_graph.inc index b6a53df..1528958 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -1,5 +1,3 @@ -#include "ggml.h" -#include // View tensors of a state of a single layer. struct rwkv_layer_state { struct ggml_tensor * ffn_xx; @@ -305,26 +303,6 @@ static struct ggml_tensor * rwkv_att_v5( return ggml_mul_mat(ctx, layer.att_output, x); } -static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return (t0->ne[0] == t1->ne[0]) && - (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable - (t1->ne[3]%t0->ne[3] == 0); -} - - -static inline void print_shape(const char * echo,const struct ggml_tensor * t){ - printf(echo); - printf(" (%ld, %ld, %ld, %ld)\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]); -} - -static inline void pppp(const char * echo){ - printf(echo); - printf("\n"); -} - - static struct ggml_tensor * rwkv_att_v6( struct ggml_context * ctx, struct ggml_tensor * x, From 69ffeb3cdcf547bc9ff06386272397e9bebf5128 Mon Sep 17 00:00:00 2001 From: YuChuXi Date: Wed, 29 May 2024 18:32:36 +0800 Subject: [PATCH 11/11] =?UTF-8?q?=E4=B8=BA=E4=BB=80=E4=B9=88=E6=80=A7?= =?UTF-8?q?=E8=83=BD=E9=82=A3=E4=B9=88=E5=B7=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rwkv_graph.inc | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 1528958..2b12746 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -1,3 +1,7 @@ + +#include "ggml.h" +#include +#include // View tensors of a state of a single layer. struct rwkv_layer_state { struct ggml_tensor * ffn_xx; @@ -328,36 +332,39 @@ static struct ggml_tensor * rwkv_att_v6( ggml_mul(ctx, x_prev, layer.att_time_maa_x), x ); + + // w1 (2048, M), w2 (N, 5, 2048) + // xxx (2048, T) (M, T) (M//5 N, 1, 5, T) // xxx = torch.tanh(xxx @ tm_w1).view(T, 5, 1, -1).permute(1,0,2,3) // xxx (2048, T) @ att_time_maa_w1 (2048, 160) -> (T, 160) -> (T, 5, 32, 1) -> xxx (32, 1, 5, T) xxx = ggml_reshape_4d( ctx, - ggml_tanh( + ggml_tanh_inplace( ctx, - ggml_mul_mat(ctx, xxx, layer.att_time_maa_w1) + ggml_mul_mat(ctx, layer.att_time_maa_w1, xxx) ), - sequence_length, 5, layer.att_time_maa_w2->ne[0], 1 - ); + layer.att_time_maa_w2->ne[0], 1, 5, sequence_length + ); // (N, 1, 5, T) - xxx = ggml_cont_inplace(ctx, ggml_permute(ctx, xxx, 3, 2, 0, 1)); // !!!!! + // xxx = ggml_cont_inplace(ctx, ggml_permute(ctx, xxx, 3, 2, 0, 1)); // !!!!! // xxx = torch.bmm(xxx, tm_w2).view(5, T, -1) - // xxx (32, 1, 5, T) @ att_time_maa_w2 (32, 2048, 5) -> (1, 2048, 5, T) -> (5, 2048, T, 1) -> xxx (5, 2048, T) - xxx = ggml_permute(ctx, ggml_mul_mat(ctx, xxx, layer.att_time_maa_w2), 3, 1, 0, 2); // !!!!! - xxx = ggml_cont_inplace(ctx, xxx); // !!!!!! + // xxx (32, 1, 5, T) @ att_time_maa_w2 (32, 2048, 5) -> (1, 2048, 5, T) -> xxx (2048, 5, T) + xxx = ggml_mul_mat(ctx, xxx, layer.att_time_maa_w2); xxx = ggml_reshape_3d( ctx, xxx, - 5, layer.att_time_maa_w2->ne[1], sequence_length + layer.att_time_maa_w2->ne[1], 5, sequence_length ); + xxx = ggml_permute(ctx, xxx, 0, 1, 3, 2); // mw (2048, T) * att_time_maa_w (2048) + sx (2048, T) -> xw (2048, T) // !!!! 检查是否正确,因为转置 - struct ggml_tensor * mw = ggml_view_1d(ctx, xxx, n_embed, n_embed * 0 * sizeof(float)); - struct ggml_tensor * mk = ggml_view_1d(ctx, xxx, n_embed, n_embed * 1 * sizeof(float)); - struct ggml_tensor * mv = ggml_view_1d(ctx, xxx, n_embed, n_embed * 2 * sizeof(float)); - struct ggml_tensor * mr = ggml_view_1d(ctx, xxx, n_embed, n_embed * 3 * sizeof(float)); - struct ggml_tensor * mg = ggml_view_1d(ctx, xxx, n_embed, n_embed * 4 * sizeof(float)); + struct ggml_tensor * mw = ggml_get_rows(ctx, xxx, ggml_new_i32(ctx, 0)); + struct ggml_tensor * mk = ggml_get_rows(ctx, xxx, ggml_new_i32(ctx, 1)); + struct ggml_tensor * mv = ggml_get_rows(ctx, xxx, ggml_new_i32(ctx, 2)); + struct ggml_tensor * mr = ggml_get_rows(ctx, xxx, ggml_new_i32(ctx, 3)); + struct ggml_tensor * mg = ggml_get_rows(ctx, xxx, ggml_new_i32(ctx, 4)); struct ggml_tensor * xw = ggml_add_inplace( ctx, @@ -409,7 +416,6 @@ static struct ggml_tensor * rwkv_att_v6( x ); - // w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2)).view(H, S, 1) // att_time_decay_w1 (2048, 64) @ xw (2048, T) -> (64, T) // att_time_decay_w2 (64, 2048) @ (64, T) -> (2048, T) -> w (1, S, H, T) <=> (1, S, H, T) att_heads @@ -428,7 +434,7 @@ static struct ggml_tensor * rwkv_att_v6( w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, sequence_length); // w = torch.exp(-torch.exp(w)) - w = rwkv_exp(ctx,rwkv_1_minus_x(ctx,rwkv_exp(ctx, w))); + w = rwkv_exp(ctx,ggml_neg_inplace(ctx,rwkv_exp(ctx, w))); // r = (rw @ xr).view(H, 1, S) // k = (kw @ xk).view(H, S, 1)