Skip to content

Commit

Permalink
Fix LoRA merge script
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed May 6, 2023
1 parent 1fef5d1 commit 1d72fa2
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions rwkv/merge_lora_into_ggml.py
Expand Up @@ -113,29 +113,32 @@ def main() -> None:

del lora_state_dict[key]

lora_A_key: str = key.replace('.weight', '') + '.lora_A.weight'
lora_B_key: str = key.replace('.weight', '') + '.lora_B.weight'
for suffix in ['.weight', '']:
lora_A_key: str = key.replace('.weight', '') + '.lora_A' + suffix
lora_B_key: str = key.replace('.weight', '') + '.lora_B' + suffix

if lora_A_key in lora_state_dict:
lora_A: torch.Tensor = lora_state_dict[lora_A_key]
lora_B: torch.Tensor = lora_state_dict[lora_B_key]
if lora_A_key in lora_state_dict:
lora_A: torch.Tensor = lora_state_dict[lora_A_key]
lora_B: torch.Tensor = lora_state_dict[lora_B_key]

assert lora_B.shape[1] == lora_A.shape[0], f'Invalid shape of LoRA matrices for {key}: ' \
f'{lora_A.shape}, {lora_B.shape}'
assert lora_B.shape[1] == lora_A.shape[0], f'Invalid shape of LoRA matrices for {key}: ' \
f'{lora_A.shape}, {lora_B.shape}'

lora_R: int = lora_B.shape[1]
lora_R: int = lora_B.shape[1]

replacement: torch.Tensor = parameter + lora_B @ lora_A * (args.lora_alpha / lora_R)
replacement: torch.Tensor = parameter + lora_B @ lora_A * (args.lora_alpha / lora_R)

if parameter.dtype == torch.float16:
replacement = replacement.half()
if parameter.dtype == torch.float16:
replacement = replacement.half()

parameter = replacement
parameter = replacement

print(f'Merged LoRA into parameter {key}, lora_r = {lora_R}')

print(f'Merged LoRA into parameter {key}, lora_r = {lora_R}')
del lora_state_dict[lora_A_key]
del lora_state_dict[lora_B_key]

del lora_state_dict[lora_A_key]
del lora_state_dict[lora_B_key]
break

write_parameter(out_file, key, parameter)

Expand Down

0 comments on commit 1d72fa2

Please sign in to comment.