Skip to content

Commit

Permalink
Sync ggml with upstream (#38)
Browse files Browse the repository at this point in the history
* Sync ggml with upstream

* Remove file filters from Actions triggers

* Update ggml

* Add Q4_2 and Q4_3 support

* Improve output of perplexity measuring script

* Add tests for new formats

* Add token limit argument to perplexity measuring script

* Update README

* Update README

* Update ggml

* Use master branch of ggml
  • Loading branch information
saharNooby committed Apr 22, 2023
1 parent ac66363 commit 3587ff9
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 61 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ on:
description: 'Create new release'
required: true
type: boolean
push:
paths: ['.github/workflows/**', '**/CMakeLists.txt', '**/*.h', '**/*.c', '**/*.cpp']
push: {}
pull_request:
types: [opened, synchronize, edited, reopened, review_requested, ready_for_review]
paths: ['**/CMakeLists.txt', '**/*.h', '**/*.c', '**/*.cpp']

env:
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
Expand Down
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,13 @@ python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-

Formats available:

- `4`: `Q4_1_O`, OK quality, moderately fast (20% slower than `FP16`).
- `3`: `Q4_1`, worst quality, fast (comparable to `FP16`).
- `2`: `Q4_0`, poor quality, very fast.
- `6`: `Q4_3`, OK quality, fast.
- `5`: `Q4_2`, poor quality, fast.
- `4`: `Q4_1_O`, best quality, slow (20% slower than `FP16`).
- `3`: `Q4_1`, poor quality, very fast.
- `2`: `Q4_0`, worst quality, very fast.

If you use `rwkv.cpp` for anything serious (just having fun is serious enough!), please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you.

### 4. Run the model

Expand Down
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated from 033090 to bfa8d5
43 changes: 18 additions & 25 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ bool read_int32(FILE * file, int32_t * dest) {
return true;
}

static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[5] = {
static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[7] = {
GGML_TYPE_F32,
GGML_TYPE_F16,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_Q4_1_O
GGML_TYPE_Q4_1_O,
GGML_TYPE_Q4_2,
GGML_TYPE_Q4_3
};

// --- Model definition and loading utilities ---
Expand Down Expand Up @@ -204,15 +206,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr
RWKV_ASSERT_NULL(model->n_layer > 0, "Non-positive n_layer %d", model->n_layer);

read_int32(file, &(model->data_type));
RWKV_ASSERT_NULL(
model->data_type == 0 ||
model->data_type == 1 ||
model->data_type == 2 ||
model->data_type == 3 ||
model->data_type == 4,
"Unsupported model data type %d",
model->data_type
);
RWKV_ASSERT_NULL(model->data_type >= 0 && model->data_type <= 6, "Unsupported model data type %d", model->data_type);

// Parameter tensors would take at least this amount in memory.
size_t file_size;
Expand Down Expand Up @@ -262,15 +256,7 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr

int32_t data_type;
read_int32(file, &data_type);
RWKV_ASSERT_NULL(
data_type == 0 ||
data_type == 1 ||
data_type == 2 ||
data_type == 3 ||
data_type == 4,
"Unsupported parameter data type %d",
data_type
);
RWKV_ASSERT_NULL(data_type >= 0 && data_type <= 6, "Unsupported parameter data type %d", data_type);

ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type];

Expand Down Expand Up @@ -581,9 +567,6 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float

memcpy(logits_out, ctx->logits->data, ctx->logits->ne[0] * FP32_SIZE);

// Uncomment to measure used memory for adding the value into get_memory_required_mb.
//fprintf(stderr, "Used mem: %d MB\n", ggml_used_mem(ctx->ctx) / 1024 / 1024);

return true;
}

Expand All @@ -597,7 +580,7 @@ void rwkv_free(struct rwkv_context * ctx) {
}

bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) {
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type);
RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4 || q_type == 5 || q_type == 6, "Unsupported quantization type %d", q_type);

// Needed to initialize FP16 lookup table
{
Expand Down Expand Up @@ -690,7 +673,9 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
"F16",
"Q4_0",
"Q4_1",
"Q4_1_O"
"Q4_1_O",
"Q4_2",
"Q4_3"
};
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]);

Expand Down Expand Up @@ -761,6 +746,14 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode
{
cur_size = ggml_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_2:
{
cur_size = ggml_quantize_q4_2(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_3:
{
cur_size = ggml_quantize_q4_3(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
default:
{
fprintf(stderr, "unsupported quantization type %d\n", type);
Expand Down
2 changes: 1 addition & 1 deletion rwkv.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ extern "C" {
// Returns false on any error. Error messages would be printed to stderr.
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
// - model_file_path_out: quantized model will be written here.
// - q_type: set to 2 for GGML_TYPE_Q4_0, set to 3 for GGML_TYPE_Q4_1, set to 4 for GGML_TYPE_Q4_1_O.
// - q_type: set to 2 for GGML_TYPE_Q4_0, 3 for GGML_TYPE_Q4_1, 4 for GGML_TYPE_Q4_1_O, 5 for GGML_TYPE_Q4_2, 6 for GGML_TYPE_Q4_3.
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type);

// Returns system information string.
Expand Down
4 changes: 2 additions & 2 deletions rwkv/convert_pytorch_to_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# int32 n_vocab;
# int32 n_embed;
# int32 n_layer;
# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O.
# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O, 5 if Q4_2, 6 if Q4_3.
# int32 data_type;
# // Read until EOF.
# Parameter[] parameters;
Expand All @@ -21,7 +21,7 @@
# Parameter {
# int32 dim_count;
# int32 key_length;
# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O.
# // 0 if float32, 1 if float16, 2 if Q4_0, 3 if Q4_1, 4 if Q4_1_O, 5 if Q4_2, 6 if Q4_3.
# int32 data_type;
# // Compared to PyTorch's tensor.shape, dimension order is reversed here!
# int32[dim_count] shape;
Expand Down
24 changes: 16 additions & 8 deletions rwkv/measure_pexplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

def parse_args():
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
parser.add_argument('model_path', help='Path to model checkpoint file')
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding')
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int, default=1024)
parser.add_argument('model_path', help='Path to model checkpoint file', type=str)
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding', type=str)
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int)
parser.add_argument('token_limit', help='How many tokens to process; set to -1 to process all text', nargs='?', type=int, default=-1)
return parser.parse_args()

args = parse_args()
Expand All @@ -33,6 +34,15 @@ def parse_args():
token_count: int = len(tokens)
print(f'{token_count} tokens in the text')

token_limit: int = args.token_limit

assert token_limit == -1 or token_limit > 0, 'Invalid token_limit'

if token_limit != -1 and token_count > token_limit:
tokens = tokens[0:token_limit]
token_count = token_limit
print(f'Text was limited to {token_limit} tokens')

assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation'

# ---
Expand Down Expand Up @@ -73,7 +83,7 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:
loss_sum += losses
loss_count += 1

if i % 10 == 0:
if run_count <= 5 or i % (run_count // 10) == 0:
avg_loss_so_far = loss_sum / loss_count

duration: float = time.time() - start
Expand All @@ -90,11 +100,9 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:
else:
print()

print()
print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token')

print()
print(f'Model: {os.path.basename(args.model_path)}, '
f'data: {os.path.basename(args.text_path)} with {token_count} tokens, '
f'skipped {args.ignore_first_n_tokens} tokens, '
f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}')
f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}, '
f'latency {int((time.time() - start) * 1000 / run_count)} ms per token')
17 changes: 7 additions & 10 deletions rwkv/quantize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1 or Q4_1_O (recommended).
# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1, Q4_1_O, Q4_2, Q4_3.
# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1_o.bin 4

import argparse
Expand All @@ -8,20 +8,17 @@ def parse_args():
parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1')
parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file')
parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten')
parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0), 3 (GGML_TYPE_Q4_1) or 4 (GGML_TYPE_Q4_1_O)', type=int, choices=[2, 3, 4], default=4)
parser.add_argument('data_type', help='Data type, '
'2 (GGML_TYPE_Q4_0), '
'3 (GGML_TYPE_Q4_1), '
'4 (GGML_TYPE_Q4_1_O), '
'5 (Q4_2), '
'6 (Q4_3)', type=int, choices=[2, 3, 4, 5, 6], default=4)
return parser.parse_args()

def main() -> None:
args = parse_args()

if args.data_type == 2 or args.data_type == 3:
print()
print('WARNING!')
print('You are using Q4_0 or Q4_1 quantization; it will heavily degrade RWKV quality.')
print('For best quality preservation, it is recommended to use Q4_1_O.')
print('More info at https://github.com/saharNooby/rwkv.cpp/issues/12')
print()

library = rwkv_cpp_shared_library.load_rwkv_shared_library()

library.rwkv_quantize_model_file(
Expand Down
28 changes: 20 additions & 8 deletions tests/test_tiny_rwkv.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,21 @@ int main(int argc, const char ** argv) {
ASSERT(elements_read == N_VOCAB, "Failed to read expected_logits.bin, read %zd elements", elements_read);
fclose(file);

float expected_difference_sum[8] = {
float expected_difference_sum[12] = {
0.000000F,
-0.005320F,

-0.501214F,
-1.092427F,
-0.370606F,
-0.268956F,
0.676837F,
0.237099F,

-0.501073F,
-1.103214F,
-0.244590F
-0.372169F,
-0.244590F,
0.674874F,
0.243007F
};

test_model("tiny-rwkv-660K-FP32.bin", expected_logits, expected_difference_sum[0]);
Expand All @@ -88,18 +92,26 @@ int main(int argc, const char ** argv) {
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_0.bin", 2);
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1.bin", 3);
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_1_O.bin", 4);
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_2.bin", 5);
rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q4_3.bin", 6);

test_model("tiny-rwkv-660K-FP32-Q4_0.bin", expected_logits, expected_difference_sum[2]);
test_model("tiny-rwkv-660K-FP32-Q4_1.bin", expected_logits, expected_difference_sum[3]);
test_model("tiny-rwkv-660K-FP32-Q4_1_O.bin", expected_logits, expected_difference_sum[4]);
test_model("tiny-rwkv-660K-FP32-Q4_2.bin", expected_logits, expected_difference_sum[5]);
test_model("tiny-rwkv-660K-FP32-Q4_3.bin", expected_logits, expected_difference_sum[6]);

rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_0.bin", 2);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1.bin", 3);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_1_O.bin", 4);

test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[5]);
test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[6]);
test_model("tiny-rwkv-660K-FP16-Q4_1_O.bin", expected_logits, expected_difference_sum[7]);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_2.bin", 5);
rwkv_quantize_model_file("tiny-rwkv-660K-FP16.bin", "tiny-rwkv-660K-FP16-Q4_3.bin", 6);

test_model("tiny-rwkv-660K-FP16-Q4_0.bin", expected_logits, expected_difference_sum[7]);
test_model("tiny-rwkv-660K-FP16-Q4_1.bin", expected_logits, expected_difference_sum[8]);
test_model("tiny-rwkv-660K-FP16-Q4_1_O.bin", expected_logits, expected_difference_sum[9]);
test_model("tiny-rwkv-660K-FP16-Q4_2.bin", expected_logits, expected_difference_sum[10]);
test_model("tiny-rwkv-660K-FP16-Q4_3.bin", expected_logits, expected_difference_sum[11]);

free(expected_logits);

Expand Down

0 comments on commit 3587ff9

Please sign in to comment.