From 96ea1c3fd34db035a40f9d561bde11490bf1ae02 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Tue, 4 Apr 2023 19:47:27 +0400 Subject: [PATCH] Use FP16 in Q4_1 block to reduce file size --- ggml.c | 101 ++++++++++-------- rwkv/compare_with_reference_implementation.py | 2 +- 2 files changed, 55 insertions(+), 48 deletions(-) diff --git a/ggml.c b/ggml.c index 0d2af19..f0946f2 100644 --- a/ggml.c +++ b/ggml.c @@ -25,6 +25,35 @@ #define static_assert(cond, msg) struct global_scope_noop_trick #endif +// https://gist.github.com/rygorous/2144712 +// Public domain, by Fabian "ryg" Giesen +inline static float ggml_half_to_float_simple(uint16_t value) { + union FP32 { + uint32_t u; + float f; + }; + + const union FP32 magic = { (254UL - 15UL) << 23 }; + const union FP32 was_inf_nan = { (127UL + 16UL) << 23 }; + + union FP32 out; + + // Exponent/mantissa bits + out.u = (value & 0x7FFFU) << 13; + // Exponent adjust + out.f *= magic.f; + + // Make sure Inf/NaN survive + if (out.f >= was_inf_nan.f) { + out.u |= 255UL << 23; + } + + // Sign bit + out.u |= (value & 0x8000UL) << 16; + + return out.f; +} + #if defined _MSC_VER || defined(__MINGW32__) #if !defined(__MINGW32__) @@ -326,42 +355,13 @@ static float table_f32_f16[1 << 16]; // This is also true for POWER9. #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) -// https://gist.github.com/rygorous/2144712 -// Public domain, by Fabian "ryg" Giesen -inline static float ggml_half_to_float(uint16_t value) { - union FP32 { - uint32_t u; - float f; - }; - - const union FP32 magic = { (254UL - 15UL) << 23 }; - const union FP32 was_inf_nan = { (127UL + 16UL) << 23 }; - - union FP32 out; - - // Exponent/mantissa bits - out.u = (value & 0x7FFFU) << 13; - // Exponent adjust - out.f *= magic.f; - - // Make sure Inf/NaN survive - if (out.f >= was_inf_nan.f) { - out.u |= 255UL << 23; - } - - // Sign bit - out.u |= (value & 0x8000UL) << 16; - - return out.f; -} - inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { // For some reason, lookup table does not work on my machine: // - Windows SDK version 10.0.19041.0 // - CMAKE_SYSTEM_PROCESSOR: AMD64 // Replaced lookup with some conversion code found online. // TODO This must be properly debugged and fixed - return ggml_half_to_float(f); + return ggml_half_to_float_simple(f); } #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) @@ -508,14 +508,15 @@ static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block si // blocks of QK elements // represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors) typedef struct { - // TODO Use fp16 - float d; - float m; + ggml_fp16_t d; + ggml_fp16_t m; + // 16 bits for the in-block index is overkill, since we need only 5 bits; + // but IDK how to compress these fields further. uint16_t outlier_index; - float outlier_value; + ggml_fp16_t outlier_value; uint8_t qs[QK / 2]; // nibbles / quants } block_q4_1; -static_assert(sizeof(block_q4_1) == sizeof(float) * 3 + 2 + QK / 2, "wrong q4_1 block size/padding"); +static_assert(sizeof(block_q4_1) == 8 + QK / 2, "wrong q4_1 block size/padding"); // reference implementation for deterministic creation of model files static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { @@ -758,7 +759,7 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric } y[i].outlier_index = outlier_index; - y[i].outlier_value = outlier_value; + y[i].outlier_value = GGML_COMPUTE_FP32_TO_FP16(outlier_value); float min = FLT_MAX; float max = -FLT_MAX; @@ -777,8 +778,8 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric const float d = (max - min) / ((1 << 4) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; - y[i].m = min; + y[i].d = GGML_COMPUTE_FP32_TO_FP16(d); + y[i].m = GGML_COMPUTE_FP32_TO_FP16(min); for (int l = 0; l < QK; l += 2) { float v0 = (x[i*QK + l + 0] - min)*id; @@ -1050,8 +1051,11 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in #if defined(__AVX2__) for (int i = 0; i < nb; i++) { - const __m256 d_v = _mm256_broadcast_ss(&x[i].d); - const __m256 d_m = _mm256_broadcast_ss(&x[i].m); + const float x_d = ggml_half_to_float_simple(x[i].d); + const float x_m = ggml_half_to_float_simple(x[i].m); + + const __m256 d_v = _mm256_broadcast_ss(&x_d); + const __m256 d_m = _mm256_broadcast_ss(&x_m); const uint8_t * restrict pp = x[i].qs; @@ -1079,12 +1083,15 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in } // Restore the outlier - y[i * QK + x[i].outlier_index] = x[i].outlier_value; + y[i * QK + x[i].outlier_index] = ggml_half_to_float_simple(x[i].outlier_value); } #elif defined(__ARM_NEON) for (int i = 0; i < nb; i++) { - const float32x4_t vd = vdupq_n_f32(x[i].d); - const float32x4_t vm = vdupq_n_f32(x[i].m); + const float x_d = ggml_half_to_float_simple(x[i].d); + const float x_m = ggml_half_to_float_simple(x[i].m); + + const float32x4_t vd = vdupq_n_f32(x_d); + const float32x4_t vm = vdupq_n_f32(x_m); const uint8_t * restrict pp = x[i].qs; @@ -1126,12 +1133,12 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in } // Restore the outlier - y[i * QK + x[i].outlier_index] = x[i].outlier_value; + y[i * QK + x[i].outlier_index] = ggml_half_to_float_simple(x[i].outlier_value); } #else for (int i = 0; i < nb; i++) { - const float d = x[i].d; - const float m = x[i].m; + const float d = ggml_half_to_float_simple(x[i].d); + const float m = ggml_half_to_float_simple(x[i].m); const uint8_t * restrict pp = x[i].qs; @@ -1152,7 +1159,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in } // Restore the outlier - y[i * QK + x[i].outlier_index] = x[i].outlier_value; + y[i * QK + x[i].outlier_index] = ggml_half_to_float_simple(x[i].outlier_value); } #endif } diff --git a/rwkv/compare_with_reference_implementation.py b/rwkv/compare_with_reference_implementation.py index 69a5828..2aee33b 100644 --- a/rwkv/compare_with_reference_implementation.py +++ b/rwkv/compare_with_reference_implementation.py @@ -51,7 +51,7 @@ def main() -> None: threshold = 4.0 elif data_type == 3: # This format stores more data, so error would be lower - threshold = 1.2 + threshold = 0.2 model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)