Skip to content

Commit

Permalink
Use FP16 in Q4_1 block to reduce file size
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Apr 4, 2023
1 parent 8604ed4 commit 96ea1c3
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 48 deletions.
101 changes: 54 additions & 47 deletions ggml.c
Expand Up @@ -25,6 +25,35 @@
#define static_assert(cond, msg) struct global_scope_noop_trick
#endif

// https://gist.github.com/rygorous/2144712
// Public domain, by Fabian "ryg" Giesen
inline static float ggml_half_to_float_simple(uint16_t value) {
union FP32 {
uint32_t u;
float f;
};

const union FP32 magic = { (254UL - 15UL) << 23 };
const union FP32 was_inf_nan = { (127UL + 16UL) << 23 };

union FP32 out;

// Exponent/mantissa bits
out.u = (value & 0x7FFFU) << 13;
// Exponent adjust
out.f *= magic.f;

// Make sure Inf/NaN survive
if (out.f >= was_inf_nan.f) {
out.u |= 255UL << 23;
}

// Sign bit
out.u |= (value & 0x8000UL) << 16;

return out.f;
}

#if defined _MSC_VER || defined(__MINGW32__)

#if !defined(__MINGW32__)
Expand Down Expand Up @@ -326,42 +355,13 @@ static float table_f32_f16[1 << 16];
// This is also true for POWER9.
#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)

// https://gist.github.com/rygorous/2144712
// Public domain, by Fabian "ryg" Giesen
inline static float ggml_half_to_float(uint16_t value) {
union FP32 {
uint32_t u;
float f;
};

const union FP32 magic = { (254UL - 15UL) << 23 };
const union FP32 was_inf_nan = { (127UL + 16UL) << 23 };

union FP32 out;

// Exponent/mantissa bits
out.u = (value & 0x7FFFU) << 13;
// Exponent adjust
out.f *= magic.f;

// Make sure Inf/NaN survive
if (out.f >= was_inf_nan.f) {
out.u |= 255UL << 23;
}

// Sign bit
out.u |= (value & 0x8000UL) << 16;

return out.f;
}

inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
// For some reason, lookup table does not work on my machine:
// - Windows SDK version 10.0.19041.0
// - CMAKE_SYSTEM_PROCESSOR: AMD64
// Replaced lookup with some conversion code found online.
// TODO This must be properly debugged and fixed
return ggml_half_to_float(f);
return ggml_half_to_float_simple(f);
}

#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
Expand Down Expand Up @@ -508,14 +508,15 @@ static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block si
// blocks of QK elements
// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
typedef struct {
// TODO Use fp16
float d;
float m;
ggml_fp16_t d;
ggml_fp16_t m;
// 16 bits for the in-block index is overkill, since we need only 5 bits;
// but IDK how to compress these fields further.
uint16_t outlier_index;
float outlier_value;
ggml_fp16_t outlier_value;
uint8_t qs[QK / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(float) * 3 + 2 + QK / 2, "wrong q4_1 block size/padding");
static_assert(sizeof(block_q4_1) == 8 + QK / 2, "wrong q4_1 block size/padding");

// reference implementation for deterministic creation of model files
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
Expand Down Expand Up @@ -758,7 +759,7 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
}

y[i].outlier_index = outlier_index;
y[i].outlier_value = outlier_value;
y[i].outlier_value = GGML_COMPUTE_FP32_TO_FP16(outlier_value);

float min = FLT_MAX;
float max = -FLT_MAX;
Expand All @@ -777,8 +778,8 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;

y[i].d = d;
y[i].m = min;
y[i].d = GGML_COMPUTE_FP32_TO_FP16(d);
y[i].m = GGML_COMPUTE_FP32_TO_FP16(min);

for (int l = 0; l < QK; l += 2) {
float v0 = (x[i*QK + l + 0] - min)*id;
Expand Down Expand Up @@ -1050,8 +1051,11 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in

#if defined(__AVX2__)
for (int i = 0; i < nb; i++) {
const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
const __m256 d_m = _mm256_broadcast_ss(&x[i].m);
const float x_d = ggml_half_to_float_simple(x[i].d);
const float x_m = ggml_half_to_float_simple(x[i].m);

const __m256 d_v = _mm256_broadcast_ss(&x_d);
const __m256 d_m = _mm256_broadcast_ss(&x_m);

const uint8_t * restrict pp = x[i].qs;

Expand Down Expand Up @@ -1079,12 +1083,15 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
}

// Restore the outlier
y[i * QK + x[i].outlier_index] = x[i].outlier_value;
y[i * QK + x[i].outlier_index] = ggml_half_to_float_simple(x[i].outlier_value);
}
#elif defined(__ARM_NEON)
for (int i = 0; i < nb; i++) {
const float32x4_t vd = vdupq_n_f32(x[i].d);
const float32x4_t vm = vdupq_n_f32(x[i].m);
const float x_d = ggml_half_to_float_simple(x[i].d);
const float x_m = ggml_half_to_float_simple(x[i].m);

const float32x4_t vd = vdupq_n_f32(x_d);
const float32x4_t vm = vdupq_n_f32(x_m);

const uint8_t * restrict pp = x[i].qs;

Expand Down Expand Up @@ -1126,12 +1133,12 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
}

// Restore the outlier
y[i * QK + x[i].outlier_index] = x[i].outlier_value;
y[i * QK + x[i].outlier_index] = ggml_half_to_float_simple(x[i].outlier_value);
}
#else
for (int i = 0; i < nb; i++) {
const float d = x[i].d;
const float m = x[i].m;
const float d = ggml_half_to_float_simple(x[i].d);
const float m = ggml_half_to_float_simple(x[i].m);

const uint8_t * restrict pp = x[i].qs;

Expand All @@ -1152,7 +1159,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
}

// Restore the outlier
y[i * QK + x[i].outlier_index] = x[i].outlier_value;
y[i * QK + x[i].outlier_index] = ggml_half_to_float_simple(x[i].outlier_value);
}
#endif
}
Expand Down
2 changes: 1 addition & 1 deletion rwkv/compare_with_reference_implementation.py
Expand Up @@ -51,7 +51,7 @@ def main() -> None:
threshold = 4.0
elif data_type == 3:
# This format stores more data, so error would be lower
threshold = 1.2
threshold = 0.2

model = rwkv_cpp_model.RWKVModel(rwkv_cpp_shared_library.load_rwkv_shared_library(), args.ggml_model_path)

Expand Down

0 comments on commit 96ea1c3

Please sign in to comment.