Skip to content

Commit

Permalink
refactoring to model version detection
Browse files Browse the repository at this point in the history
  • Loading branch information
schamane committed Nov 29, 2023
1 parent 31a0834 commit f14e2ce
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions python/convert_pytorch_to_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,21 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
n_vocab: int = emb_weight.shape[0]
n_embed: int = emb_weight.shape[1]

is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict
is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict

if is_v5_2:
print('Detected RWKV v5.2')
elif is_v5_1_or_2:
print('Detected RWKV v5.1')
else:
print('Detected RWKV v4')
version = 4
keys = list(state_dict.keys())
for k in keys:
if 'ln_x' in k:
version = max(5, version)
if 'gate.weight' in k:
version = max(5.1, version)
if int(version) == 5 and 'att.time_decay' in k:
if len(state_dict[k].shape) > 1:
if (state_dict[k].shape[1]) > 1:
version = max(5.2, version)
if "time_maa" in k:
version = max(6, version)

print(f'Model detected v{version:.1f}')

with open(dest_path, 'wb') as out_file:
is_FP16: bool = data_type == 'FP16' or data_type == 'float16'
Expand All @@ -57,15 +63,16 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
1 if is_FP16 else 0
))

for k in state_dict.keys():
keys = list(state_dict.keys())
for k in keys:
tensor: torch.Tensor = state_dict[k].float()

if '.time_' in k:
tensor = tensor.squeeze()

if is_v5_1_or_2:
if int(version) == 5:
if '.time_decay' in k:
if is_v5_2:
if version == 5.2:
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
else:
tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1)
Expand Down

0 comments on commit f14e2ce

Please sign in to comment.