diff --git a/README.md b/README.md index e2b8dc7..e709bd6 100644 --- a/README.md +++ b/README.md @@ -10,9 +10,10 @@ This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrap **TODO (contributions welcome!)**: -1. Measure latency and perplexity of different model sizes (169M to 14B) and data types (FP32, FP16, Q4_0, Q4_1) -2. Test on Linux (including Colab) and MacOS -3. Make required memory calculation more robust (see #4) +1. Optimize AVX2 implementation of `Q4_1_O` matmul — currently, it is as slow as `FP32` +2. Measure latency and perplexity of different model sizes (169M to 14B) and data types (`FP32`, `FP16`, `Q4_0`, `Q4_1`, `Q4_1_O`) +3. Test on Linux (including Colab) and MacOS +4. Make required memory calculation more robust (see [#4](https://github.com/saharNooby/rwkv.cpp/issues/4)) ## How to use @@ -68,7 +69,7 @@ If everything went OK, `librwkv.so` (Linux) or `rwkv.o` (MacOS) file should appe ```commandline # Windows -python rwkv\convert_rwkv_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float16 +python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float16 # Linux / MacOS python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin float16 @@ -80,13 +81,17 @@ To convert the model into INT4 quantized format, run: ```commandline # Windows -python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q4_1.bin 3 +python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q4_1_O.bin 4 # Linux / MacOS -python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q4_1.bin 3 +python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin 4 ``` -Pass `2` for `Q4_0` format (smaller size, lower quality), `3` for `Q4_1` format (larger size, higher quality). +Formats available: + +- `4`: `Q4_1_O`, best quality, very slow (as `FP32`). +- `3`: `Q4_1`, poor quality, very fast (as `FP16`). +- `2`: `Q4_0`, worst quality, breaks larger models, moderately fast (between `FP16` and `FP32`). ### 4. Run the model @@ -98,20 +103,20 @@ To generate some text, run: ```commandline # Windows -python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q4_1.bin +python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q4_1_O.bin # Linux / MacOS -python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q4_1.bin +python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin ``` To chat with a bot, run: ```commandline # Windows -python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q4_1.bin +python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q4_1_O.bin # Linux / MacOS -python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q4_1.bin +python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q4_1_O.bin ``` Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings. diff --git a/ggml.c b/ggml.c index 05304a8..6c946d0 100644 --- a/ggml.c +++ b/ggml.c @@ -25,6 +25,35 @@ #define static_assert(cond, msg) struct global_scope_noop_trick #endif +// https://gist.github.com/rygorous/2144712 +// Public domain, by Fabian "ryg" Giesen +inline static float ggml_half_to_float_reference(uint16_t value) { + union FP32 { + uint32_t u; + float f; + }; + + const union FP32 magic = { (254UL - 15UL) << 23 }; + const union FP32 was_inf_nan = { (127UL + 16UL) << 23 }; + + union FP32 out; + + // Exponent/mantissa bits + out.u = (value & 0x7FFFU) << 13; + // Exponent adjust + out.f *= magic.f; + + // Make sure Inf/NaN survive + if (out.f >= was_inf_nan.f) { + out.u |= 255UL << 23; + } + + // Sign bit + out.u |= (value & 0x8000UL) << 16; + + return out.f; +} + #if defined _MSC_VER || defined(__MINGW32__) #if !defined(__MINGW32__) @@ -326,42 +355,13 @@ static float table_f32_f16[1 << 16]; // This is also true for POWER9. #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) -// https://gist.github.com/rygorous/2144712 -// Public domain, by Fabian "ryg" Giesen -inline static float ggml_half_to_float(uint16_t value) { - union FP32 { - uint32_t u; - float f; - }; - - const union FP32 magic = { (254UL - 15UL) << 23 }; - const union FP32 was_inf_nan = { (127UL + 16UL) << 23 }; - union FP32 out; - - // Exponent/mantissa bits - out.u = (value & 0x7FFFU) << 13; - // Exponent adjust - out.f *= magic.f; - - // Make sure Inf/NaN survive - if (out.f >= was_inf_nan.f) { - out.u |= 255UL << 23; - } - - // Sign bit - out.u |= (value & 0x8000UL) << 16; - - return out.f; -} inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { - // For some reason, lookup table does not work on my machine: - // - Windows SDK version 10.0.19041.0 - // - CMAKE_SYSTEM_PROCESSOR: AMD64 - // Replaced lookup with some conversion code found online. + // For some reason, lookup table does not work on my machine. + // Replaced lookup with working reference code. // TODO This must be properly debugged and fixed - return ggml_half_to_float(f); + return ggml_half_to_float_reference(f); } #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) @@ -514,6 +514,19 @@ typedef struct { } block_q4_1; static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding"); +// Method 4 with better outlier handling. +typedef struct { + ggml_fp16_t d; + ggml_fp16_t m; + // We need only 5 bits for the in-block index, so 16 bits is overkill. + // TODO Optimize if possible + uint16_t outlier_index; + ggml_fp16_t outlier_value; + // Nibbles / quants. + uint8_t qs[QK / 2]; +} block_q4_1_o; +static_assert(sizeof(block_q4_1_o) == 8 + QK / 2, "wrong q4_1_o block size/padding"); + // reference implementation for deterministic creation of model files static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { assert(k % QK == 0); @@ -1118,6 +1131,208 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in #endif } +// Q4_1_O + +static inline void quantize_row_q4_1_o_reference_single_block(const float * restrict x, block_q4_1_o * restrict block) { + // An outlier is just the absmax element in the block. + // We store it separately and do not quantize it. + int outlier_index = -1; + float outlier_value = 0.0F; + + for (int l = 0; l < QK; l++) { + const float v = x[l]; + + if (fabsf(v) > fabsf(outlier_value)) { + outlier_index = l; + outlier_value = v; + } + } + + block->outlier_index = outlier_index; + block->outlier_value = GGML_COMPUTE_FP32_TO_FP16(outlier_value); + + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int l = 0; l < QK; l++) { + if (l == outlier_index) { + // Ignore outlier when computing range. + continue; + } + + const float v = x[l]; + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0F / d : 0.0F; + + block->d = GGML_COMPUTE_FP32_TO_FP16(d); + block->m = GGML_COMPUTE_FP32_TO_FP16(min); + + uint8_t pp[QK / 2]; + + for (int l = 0; l < QK; l += 2) { + float v0 = (x[l + 0] - min) * id; + float v1 = (x[l + 1] - min) * id; + + // Write some garbage but valid index for the outlier. + if (l + 0 == outlier_index) v0 = 0.0; + if (l + 1 == outlier_index) v1 = 0.0; + + const uint8_t vi0 = roundf(v0); + const uint8_t vi1 = roundf(v1); + + assert(vi0 >= 0 && vi0 < 16); + assert(vi1 >= 0 && vi1 < 16); + + pp[l/2] = vi0 | (vi1 << 4); + } + + memcpy(block->qs, pp, sizeof(pp)); +} + +static inline void dequantize_row_q4_1_o_reference_single_block(block_q4_1_o * restrict block, float * restrict y) { + const float d = ggml_half_to_float_reference(block->d); + const float m = ggml_half_to_float_reference(block->m); + + const uint8_t * restrict pp = block->qs; + + for (int l = 0; l < QK; l += 2) { + const uint8_t vi = pp[l / 2]; + + const int8_t vi0 = vi & 0xF; + const int8_t vi1 = vi >> 4; + + const float v0 = vi0 * d + m; + const float v1 = vi1 * d + m; + + y[l + 0] = v0; + y[l + 1] = v1; + + assert(!isnan(y[l + 0])); + assert(!isnan(y[l + 1])); + } + + // Restore the outlier + y[block->outlier_index] = ggml_half_to_float_reference(block->outlier_value); +} + +static void quantize_row_q4_1_o_reference(const float * restrict x, void * restrict vy, int k) { + assert(k % QK == 0); + const int nb = k / QK; + + block_q4_1_o * restrict y = vy; + + for (int i = 0; i < nb; i++) { + quantize_row_q4_1_o_reference_single_block(x + i * QK, y + i); + } +} + +static void quantize_row_q4_1_o(const float * restrict x, void * restrict vy, int k) { + quantize_row_q4_1_o_reference(x, vy, k); +} + +static void dequantize_row_q4_1_o(const void * restrict vx, float * restrict y, int k) { + assert(k % QK == 0); + const int nb = k / QK; + + const block_q4_1_o * restrict x = vx; + +#if defined(__AVX2__) + for (int i = 0; i < nb; i++) { + const float x_d = ggml_half_to_float_reference(x[i].d); + const float x_m = ggml_half_to_float_reference(x[i].m); + + const __m256 d_v = _mm256_broadcast_ss(&x_d); + const __m256 d_m = _mm256_broadcast_ss(&x_m); + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK; l += 32) { + // Load 32x4-bit integers into 32x8-bit integers + __m256i vx8 = bytesFromNibbles(pp+l/2); + + // Convert to 16-bit int + const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0)); + const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1)); + + // Convert to 32-bit int -> float 32 + const __m256 vf[4] = { + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))), + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1))) + }; + + // Scale, add m and store + for (int j = 0; j < 4; j++) { + const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m); + _mm256_storeu_ps(y + i * QK + l + j*8, result); + } + } + + // Restore the outlier + y[i * QK + x[i].outlier_index] = ggml_half_to_float_reference(x[i].outlier_value); + } +#elif defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + const float x_d = ggml_half_to_float_reference(x[i].d); + const float x_m = ggml_half_to_float_reference(x[i].m); + + const float32x4_t vd = vdupq_n_f32(x_d); + const float32x4_t vm = vdupq_n_f32(x_m); + + const uint8_t * restrict pp = x[i].qs; + + for (int l = 0; l < QK; l += 16) { + // Load 16x4-bit integers into 8x8-bit integers + const uint8x8_t v8 = vld1_u8(pp + l/2); + + // Expand 4-bit qs to 8-bit bytes + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v1 = vshr_n_u8(v8, 4); + + // Interleave and combine + const uint8x8_t vx_0 = vzip1_u8(v0, v1); + const uint8x8_t vx_1 = vzip2_u8(v0, v1); + + const uint8x16_t vq = vcombine_u8(vx_0, vx_1); + + // convert to 2x uint16x8_t + const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq)); + const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq)); + + // convert to 4x float32x4_t + const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0))); + const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0))); + const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1))); + const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1))); + + // multiply by d and add m + const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd); + const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd); + const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd); + const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd); + + // Store + vst1q_f32(y + i*QK + l + 0, r0); + vst1q_f32(y + i*QK + l + 4, r1); + vst1q_f32(y + i*QK + l + 8, r2); + vst1q_f32(y + i*QK + l + 12, r3); + } + + // Restore the outlier + y[i * QK + x[i].outlier_index] = ggml_half_to_float_reference(x[i].outlier_value); + } +#else + for (int i = 0; i < nb; i++) { + dequantize_row_q4_1_o_reference_single_block(x + i, y + i * QK); + } +#endif +} + // // simd mappings // @@ -2437,6 +2652,7 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x // static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { + QK, QK, QK, 1, @@ -2446,11 +2662,12 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { 1, }; -static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); +static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_COUNT != 8"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { sizeof(block_q4_0), sizeof(block_q4_1), + sizeof(block_q4_1_o), sizeof(int8_t ), sizeof(int16_t), sizeof(int32_t), @@ -2459,7 +2676,7 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { }; // don't forget to update the array above when adding new types -static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5"); +static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_COUNT != 8"); static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -2477,9 +2694,14 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "ABS", "SGN", "NEG", + "EXP", + "1_MINUS_X", + "MAX", + "STEP", "RELU", "GELU", + "SIGMOID", "SILU", "NORM", "RMS_NORM", @@ -2521,9 +2743,14 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "abs(x)", "sgn(x)", "-x", + "e^x", + "1-x", + "max(x,y)", + "step(x)", "relu(x)", "gelu(x)", + "sigmoid(x)", "silu(x)", "norm(x)", "rms_norm(x)", @@ -3186,6 +3413,10 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_1_O: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -3246,6 +3477,10 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_1_O: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -3300,6 +3535,10 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_1_O: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3344,6 +3583,10 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_1_O: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3386,6 +3629,10 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_1_O: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3430,6 +3677,10 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q4_1_O: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -4980,6 +5231,7 @@ static void ggml_compute_forward_dup( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5057,6 +5309,7 @@ static void ggml_compute_forward_add( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5109,6 +5362,7 @@ static void ggml_compute_forward_sub( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5161,6 +5415,7 @@ static void ggml_compute_forward_mul( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5213,6 +5468,7 @@ static void ggml_compute_forward_div( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5261,6 +5517,7 @@ static void ggml_compute_forward_sqr( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5309,6 +5566,7 @@ static void ggml_compute_forward_sqrt( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5367,6 +5625,7 @@ static void ggml_compute_forward_sum( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5444,6 +5703,7 @@ static void ggml_compute_forward_mean( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5508,6 +5768,7 @@ static void ggml_compute_forward_repeat( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5556,6 +5817,7 @@ static void ggml_compute_forward_abs( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5604,6 +5866,7 @@ static void ggml_compute_forward_sgn( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5652,6 +5915,7 @@ static void ggml_compute_forward_neg( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5700,6 +5964,7 @@ static void ggml_compute_forward_exp( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5748,6 +6013,7 @@ static void ggml_compute_forward_1_minus_x( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5800,6 +6066,7 @@ static void ggml_compute_forward_max( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5848,6 +6115,7 @@ static void ggml_compute_forward_step( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5896,6 +6164,7 @@ static void ggml_compute_forward_relu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5961,6 +6230,7 @@ static void ggml_compute_forward_gelu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6011,6 +6281,7 @@ static void ggml_compute_forward_sigmoid( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6076,6 +6347,7 @@ static void ggml_compute_forward_silu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6162,6 +6434,7 @@ static void ggml_compute_forward_norm( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6242,6 +6515,7 @@ static void ggml_compute_forward_rms_norm( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6659,6 +6933,11 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .quantize_row_q = quantize_row_q4_1, .vec_dot_q = ggml_vec_dot_q4_1, }, + [GGML_TYPE_Q4_1_O] = { + .dequantize_row_q = dequantize_row_q4_1_o, + .quantize_row_q = quantize_row_q4_1_o, + .vec_dot_q = NULL, + }, }; static void ggml_compute_forward_mul_mat_q_f32( @@ -6849,6 +7128,273 @@ static void ggml_compute_forward_mul_mat_q_f32( //} } +static void ggml_compute_forward_mul_mat_q4_1_o_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + const enum ggml_type type = src0->type; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); + GGML_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + if (params->ith != 0) { + return; + } + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + float * const wdata = params->wdata; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + { + size_t id = 0; + for (int i01 = 0; i01 < ne01; ++i01) { + dequantize_row_q4_1_o((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); + id += ne00; + } + } + + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // zT = y * xT + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); + } + } + + //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by src0 rows using ggml_vec_dot_f32 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + +#if defined(__AVX2__) + float outlier_mask[QK]; + memset(outlier_mask, 0, QK * sizeof(float)); +#endif + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + +#if defined(__AVX2__) + for (int ic = 0; ic < ne11; ++ic) { + // src1 indices + const int i13 = i03; + const int i12 = i02; + const int i11 = ic; + + // dst indices + const int i0 = i01; + const int i1 = i11; + const int i2 = i02; + const int i3 = i03; + + const int block_count = ne00 / QK; + + const block_q4_1_o * row_blocks = (block_q4_1_o *) ((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03)); + + __m256 accum = _mm256_setzero_ps(); + + // Here we do fused dequantization and dot product. + for (int block_index = 0; block_index < block_count; block_index++) { + const float block_d = ggml_half_to_float_reference(row_blocks[block_index].d); + const float block_m = ggml_half_to_float_reference(row_blocks[block_index].m); + + // 0 .. 31 + const uint16_t outlier_index = row_blocks[block_index].outlier_index; + const float outlier_value = ggml_half_to_float_reference(row_blocks[block_index].outlier_value); + + const uint8_t * restrict quant_nibbles = row_blocks[block_index].qs; + + // --- + + // Broadcast values to 8x element float32 vectors + const __m256 broadcasted_d = _mm256_broadcast_ss(&block_d); + const __m256 broadcasted_m = _mm256_broadcast_ss(&block_m); + const __m256 broadcasted_outlier_value = _mm256_broadcast_ss(&outlier_value); + + // Load 32x4-bit integers into 32x8-bit integers + const __m256i quant_bytes = bytesFromNibbles(quant_nibbles); + + // Convert to 16-bit int + const __m256i quant_shorts_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(quant_bytes, 0)); + const __m256i quant_shorts_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(quant_bytes, 1)); + + // Convert to 32-bit int and then to 32-bit float + const __m256 quant_floats_0 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_lo, 0))); + const __m256 quant_floats_1 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_lo, 1))); + const __m256 quant_floats_2 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_hi, 0))); + const __m256 quant_floats_3 = _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(quant_shorts_hi, 1))); + + // Dequantize to ~original weights + const __m256 weight_0 = _mm256_fmadd_ps(quant_floats_0, broadcasted_d, broadcasted_m); + const __m256 weight_1 = _mm256_fmadd_ps(quant_floats_1, broadcasted_d, broadcasted_m); + const __m256 weight_2 = _mm256_fmadd_ps(quant_floats_2, broadcasted_d, broadcasted_m); + const __m256 weight_3 = _mm256_fmadd_ps(quant_floats_3, broadcasted_d, broadcasted_m); + + // TODO This outlier handling is VERY slow + // Set outlier mask -- this should give 1 in the most significant bit + outlier_mask[outlier_index] = -1.0F; + // Load mask into vectors + const __m256 outlier_mask_0 = _mm256_load_ps(outlier_mask); + const __m256 outlier_mask_1 = _mm256_load_ps(outlier_mask + 8); + const __m256 outlier_mask_2 = _mm256_load_ps(outlier_mask + 16); + const __m256 outlier_mask_3 = _mm256_load_ps(outlier_mask + 24); + // Reset mask array to all zeroes for the next iteration + outlier_mask[outlier_index] = 0.0F; + + // Replace the weight at the index of the outlier + const __m256 weight_0_with_outlier = _mm256_blendv_ps(weight_0, broadcasted_outlier_value, outlier_mask_0); + const __m256 weight_1_with_outlier = _mm256_blendv_ps(weight_1, broadcasted_outlier_value, outlier_mask_1); + const __m256 weight_2_with_outlier = _mm256_blendv_ps(weight_2, broadcasted_outlier_value, outlier_mask_2); + const __m256 weight_3_with_outlier = _mm256_blendv_ps(weight_3, broadcasted_outlier_value, outlier_mask_3); + + // Load 32 floats of data of the second argument + const float * src1_data = (float *) ((char *) src1->data + (block_index * QK * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13)); + + const __m256 src1_0 = _mm256_load_ps(src1_data); + const __m256 src1_1 = _mm256_load_ps(src1_data + 8); + const __m256 src1_2 = _mm256_load_ps(src1_data + 16); + const __m256 src1_3 = _mm256_load_ps(src1_data + 24); + + // Multiply weights and values of the second argument element-wise; add to accumulator + accum = _mm256_fmadd_ps(src1_0, weight_0_with_outlier, accum); + accum = _mm256_fmadd_ps(src1_1, weight_1_with_outlier, accum); + accum = _mm256_fmadd_ps(src1_2, weight_2_with_outlier, accum); + accum = _mm256_fmadd_ps(src1_3, weight_3_with_outlier, accum); + } + + // Add elements of accumulator + __m128 res = _mm256_extractf128_ps(accum, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(accum)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res )); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + + *((float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3))) = _mm_cvtss_f32(res); + } +#else + float * const wdata = (float *) ((char *) params->wdata + (i01 * nb01 + i02 * nb02 + i03 * nb03)); + + dequantize_row_q4_1_o((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03), wdata, ne00); + + for (int ic = 0; ic < ne11; ++ic) { + // src1 indices + const int i13 = i03; + const int i12 = i02; + const int i11 = ic; + + // dst indices + const int i0 = i01; + const int i1 = i11; + const int i2 = i02; + const int i3 = i03; + + ggml_vec_dot_f32( + ne00, + (float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3)), + wdata, + (float *) ((char *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13)) + ); + } +#endif + } +} + static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -6860,6 +7406,10 @@ static void ggml_compute_forward_mul_mat( { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); } break; + case GGML_TYPE_Q4_1_O: + { + ggml_compute_forward_mul_mat_q4_1_o_f32(params, src0, src1, dst); + } break; case GGML_TYPE_F16: { ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); @@ -6955,6 +7505,7 @@ static void ggml_compute_forward_scale( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7111,6 +7662,7 @@ static void ggml_compute_forward_get_rows( switch (src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -7200,6 +7752,7 @@ static void ggml_compute_forward_diag_mask_inf( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7294,6 +7847,7 @@ static void ggml_compute_forward_soft_max( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7436,6 +7990,7 @@ static void ggml_compute_forward_rope( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7704,6 +8259,7 @@ static void ggml_compute_forward_conv_1d_1s( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7972,6 +8528,7 @@ static void ggml_compute_forward_conv_1d_2s( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8457,6 +9014,7 @@ static void ggml_compute_forward_flash_attn( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8668,6 +9226,7 @@ static void ggml_compute_forward_flash_ff( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1_O: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -9498,6 +10057,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) #endif } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; + } else if (node->src0->type == GGML_TYPE_Q4_1_O && node->src1->type == GGML_TYPE_F32) { +#if defined(__AVX2__) + cur = 0; +#else + // Assuming that src1 is a vector + // TODO Not sure whether this is correct + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * ggml_nelements(node->src1); +#endif } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) { #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { @@ -10719,6 +11286,29 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * return (n/QK*sizeof(block_q4_1)); } +size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK == 0); + const int nb = k / QK; + + for (int j = 0; j < n; j += k) { + block_q4_1_o * restrict y = (block_q4_1_o *) dst + j / QK; + + quantize_row_q4_1_o_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + for (int l = 0; l < QK; l += 2) { + const uint8_t vi0 = y[i].qs[l / 2] & 0xF; + const uint8_t vi1 = y[i].qs[l / 2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n / QK * sizeof(block_q4_1_o)); +} + //////////////////////////////////////////////////////////////////////////////// int ggml_cpu_has_avx(void) { @@ -10837,7 +11427,6 @@ int ggml_cpu_has_vsx(void) { // Copied from https://github.com/ggerganov/llama.cpp/blob/6e7801d08d81c931a5427bae46f00763e993f54a/tests/test-quantize.c void ggml_test_quantization(void) { - #define QK 32 float src[QK]; uint8_t dst[24]; int64_t hist[16]; @@ -10847,7 +11436,7 @@ void ggml_test_quantization(void) { } size_t size = ggml_quantize_q4_0(src, dst, QK, QK, hist); - GGML_TEST_ASSERT(size == 20, "%d", size); + GGML_TEST_ASSERT(size == 20, "%zd", size); float max_result = ((float *) dst)[0]; float max_expected = src[31] / ((1 << 3) - 1); GGML_TEST_ASSERT(max_result == max_expected, "%f, %f", max_result, max_expected); @@ -10858,7 +11447,7 @@ void ggml_test_quantization(void) { } size = ggml_quantize_q4_1(src, dst, QK, QK, hist); - GGML_TEST_ASSERT(size == 24, "%d", size); + GGML_TEST_ASSERT(size == 24, "%zd", size); float delta_result = ((float *) dst)[0]; float delta_expected = (src[31] - src[0]) / ((1 << 4) - 1); GGML_TEST_ASSERT(delta_result == delta_expected, "%f, %f", delta_result, delta_expected); @@ -10872,8 +11461,55 @@ void ggml_test_quantization(void) { } } +void ggml_test_quantization_q4_1_o(void) { + float src[QK]; + uint8_t dst[24]; + int64_t hist[16]; + + for (int i = 0; i < QK; i++) { + src[i] = (float) (i + 1); + } + + size_t size = ggml_quantize_q4_1_o(src, dst, QK, QK, hist); + GGML_TEST_ASSERT(size == 24, "%zd", size); + + float delta_result = ggml_half_to_float_reference(((block_q4_1_o *) dst)->d); + float delta_expected = (src[30] - src[0]) / ((1 << 4) - 1); + GGML_TEST_ASSERT(delta_result == delta_expected, "%f, %f", delta_result, delta_expected); + + float min_result = ggml_half_to_float_reference(((block_q4_1_o *) dst)->m); + float min_expected = src[0]; + GGML_TEST_ASSERT(min_result == min_expected, "%f, %f", min_result, min_expected); + + uint16_t outlier_index = ((block_q4_1_o *) dst)->outlier_index; + uint16_t outlier_index_expected = 31; + GGML_TEST_ASSERT(outlier_index == outlier_index_expected, "%d, %d", outlier_index, outlier_index_expected); + + float outlier_value_result = ggml_half_to_float_reference(((block_q4_1_o *) dst)->outlier_value); + float outlier_value_expected = src[31]; + GGML_TEST_ASSERT(outlier_value_result == outlier_value_expected, "%f, %f", outlier_value_result, outlier_value_expected); + + for (int i = 0; i < QK - 1; i++) { + uint8_t q4_result = (i % 2) ? (dst[sizeof(float) * 2 + i / 2] >> 4) : (dst[sizeof(float) * 2 + i / 2] & 0xF); + uint8_t q4_expected = roundf((src[i] - min_expected) / delta_expected); + GGML_TEST_ASSERT(q4_result == q4_expected, "%d: %d, %d", i, q4_result, q4_expected); + } + + float dequantized[QK]; + dequantize_row_q4_1_o(dst, dequantized, QK); + + for (int i = 0; i < QK; i++) { + float actual = dequantized[i]; + float expected = src[i]; + float diff = fabsf(actual - expected); + // Difference looks huge, but the range is 0..31 -- compared to range, it is not that huge + GGML_TEST_ASSERT(diff <= 1.0F, "%d: %f, %f", i, actual, expected); + } +} + void ggml_run_test_suite(void) { ggml_test_quantization(); + ggml_test_quantization_q4_1_o(); struct ggml_init_params params; params.mem_size = 16 * 1024; diff --git a/ggml.h b/ggml.h index 0b7f9a3..03b3369 100644 --- a/ggml.h +++ b/ggml.h @@ -186,7 +186,8 @@ // - to `ggml_compute_forward` and call the forward dispatch function here. // - to `ggml_compute_backward` and add `GGML_ASSERT(false)` here. // - to `ggml_graph_compute` and add `node->n_tasks = 1` here. -// 6. Fix all assertions that check value of `GGML_OP_COUNT`: you've added 1 operator, so increment asserted value by one. +// 6. Add operator label to `GGML_OP_LABEL` array and operator symbol to `GGML_OP_SYMBOL` array. +// 7. Fix all assertions that check value of `GGML_OP_COUNT`: you've added 1 operator, so increment asserted value by one. // // When in doubt, consult the code of existing operators similar to that you're implementing. // Resulting operator would work for the forward pass, but will lack backward implementation and multi-threading support. @@ -225,7 +226,11 @@ struct ggml_context; enum ggml_type { GGML_TYPE_Q4_0, + // Stores min and delta per block, does quantized matmul. GGML_TYPE_Q4_1, + // Same as Q4_1, but stores outliers separately, and matmul is done in FP32. + // An outlier is the single absmax element in the quantized block. + GGML_TYPE_Q4_1_O, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -806,6 +811,7 @@ enum ggml_opt_result ggml_opt( size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_q4_1_o(const float * src, void * dst, int n, int k, int64_t * hist); // // system info diff --git a/rwkv.cpp b/rwkv.cpp index ede0791..0c331c3 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -43,6 +43,14 @@ bool read_int32(FILE * file, int32_t * dest) { return true; } +static const ggml_type FORMAT_TYPE_TO_GGML_TYPE[5] = { + GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_Q4_1_O +}; + // --- Model definition and loading utilities --- struct rwkv_layer { @@ -160,7 +168,8 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr model->data_type == 0 || model->data_type == 1 || model->data_type == 2 || - model->data_type == 3, + model->data_type == 3 || + model->data_type == 4, "Unsupported model data type %d", model->data_type ); @@ -216,20 +225,13 @@ struct rwkv_context * rwkv_init_from_file(const char * file_path, uint32_t n_thr data_type == 0 || data_type == 1 || data_type == 2 || - data_type == 3, + data_type == 3 || + data_type == 4, "Unsupported parameter data type %d", data_type ); - ggml_type ggml_data_type; - - switch (data_type) { - case 0: ggml_data_type = GGML_TYPE_F32; break; - case 1: ggml_data_type = GGML_TYPE_F16; break; - case 2: ggml_data_type = GGML_TYPE_Q4_0; break; - case 3: ggml_data_type = GGML_TYPE_Q4_1; break; - default: return NULL; - } + ggml_type ggml_data_type = FORMAT_TYPE_TO_GGML_TYPE[data_type]; struct ggml_tensor * tensor; @@ -553,17 +555,9 @@ void rwkv_free(struct rwkv_context * ctx) { } bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, uint32_t q_type) { - RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3, "Unsupported quantization type %d", q_type); - - ggml_type type; - - switch (q_type) { - case 2: type = GGML_TYPE_Q4_0; break; - case 3: type = GGML_TYPE_Q4_1; break; - default: return false; - }; + RWKV_ASSERT_FALSE(q_type == 2 || q_type == 3 || q_type == 4, "Unsupported quantization type %d", q_type); - RWKV_ASSERT_FALSE(type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1, "Unsupported data type %d", type); + ggml_type type = FORMAT_TYPE_TO_GGML_TYPE[q_type]; printf("Loading model from '%s'\n", model_file_path_in); @@ -643,22 +637,30 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode { static const char * parameter_data_type_str[] = { - "f32", - "f16", - "q4_0", - "q4_1" + "F32", + "F16", + "Q4_0", + "Q4_1", + "Q4_1_O" }; printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], parameter_data_type_str[parameter_data_type]); + + total_size_orig += (size_t) (nelements * ggml_type_sizef(FORMAT_TYPE_TO_GGML_TYPE[parameter_data_type])); } - // Quantize only 2D tensors - bool quantize = n_dims == 2; + // Quantize only 2D tensors, except embedding and head matrices. + // Embedding and head take not too much space, especially in bigger models; + // but they significantly increase perplexity when quantized. + bool quantize = n_dims == 2 && + name != std::string("emb.weight") && + name != std::string("head.weight"); if (quantize) { - if (parameter_data_type != 0 && parameter_data_type != 1) { - fprintf(stderr, "unsupported data type %d for integer quantization\n", parameter_data_type); - return false; - } + RWKV_ASSERT_FALSE( + parameter_data_type == 0 || parameter_data_type == 1, + "Unsupported parameter data type %d, only FP32 and FP16 can be quantized", + parameter_data_type + ); if (parameter_data_type == 1) { data_f16.resize(nelements); @@ -706,6 +708,10 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode { cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); } break; + case GGML_TYPE_Q4_1_O: + { + cur_size = ggml_quantize_q4_1_o(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data()); + } break; default: { fprintf(stderr, "unsupported quantization type %d\n", type); @@ -732,12 +738,11 @@ bool rwkv_quantize_model_file(const char * model_file_path_in, const char * mode fout.write(reinterpret_cast(data_u8.data()), data_u8.size()); total_size_new += data_u8.size(); } - - total_size_orig += nelements * sizeof(float); } - printf("model size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0); - printf("quant size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0); + printf("original size = %8.2f MB\n", total_size_orig / 1024.0 / 1024.0); + printf("quantized size = %8.2f MB\n", total_size_new / 1024.0 / 1024.0); + printf("compression ratio = %8.2f%\n", 1.0 * total_size_orig / total_size_new); { int64_t sum_all = 0; diff --git a/rwkv/compare_with_reference_implementation.py b/rwkv/compare_with_reference_implementation.py deleted file mode 100644 index 69a5828..0000000 --- a/rwkv/compare_with_reference_implementation.py +++ /dev/null @@ -1,102 +0,0 @@ -# Compares logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV. -# Reference logits were generated with RWKV-4-Pile-169M-20220807-8023.pth model in PyTorch. -# Reference implementation code: https://github.com/BlinkDL/ChatRWKV/blob/0d0abf181356c6f27501274cad18bdf28c83a45b/RWKV_in_150_lines.py -# Usage: python compare_with_reference_implementation.py C:\rwkv.cpp-169M.bin - -import os -import struct -import argparse -import torch -import numpy as np -import rwkv_cpp_model -import rwkv_cpp_shared_library -from typing import List, Tuple, Any - -def parse_args(): - parser = argparse.ArgumentParser(description='Compare logits from rwkv.cpp implementation of RWKV with logits from reference implementation of RWKV') - parser.add_argument('ggml_model_path', help='Path to rwkv.cpp checkpoint file') - return parser.parse_args() - -def main() -> None: - args = parse_args() - - # Don't want to depend on tokenizer here. - tokens: List[int] = [4, 3631, 4420, 2412, 953, 432, 391, 30567, 87, 15, 14161, 7092, 273, 416, 27767, 55, 342, - 2412, 953, 432, 3806, 7092, 273, 416, 27767, 55, 15, 187, 4, 19039, 2412, 953, 497, 4561, - 342, 416, 27767, 55, 14, 21, 14, 49, 587, 14, 17809, 46, 14, 938, 14256, 28950, 14, 1438, - 1508, 15, 81, 394, 1566, 275, 8462, 22097, 348, 15, 187, 4, 43825, 27, 15548, 7277, 64, - 3113, 64, 14005, 64, 39595, 15, 4789, 10269, 61, 18992, 61, 7265, 64, 30217, 39297, 15, - 20963, 330, 27, 190, 30567, 87, 15, 14161, 14, 17809, 46, 15, 4805] - - threshold: float - - with open(args.ggml_model_path, 'rb') as model_file: - header: Tuple[Any] = struct.unpack('=iiiiii', model_file.read(6 * 4)) - data_type: int = header[5] - - assert data_type == 0 or\ - data_type == 1 or\ - data_type == 2 or\ - data_type == 3, f'Unsupported model data type {data_type}' - - if data_type == 0: - # FP32, high precision - threshold = 0.000005 - elif data_type == 1: - # FP16, lower precision, so higher threshold - threshold = 0.0032 - elif data_type == 2: - # INT4 quantized, even lower precision, so even higher threshold - # This threshold will let some bugs pass - threshold = 4.0 - elif data_type == 3: - # This format stores more data, so error would be lower - threshold = 1.2 - - model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path) - - def compare_logits(tokens_subset: List[int]) -> None: - token_count: int = len(tokens_subset) - - logits, state = None, None - - for i in range(token_count): - token: int = tokens_subset[i] - - if token_count <= 10 or i % (token_count // 10) == 0: - print(f'{i + 1}/{token_count}') - - logits, state = model.eval(token, state, state, logits) - - actual_logits = logits - - # --- - - expected_logits_path: str = f'expected_logits_169M_20220807_8023_{len(tokens_subset)}_tokens.bin' - - if not os.path.isfile(expected_logits_path): - expected_logits_path = 'rwkv/' + expected_logits_path - - with open(expected_logits_path, 'rb') as logits_file: - expected_logits = torch.tensor(np.frombuffer(logits_file.read(), dtype=np.single)) - - # --- - - difference: float = (torch.sum(expected_logits - actual_logits) / len(expected_logits)).item() - - print(f'Reference logits: {expected_logits}') - print(f'Actual logits: {actual_logits}') - print('Difference per token: %.8f' % (difference,)) - - assert abs(difference) <= threshold, 'Difference is too big' - - compare_logits(tokens) - - print() - print('Test passes') - - if model is not None: - model.free() - -if __name__ == "__main__": - main() diff --git a/rwkv/expected_logits_169M_20220807_8023_98_tokens.bin b/rwkv/expected_logits_169M_20220807_8023_98_tokens.bin deleted file mode 100644 index e0409d2..0000000 Binary files a/rwkv/expected_logits_169M_20220807_8023_98_tokens.bin and /dev/null differ diff --git a/rwkv/measure_pexplexity.py b/rwkv/measure_pexplexity.py new file mode 100644 index 0000000..a2a0e2c --- /dev/null +++ b/rwkv/measure_pexplexity.py @@ -0,0 +1,100 @@ +# Measures perplexity and per-token latency of an RWKV model on a given text file. +# Perplexity is defined here as exp() of average cross-entropy loss. +# Usage: python measure_pexplexity.py C:\rwkv.cpp-169M.bin C:\text.txt 1024 + +import os +import time +import pathlib +import argparse +import tokenizers +import torch +import rwkv_cpp_model +import rwkv_cpp_shared_library +from typing import List + +def parse_args(): + parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file') + parser.add_argument('model_path', help='Path to model checkpoint file') + parser.add_argument('text_path', help='Path to text file in UTF-8 encoding') + parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int, default=1024) + return parser.parse_args() + +args = parse_args() + +# --- + +print('Loading 20B tokenizer') +tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json' +tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path)) + +print('Loading text') +text: str = open(args.text_path, encoding='utf-8').read() +tokens: List[int] = tokenizer.encode(text).ids +token_count: int = len(tokens) +print(f'{token_count} tokens in the text') + +assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation' + +# --- + +def format_loss(loss: torch.Tensor) -> str: + return str(['%.3f' % (loss[i].item(),) for i in range(len(loss))]).replace('\'', '')[1:-1] + +def format_loss_with_perplexity(loss: torch.Tensor) -> str: + return f'loss [{format_loss(loss)}], perplexity {"%.3f" % (torch.exp(loss[0]).item(),)}' + +# --- + +model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel( + rwkv_cpp_shared_library.load_rwkv_shared_library(), + args.model_path +) + +logits, state = None, None + +loss_sum: torch.Tensor = torch.tensor([0.0]) +loss_count: int = 0 + +start: float = time.time() + +run_count: int = token_count - 1 + +for i in range(run_count): + token: int = tokens[i] + target: int = tokens[i + 1] + + logits, state = model.eval(token, state, state, logits) + + if args.ignore_first_n_tokens == 0 or i + 1 >= args.ignore_first_n_tokens: + losses = torch.tensor([ + torch.nn.functional.cross_entropy(logits, torch.tensor(target, dtype=torch.long), reduction='none').item() + ]) + + loss_sum += losses + loss_count += 1 + + if i % 10 == 0: + avg_loss_so_far = loss_sum / loss_count + + duration: float = time.time() - start + duration_per_token: float = duration / (i + 1) + runs_remaining: int = run_count - i - 1 + duration_remaining: int = int(runs_remaining * duration_per_token) + + print(f'Token #{i}/{token_count}, ' + f'{int(100.0 * i / token_count)}%, ' + f'ETA {duration_remaining // 60} m {duration_remaining % 60} s', end='') + + if loss_count > 0: + print(f', averages so far: {format_loss_with_perplexity(avg_loss_so_far)}') + else: + print() + +print() +print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token') + +print() +print(f'Model: {os.path.basename(args.model_path)}, ' + f'data: {os.path.basename(args.text_path)} with {token_count} tokens, ' + f'skipped {args.ignore_first_n_tokens} tokens, ' + f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}') diff --git a/rwkv/quantize.py b/rwkv/quantize.py index e798855..243dc92 100644 --- a/rwkv/quantize.py +++ b/rwkv/quantize.py @@ -1,5 +1,5 @@ -# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1. -# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1.bin 3 +# Quantizes rwkv.cpp model file from FP32 or FP16 to Q4_0, Q4_1 or Q4_1_O (recommended). +# Usage: python quantize.py bin\Release\rwkv.dll C:\rwkv.cpp-169M-float32.bin C:\rwkv.cpp-169M-q4_1_o.bin 4 import argparse import rwkv_cpp_shared_library @@ -8,12 +8,20 @@ def parse_args(): parser = argparse.ArgumentParser(description='Quantize rwkv.cpp model file from FP32 or FP16 to Q4_0 or Q4_1') parser.add_argument('src_path', help='Path to FP32/FP16 checkpoint file') parser.add_argument('dest_path', help='Path to resulting checkpoint file, will be overwritten') - parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0) or 3 (GGML_TYPE_Q4_1)', type=int, choices=[2, 3], default=3) + parser.add_argument('data_type', help='Data type, 2 (GGML_TYPE_Q4_0), 3 (GGML_TYPE_Q4_1) or 4 (GGML_TYPE_Q4_1_O)', type=int, choices=[2, 3, 4], default=4) return parser.parse_args() def main() -> None: args = parse_args() + if args.data_type == 2 or args.data_type == 3: + print() + print('WARNING!') + print('You are using Q4_0 or Q4_1 quantization; it will heavily degrade RWKV quality.') + print('For best quality preservation, it is recommended to use Q4_1_O.') + print('More info at https://github.com/saharNooby/rwkv.cpp/issues/12') + print() + library = rwkv_cpp_shared_library.load_rwkv_shared_library() library.rwkv_quantize_model_file( diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index 4f089ad..f7bb32b 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -32,14 +32,14 @@ def __init__( assert os.path.isfile(model_path), f'{model_path} is not a file' assert thread_count > 0, 'Thread count must be positive' - self.library = shared_library + self._library = shared_library - self.ctx = self.library.rwkv_init_from_file(model_path, thread_count) + self._ctx = self._library.rwkv_init_from_file(model_path, thread_count) - self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx) - self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx) + self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx) + self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx) - self.valid = True + self._valid = True def eval( self, @@ -69,7 +69,7 @@ def eval( Logits vector of shape (n_vocab); state for the next step. """ - assert self.valid, 'Model was freed' + assert self._valid, 'Model was freed' def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None: assert buf.dtype == torch.float32, f'{name} is not of type float32' @@ -77,24 +77,24 @@ def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None: assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' if state_in is not None: - validate_buffer(state_in, 'state_in', self.state_buffer_element_count) + validate_buffer(state_in, 'state_in', self._state_buffer_element_count) state_in_ptr = state_in.storage().data_ptr() else: state_in_ptr = 0 if state_out is not None: - validate_buffer(state_out, 'state_out', self.state_buffer_element_count) + validate_buffer(state_out, 'state_out', self._state_buffer_element_count) else: - state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu') + state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu') if logits_out is not None: - validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count) + validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count) else: - logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu') + logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu') - self.library.rwkv_eval( - self.ctx, + self._library.rwkv_eval( + self._ctx, token, state_in_ptr, state_out.storage().data_ptr(), @@ -110,8 +110,13 @@ def free(self): The object must not be used anymore after calling this method. """ - assert self.valid, 'Already freed' + assert self._valid, 'Already freed' - self.valid = False + self._valid = False - self.library.rwkv_free(self.ctx) + self._library.rwkv_free(self._ctx) + + def __del__(self): + # Free the context on GC in case user forgot to call free() explicitly. + if hasattr(self, '_valid') and self._valid: + self.free() diff --git a/rwkv/rwkv_cpp_shared_library.py b/rwkv/rwkv_cpp_shared_library.py index 85ab0e6..9dc5da5 100644 --- a/rwkv/rwkv_cpp_shared_library.py +++ b/rwkv/rwkv_cpp_shared_library.py @@ -192,13 +192,17 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary: else: file_name = 'librwkv.so' + repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent + paths = [ # If we are in "rwkv" directory f'../bin/Release/{file_name}', # If we are in repo root directory f'bin/Release/{file_name}', + # Search relative to this file + str(repo_root_dir / 'bin' / 'Release' / file_name), # Fallback - pathlib.Path(os.path.abspath(__file__)).parent.parent / file_name + str(repo_root_dir / file_name) ] for path in paths: