Skip to content

Commit

Permalink
Add RWKV v5.1 and v5.2 support
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Nov 11, 2023
1 parent 9a43933 commit 6eb67ef
Show file tree
Hide file tree
Showing 34 changed files with 693 additions and 159 deletions.
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated from d925ed to a0fec8
33 changes: 27 additions & 6 deletions python/convert_pytorch_to_ggml.py
Expand Up @@ -32,6 +32,16 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
n_vocab: int = emb_weight.shape[0]
n_embed: int = emb_weight.shape[1]

is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict
is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict

if is_v5_2:
print('Detected RWKV v5.2')
elif is_v5_1_or_2:
print('Detected RWKV v5.1')
else:
print('Detected RWKV v4')

with open(dest_path, 'wb') as out_file:
is_FP16: bool = data_type == 'FP16' or data_type == 'float16'

Expand All @@ -50,16 +60,27 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
for k in state_dict.keys():
tensor: torch.Tensor = state_dict[k].float()

# Same processing as in "RWKV_in_150_lines.py"
if '.time_' in k:
# (1, 1, n_embed) -> (n_embed)
tensor = tensor.squeeze()

if '.time_decay' in k:
tensor = -torch.exp(tensor)
if is_v5_1_or_2:
if '.time_decay' in k:
if is_v5_2:
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
else:
tensor = torch.exp(-torch.exp(tensor)).reshape(-1, 1, 1)

if '.time_first' in k:
tensor = torch.exp(tensor).reshape(-1, 1, 1)

if '.time_faaaa' in k:
tensor = tensor.unsqueeze(-1)
else:
if '.time_decay' in k:
tensor = -torch.exp(tensor)

# Keep 1-dim vectors in FP32
if is_FP16 and len(tensor.shape) > 1:
# Keep 1-dim vectors and small matrices in FP32
if is_FP16 and len(tensor.shape) > 1 and '.time_' not in k:
tensor = tensor.half()

shape = tensor.shape
Expand Down
25 changes: 21 additions & 4 deletions python/merge_lora_into_ggml.py
Expand Up @@ -13,8 +13,9 @@
def parse_args():
parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file')
parser.add_argument('src_path', help='Path to source rwkv.cpp model')
parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2', type=str, choices=['v4', 'v5.1', 'v5.2'])
parser.add_argument('lora_path', help='Path to LoRA checkpoint in PyTorch format')
parser.add_argument('lora_alpha', type=int, help='Value of lora_alpha parameter used when training this LoRA checkpoint')
parser.add_argument('lora_alpha', help='Value of lora_alpha parameter used when training this LoRA checkpoint', type=int)
parser.add_argument('dest_path', help='Path to destination rwkv.cpp model, will be overwitten with the merged model')
return parser.parse_args()

Expand Down Expand Up @@ -44,6 +45,10 @@ def write_parameter(out_file, key: str, parameter: torch.Tensor) -> None:
def main() -> None:
args = parse_args()

arch_version: str = args.rwkv_arch_version

assert arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2', f'Invalid RWKV architecture version {arch_version}'

print(f'Reading {args.lora_path}')

lora_state_dict: Dict[str, torch.Tensor] = torch.load(args.lora_path, map_location='cpu')
Expand Down Expand Up @@ -96,11 +101,23 @@ def main() -> None:

# Same processing as in convert_pytorch_to_ggml.py
if '.time_' in key:
# (1, 1, n_embed) -> (n_embed)
replacement = replacement.squeeze()

if '.time_decay' in key:
replacement = -torch.exp(replacement)
if arch_version == 'v5.1' or arch_version == 'v5.2':
if '.time_decay' in key:
if arch_version == 'v5.2':
replacement = torch.exp(-torch.exp(replacement)).unsqueeze(-1)
else:
replacement = torch.exp(-torch.exp(replacement)).reshape(-1, 1, 1)

if '.time_first' in key:
replacement = torch.exp(replacement).reshape(-1, 1, 1)

if '.time_faaaa' in key:
replacement = replacement.unsqueeze(-1)
else:
if '.time_decay' in key:
replacement = -torch.exp(replacement)

if parameter.dtype == torch.float16:
replacement = replacement.half()
Expand Down
6 changes: 5 additions & 1 deletion rwkv.cpp
Expand Up @@ -121,7 +121,11 @@ size_t rwkv_get_n_layer(const struct rwkv_context * ctx) {
size_t rwkv_get_state_len(const struct rwkv_context * ctx) {
const struct rwkv_file_header & header = ctx->model->header;

return (size_t) header.n_embed * 5 * (size_t) header.n_layer;
if (ctx->model->arch_version_major >= 5) {
return (size_t) header.n_embed * (2 + ctx->model->head_size) * (size_t) header.n_layer;
} else {
return (size_t) header.n_embed * 5 * (size_t) header.n_layer;
}
}

// API function.
Expand Down
10 changes: 6 additions & 4 deletions rwkv_eval.inc
Expand Up @@ -176,16 +176,18 @@ bool rwkv_eval_sequence_in_chunks(

// API function.
void rwkv_init_state(const struct rwkv_context * ctx, float * state) {
memset(state, 0, rwkv_get_state_len(ctx) * sizeof(float));

if (ctx->model->arch_version_major >= 5) {
return;
}

const struct rwkv_file_header & header = ctx->model->header;
const size_t layer_size = (size_t) header.n_embed * 5;
const size_t layer_zero = (size_t) header.n_embed * 4;
const size_t layers_size = (size_t) header.n_layer * layer_size;

for (size_t start = 0; start < layers_size; start += layer_size) {
for (size_t i = 0; i < layer_zero; i++) {
state[start + i] = 0.0F;
}

for (size_t i = layer_zero; i < layer_size; i++) {
state[start + i] = -1e30F;
}
Expand Down
52 changes: 40 additions & 12 deletions rwkv_file_format.inc
Expand Up @@ -129,37 +129,61 @@ struct rwkv_tensor_header {
uint32_t dim_count;
uint32_t key_length;
uint32_t data_type;
uint32_t width;
uint32_t height;
uint32_t size0;
uint32_t size1;
uint32_t size2;

size_t size() const;
};

size_t rwkv_tensor_header::size() const {
return rwkv_tensor_nbytes(rwkv_type_to_ggml[this->data_type], this->width, this->height);
return rwkv_tensor_nbytes(rwkv_type_to_ggml[this->data_type], this->size0, this->size1, this->size2);
}

static bool rwkv_fread_tensor_header(FILE * file, struct rwkv_tensor_header & header) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_tensor_header) - sizeof(uint32_t), &header));
header.height = 1;
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_SHAPE, header.dim_count == 1 || header.dim_count == 2, "Tensor has an invalid shape (%" PRId32 " dimensions)", header.dim_count);
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, sizeof(struct rwkv_tensor_header) - sizeof(uint32_t) * 2, &header));
header.size1 = 1;
header.size2 = 1;

RWKV_ASSERT_FALSE_MSG(
RWKV_ERROR_SHAPE,
header.dim_count == 1 || header.dim_count == 2 || header.dim_count == 3,
"Tensor has an invalid shape (%" PRId32 " dimensions)",
header.dim_count
);

RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_DATA_TYPE, header.data_type < TYPE_COUNT, "Tensor data type out of range (%" PRId32 " > %" PRId32 ")", header.data_type, TYPE_COUNT - 1);

RWKV_ASSERT_FALSE_MSG(
RWKV_ERROR_DATA_TYPE,
rwkv_type_to_ggml[header.data_type] != GGML_TYPE_UNKNOWN,
"Tensor data type (%s) is no longer supported",
rwkv_type_to_string[header.data_type]
);

if (header.dim_count == 2) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_uint32(file, header.height));
if (header.dim_count >= 2) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_uint32(file, header.size1));
}

if (header.dim_count >= 3) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_READ, rwkv_fread_uint32(file, header.size2));
}

return true;
}

static bool rwkv_fwrite_tensor_header(FILE * file, const struct rwkv_tensor_header & header) {
RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_data(file, &header, sizeof(struct rwkv_tensor_header) - (header.dim_count == 1 ? sizeof(uint32_t) : 0)));
size_t sub;

if (header.dim_count == 1) {
sub = sizeof(uint32_t) * 2;
} else if (header.dim_count == 2) {
sub = sizeof(uint32_t);
} else {
sub = 0;
}

RWKV_ASSERT_FALSE(RWKV_ERROR_FILE_WRITE, rwkv_fwrite_data(file, &header, sizeof(struct rwkv_tensor_header) - sub));

return true;
}
Expand Down Expand Up @@ -204,9 +228,13 @@ static bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::
name.c_str()
);

tensor = header.dim_count == 1
? ggml_new_tensor_1d(ctx, ggml_type, header.width)
: ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height);
if (header.dim_count == 1) {
tensor = ggml_new_tensor_1d(ctx, ggml_type, header.size0);
} else if (header.dim_count == 2) {
tensor = ggml_new_tensor_2d(ctx, ggml_type, header.size0, header.size1);
} else {
tensor = ggml_new_tensor_3d(ctx, ggml_type, header.size0, header.size1, header.size2);
}

RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor != NULL, "Failed to allocate tensor");

Expand Down
4 changes: 4 additions & 0 deletions rwkv_gpu_offload.inc
Expand Up @@ -40,6 +40,10 @@ bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers)
offload(layer.att_receptance);
offload(layer.att_output);

if (layer.att_gate != NULL) {
offload(layer.att_gate);
}

offload(layer.ffn_key);
offload(layer.ffn_value);
offload(layer.ffn_receptance);
Expand Down

0 comments on commit 6eb67ef

Please sign in to comment.