diff --git a/rwkv/merge_lora_into_ggml.py b/rwkv/merge_lora_into_ggml.py index e5c7d3a..e7d9523 100644 --- a/rwkv/merge_lora_into_ggml.py +++ b/rwkv/merge_lora_into_ggml.py @@ -113,29 +113,32 @@ def main() -> None: del lora_state_dict[key] - lora_A_key: str = key.replace('.weight', '') + '.lora_A.weight' - lora_B_key: str = key.replace('.weight', '') + '.lora_B.weight' + for suffix in ['.weight', '']: + lora_A_key: str = key.replace('.weight', '') + '.lora_A' + suffix + lora_B_key: str = key.replace('.weight', '') + '.lora_B' + suffix - if lora_A_key in lora_state_dict: - lora_A: torch.Tensor = lora_state_dict[lora_A_key] - lora_B: torch.Tensor = lora_state_dict[lora_B_key] + if lora_A_key in lora_state_dict: + lora_A: torch.Tensor = lora_state_dict[lora_A_key] + lora_B: torch.Tensor = lora_state_dict[lora_B_key] - assert lora_B.shape[1] == lora_A.shape[0], f'Invalid shape of LoRA matrices for {key}: ' \ - f'{lora_A.shape}, {lora_B.shape}' + assert lora_B.shape[1] == lora_A.shape[0], f'Invalid shape of LoRA matrices for {key}: ' \ + f'{lora_A.shape}, {lora_B.shape}' - lora_R: int = lora_B.shape[1] + lora_R: int = lora_B.shape[1] - replacement: torch.Tensor = parameter + lora_B @ lora_A * (args.lora_alpha / lora_R) + replacement: torch.Tensor = parameter + lora_B @ lora_A * (args.lora_alpha / lora_R) - if parameter.dtype == torch.float16: - replacement = replacement.half() + if parameter.dtype == torch.float16: + replacement = replacement.half() - parameter = replacement + parameter = replacement + + print(f'Merged LoRA into parameter {key}, lora_r = {lora_R}') - print(f'Merged LoRA into parameter {key}, lora_r = {lora_R}') + del lora_state_dict[lora_A_key] + del lora_state_dict[lora_B_key] - del lora_state_dict[lora_A_key] - del lora_state_dict[lora_B_key] + break write_parameter(out_file, key, parameter)