Skip to content

Commit

Permalink
fix Dv6
Browse files Browse the repository at this point in the history
  • Loading branch information
YuChuXi committed Apr 3, 2024
1 parent 6ba24eb commit 3ef4086
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 9 deletions.
5 changes: 1 addition & 4 deletions python/convert_pytorch_to_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t

is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict
is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict
is_v6_0: bool = 'blocks.0.att.time_maa' in state_dict
is_v6_0: bool = 'blocks.0.att.time_maa_x' in state_dict

if is_v6_0:
print('Detected RWKV v6.0')
Expand Down Expand Up @@ -67,9 +67,6 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
tensor = tensor.squeeze()

if is_v6_0:
if '.time_first' in k:
tensor = torch.exp(tensor).reshape(-1, 1, 1)

if '.time_faaaa' in k:
tensor = tensor.unsqueeze(-1)

Expand Down
3 changes: 0 additions & 3 deletions python/merge_lora_into_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ def main() -> None:
replacement = replacement.squeeze()

if arch_version == 'v6.0':
if '.time_first' in key:
replacement = torch.exp(replacement).reshape(-1, 1, 1)

if '.time_faaaa' in key:
replacement = replacement.unsqueeze(-1)
if arch_version == 'v5.1' or arch_version == 'v5.2':
Expand Down
14 changes: 12 additions & 2 deletions rwkv_graph.inc
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ static bool rwkv_build_serial_graph(struct rwkv_model & model, struct rwkv_compu
x = ggml_add_inplace(ctx, x, rwkv_att(ctx, x, layer, state));

x = ggml_add_inplace(ctx, x, rwkv_ffn(ctx, x, layer, state));

}

struct rwkv_layer_state & output_state = outputs[i];
Expand Down Expand Up @@ -767,7 +767,17 @@ static bool rwkv_build_sequential_graph(struct rwkv_model & model, struct rwkv_c

struct rwkv_layer_state state = inputs[i];

if (model.arch_version_major >= 5) {
if (model.arch_version_major == 6) {
x = ggml_add_inplace(ctx, x, rwkv_att_v6(
ctx,
x,
layer,
state,
model.head_count,
model.head_size,
model.arch_version_minor
));
} else if (model.arch_version_major == 5) {
x = ggml_add_inplace(ctx, x, rwkv_att_v5(
ctx,
x,
Expand Down
4 changes: 4 additions & 0 deletions rwkv_model_loading.inc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ static bool rwkv_load_model_from_file(const char * file_path, struct rwkv_model
}
}

if (parameters.find("blocks.0.att.time_maa_x") != parameters.end()) {
model.arch_version_major = 6;
}

std::unordered_map<std::string, struct ggml_tensor *> & parameters_ref = parameters;
RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(
model,
Expand Down

0 comments on commit 3ef4086

Please sign in to comment.