diff --git a/ggml b/ggml index 4b20bbd..46f083d 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 4b20bbdf1b6e586addf9d065518b594e94dfa43f +Subproject commit 46f083d15bb31c62933300ffbfffa5aa6ae2ecae diff --git a/rwkv.cpp b/rwkv.cpp index 124602d..6fae152 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -47,6 +47,8 @@ static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit #include "rwkv_operators.inc" +#include "rwkv_operators_wkv_v5.inc" + #include "rwkv_graph.inc" // API function. diff --git a/rwkv_graph.inc b/rwkv_graph.inc index a8a8c6d..90dda81 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -265,65 +265,38 @@ static struct ggml_tensor * rwkv_att_v5( ); } - struct ggml_tensor * tf = layer.att_time_faaaa != NULL ? - layer.att_time_faaaa : - layer.att_time_first; + // dup is not strictly required; doing it just in case. + struct ggml_tensor * state_out = ggml_dup(ctx, state.att_heads); - struct ggml_tensor * a = rwkv_transpose_then_cont( - ctx, - ggml_mul_mat( - ctx, - k, - rwkv_transpose_then_cont(ctx, v) - ) - ); - - struct ggml_tensor * tf_a = ggml_mul_inplace( - ctx, - ggml_repeat(ctx, tf, a), - a - ); - - struct ggml_tensor * x_new = ggml_new_tensor_2d(ctx, x->type, n_embed, sequence_length); - - struct ggml_tensor * last_state = state.att_heads; - - for (size_t t = 0; t < sequence_length; t++) { - struct ggml_tensor * s = ggml_reshape_3d(ctx, last_state, head_size, head_size, head_count); - - struct ggml_tensor * tf_a_s = ggml_add_inplace( - ctx, - rwkv_get_from_dim_3(ctx, tf_a, t), - s - ); + struct ggml_tensor * time_first; + struct ggml_tensor * time_decay; - struct ggml_tensor * x_new_vector = ggml_mul_mat( - ctx, - rwkv_get_from_dim_3(ctx, r, t), - rwkv_transpose_then_cont(ctx, tf_a_s) - ); - - struct ggml_tensor * td_s = ggml_mul_inplace( - ctx, - ggml_repeat(ctx, layer.att_time_decay, s), - s - ); - - s = ggml_add_inplace(ctx, td_s, rwkv_get_from_dim_3(ctx, a, t)); - - last_state = s; + if (arch_version_minor >= 2) { + time_first = layer.att_time_faaaa; + time_decay = layer.att_time_decay; + } else { + struct ggml_tensor * dummy = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 1, head_size, head_count); - x_new = ggml_set_1d_inplace( - ctx, - x_new, - rwkv_flatten(ctx, x_new_vector), - t * n_embed * sizeof(float) - ); + time_first = ggml_repeat(ctx, layer.att_time_first, dummy); + time_decay = ggml_repeat(ctx, layer.att_time_decay, dummy); } - state.att_heads = last_state; + x = rwkv_wkv_v5( + ctx, + sequence_length, + n_embed, + head_count, + head_size, + x, + k, + v, + r, + time_first, + time_decay, + state_out + ); - x = x_new; + state.att_heads = state_out; // ggml_group_norm considers groups in the third dimension. x = ggml_reshape_4d(ctx, x, 1, 1, n_embed, sequence_length); diff --git a/rwkv_operators.inc b/rwkv_operators.inc index 0862035..af808e7 100644 --- a/rwkv_operators.inc +++ b/rwkv_operators.inc @@ -110,26 +110,3 @@ struct ggml_tensor * rwkv_layer_norm(struct ggml_context * ctx, struct ggml_tens // Looks like ggml_norm does the first part, we only need to apply weight & bias. return ggml_add_inplace(ctx, ggml_mul_inplace(ctx, ggml_norm(ctx, x, 1e-5F), weight), bias); } - -static struct ggml_tensor * rwkv_transpose_then_cont(struct ggml_context * ctx, struct ggml_tensor * x) { - return ggml_cont(ctx, ggml_transpose(ctx, x)); -} - -static struct ggml_tensor * rwkv_get_from_dim_3(struct ggml_context * ctx, struct ggml_tensor * x, int64_t index) { - return ggml_view_4d( - ctx, - x, - x->ne[0], - x->ne[1], - x->ne[2], - 1, - x->nb[1], - x->nb[2], - x->nb[3], - index * (x->ne[0] * x->ne[1] * x->ne[2]) * sizeof(float) - ); -} - -static struct ggml_tensor * rwkv_flatten(struct ggml_context * ctx, struct ggml_tensor * x) { - return ggml_view_1d(ctx, x, ggml_nelements(x), 0); -} diff --git a/rwkv_operators_wkv_v5.inc b/rwkv_operators_wkv_v5.inc new file mode 100644 index 0000000..f9502d4 --- /dev/null +++ b/rwkv_operators_wkv_v5.inc @@ -0,0 +1,180 @@ +// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L8 +// Original code by Harrison Vanderbyl. +// TODO Fix 1. unaligned memory access on Linux with AVX2, 2. tiny-rwkv with AVX-512 +/*#ifdef __AVX512F__ + #include + #define SIMD_WIDTH 16 + #define LOAD(x) _mm512_load_ps(x) + #define STORE(x, y) _mm512_store_ps(x, y) + #define SET1(x) _mm512_set1_ps(x) + #define MULTIPLY(x, y) _mm512_mul_ps(x, y) + #define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z) +#elif __AVX2__ + #include + #define SIMD_WIDTH 8 + #define LOAD(x) _mm256_load_ps(x) + #define STORE(x, y) _mm256_store_ps(x, y) + #define SET1(x) _mm256_set1_ps(x) + #define MULTIPLY(x, y) _mm256_mul_ps(x, y) + #define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z) +#elif defined(__ARM_NEON) || defined(__ARM_NEON__) + #include + #define SIMD_WIDTH 4 + #define LOAD(x) vld1q_f32(x) + #define STORE(x, y) vst1q_f32(x, y) + #define SET1(x) vdupq_n_f32(x) + #define MULTIPLY(x, y) vmulq_f32(x, y) + #define MULTADD(x, y, z) vmlaq_f32(z, x, y) +#else*/ + #define SIMD_WIDTH 1 + #define LOAD(x) *x + #define STORE(x, y) *x = y + #define SET1(x) x + #define MULTIPLY(x, y) x * y + #define MULTADD(x, y, z) x * y + z +//#endif + +// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L57 +// Original code by Harrison Vanderbyl. +static void rwkv_wkv_v5_impl(struct ggml_tensor * result, const struct ggml_tensor * src, int ith, int nth, void * userdata) { + const size_t T = result->ne[1]; + const size_t C = result->ne[0]; + const size_t H = result->src[1]->ne[2]; + + float * result_data = (float *) result->data; + + memset(result_data, 0, T * C * sizeof(float)); + + float * k = (float *) result->src[1]->data; + float * v = (float *) result->src[2]->data; + float * r = (float *) result->src[3]->data; + float * time_f = (float *) result->src[4]->data; + float * time_decay = (float *) result->src[5]->data; + float * state = (float *) result->src[6]->data; + + size_t t_stride = H * (C / H); + + size_t h_stride = C / H; + size_t h_stride_2d = (C / H) * (C / H); + + for (size_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + + for (size_t h = 0; h < H; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (size_t i = 0; i < C / H; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + auto k_val = SET1(k[t_h_i_offset]); + auto r_val = SET1(r[t_h_i_offset]); + auto time_f_val = SET1(time_f[h_i_offset]); + auto time_decay_val = SET1(time_decay[h_i_offset]); + + for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + + auto v_val = LOAD(&v[t_h_j_offset]); + + auto kv_val = MULTIPLY(v_val, k_val); + + auto prev_state_val = LOAD(&state[h_2d_i_j_offset]); + + auto temp_val = MULTADD(kv_val, time_f_val, prev_state_val); + + auto prev_result_data = LOAD(&result_data[t_h_j_offset]); + + STORE(&result_data[t_h_j_offset], MULTADD(temp_val, r_val, prev_result_data)); + + STORE(&state[h_2d_i_j_offset], MULTADD(prev_state_val, time_decay_val, kv_val)); + } + } + } + } + + // Suppress "unused parameter" warnings. + (void) src; + (void) ith; + (void) nth; + (void) userdata; +} + +// Parameters: +// - T: sequence length +// - C: channel count, same as n_embed +// - H: head count +// - S: head size +// Shapes (in ggml order): +// - x: [C, T, 1, 1] +// - k: [1, S, H, T] +// - v: [S, 1, H, T] +// - r: [S, 1, H, T] +// - time_f: [1, S, H, 1] +// - time_decay: [1, S, H, 1] +// - state: [S * S * H, 1, 1, 1] +// - result: same as x +// time_f and time_decay must be preprocessed as neccessary -- exp() applied, etc. +// state will be written to. +static struct ggml_tensor * rwkv_wkv_v5( + struct ggml_context * ctx, + const size_t T, + const size_t C, + const size_t H, + const size_t S, + struct ggml_tensor * x, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * r, + // time_first for v5.1, time_faaaa for v5.2. + struct ggml_tensor * time_f, + struct ggml_tensor * time_decay, + struct ggml_tensor * state +) { + GGML_ASSERT(x->type == GGML_TYPE_F32); + GGML_ASSERT(k->type == GGML_TYPE_F32); + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(r->type == GGML_TYPE_F32); + GGML_ASSERT(time_f->type == GGML_TYPE_F32); + GGML_ASSERT(time_decay->type == GGML_TYPE_F32); + GGML_ASSERT(state->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(x)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(r)); + GGML_ASSERT(ggml_is_contiguous(time_f)); + GGML_ASSERT(ggml_is_contiguous(time_decay)); + GGML_ASSERT(ggml_is_contiguous(state)); + + GGML_ASSERT(x->ne[0] == C && x->ne[1] == T && x->ne[2] == 1 && x->ne[3] == 1); + GGML_ASSERT(k->ne[0] == 1 && k->ne[1] == S && k->ne[2] == H && k->ne[3] == T); + GGML_ASSERT(v->ne[0] == S && v->ne[1] == 1 && v->ne[2] == H && v->ne[3] == T); + GGML_ASSERT(r->ne[0] == S && r->ne[1] == 1 && r->ne[2] == H && r->ne[3] == T); + GGML_ASSERT(ggml_nelements(state) == S * S * H); + + k = ggml_cont_inplace(ctx, ggml_transpose(ctx, k)); + v = ggml_cont_inplace(ctx, ggml_transpose(ctx, v)); + r = ggml_cont_inplace(ctx, ggml_transpose(ctx, r)); + + struct ggml_tensor * result = ggml_map_custom1( + ctx, + x, + rwkv_wkv_v5_impl, + 1, + NULL + ); + result->src[1] = k; + result->src[2] = v; + result->src[3] = r; + result->src[4] = time_f; + result->src[5] = time_decay; + // GGML_MAX_SRC must be increased from 6 to 8 for this. + result->src[6] = state; + + return result; +}