Skip to content

Commit

Permalink
Add wkv v5 custom operator
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Nov 14, 2023
1 parent 20a8549 commit e9fbb2f
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 77 deletions.
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated from 4b20bb to 46f083
2 changes: 2 additions & 0 deletions rwkv.cpp
Expand Up @@ -47,6 +47,8 @@ static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit

#include "rwkv_operators.inc"

#include "rwkv_operators_wkv_v5.inc"

#include "rwkv_graph.inc"

// API function.
Expand Down
79 changes: 26 additions & 53 deletions rwkv_graph.inc
Expand Up @@ -265,65 +265,38 @@ static struct ggml_tensor * rwkv_att_v5(
);
}

struct ggml_tensor * tf = layer.att_time_faaaa != NULL ?
layer.att_time_faaaa :
layer.att_time_first;
// dup is not strictly required; doing it just in case.
struct ggml_tensor * state_out = ggml_dup(ctx, state.att_heads);

struct ggml_tensor * a = rwkv_transpose_then_cont(
ctx,
ggml_mul_mat(
ctx,
k,
rwkv_transpose_then_cont(ctx, v)
)
);

struct ggml_tensor * tf_a = ggml_mul_inplace(
ctx,
ggml_repeat(ctx, tf, a),
a
);

struct ggml_tensor * x_new = ggml_new_tensor_2d(ctx, x->type, n_embed, sequence_length);

struct ggml_tensor * last_state = state.att_heads;

for (size_t t = 0; t < sequence_length; t++) {
struct ggml_tensor * s = ggml_reshape_3d(ctx, last_state, head_size, head_size, head_count);

struct ggml_tensor * tf_a_s = ggml_add_inplace(
ctx,
rwkv_get_from_dim_3(ctx, tf_a, t),
s
);
struct ggml_tensor * time_first;
struct ggml_tensor * time_decay;

struct ggml_tensor * x_new_vector = ggml_mul_mat(
ctx,
rwkv_get_from_dim_3(ctx, r, t),
rwkv_transpose_then_cont(ctx, tf_a_s)
);

struct ggml_tensor * td_s = ggml_mul_inplace(
ctx,
ggml_repeat(ctx, layer.att_time_decay, s),
s
);

s = ggml_add_inplace(ctx, td_s, rwkv_get_from_dim_3(ctx, a, t));

last_state = s;
if (arch_version_minor >= 2) {
time_first = layer.att_time_faaaa;
time_decay = layer.att_time_decay;
} else {
struct ggml_tensor * dummy = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 1, head_size, head_count);

x_new = ggml_set_1d_inplace(
ctx,
x_new,
rwkv_flatten(ctx, x_new_vector),
t * n_embed * sizeof(float)
);
time_first = ggml_repeat(ctx, layer.att_time_first, dummy);
time_decay = ggml_repeat(ctx, layer.att_time_decay, dummy);
}

state.att_heads = last_state;
x = rwkv_wkv_v5(
ctx,
sequence_length,
n_embed,
head_count,
head_size,
x,
k,
v,
r,
time_first,
time_decay,
state_out
);

x = x_new;
state.att_heads = state_out;

// ggml_group_norm considers groups in the third dimension.
x = ggml_reshape_4d(ctx, x, 1, 1, n_embed, sequence_length);
Expand Down
23 changes: 0 additions & 23 deletions rwkv_operators.inc
Expand Up @@ -110,26 +110,3 @@ struct ggml_tensor * rwkv_layer_norm(struct ggml_context * ctx, struct ggml_tens
// Looks like ggml_norm does the first part, we only need to apply weight & bias.
return ggml_add_inplace(ctx, ggml_mul_inplace(ctx, ggml_norm(ctx, x, 1e-5F), weight), bias);
}

static struct ggml_tensor * rwkv_transpose_then_cont(struct ggml_context * ctx, struct ggml_tensor * x) {
return ggml_cont(ctx, ggml_transpose(ctx, x));
}

static struct ggml_tensor * rwkv_get_from_dim_3(struct ggml_context * ctx, struct ggml_tensor * x, int64_t index) {
return ggml_view_4d(
ctx,
x,
x->ne[0],
x->ne[1],
x->ne[2],
1,
x->nb[1],
x->nb[2],
x->nb[3],
index * (x->ne[0] * x->ne[1] * x->ne[2]) * sizeof(float)
);
}

static struct ggml_tensor * rwkv_flatten(struct ggml_context * ctx, struct ggml_tensor * x) {
return ggml_view_1d(ctx, x, ggml_nelements(x), 0);
}
179 changes: 179 additions & 0 deletions rwkv_operators_wkv_v5.inc
@@ -0,0 +1,179 @@
// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L8
// Original code by Harrison Vanderbyl.
#ifdef __AVX512F__
#include <immintrin.h>
#define SIMD_WIDTH 16
#define LOAD(x) _mm512_load_ps(x)
#define STORE(x, y) _mm512_store_ps(x, y)
#define SET1(x) _mm512_set1_ps(x)
#define MULTIPLY(x, y) _mm512_mul_ps(x, y)
#define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z)
#elif __AVX2__
#include <immintrin.h>
#define SIMD_WIDTH 8
#define LOAD(x) _mm256_load_ps(x)
#define STORE(x, y) _mm256_store_ps(x, y)
#define SET1(x) _mm256_set1_ps(x)
#define MULTIPLY(x, y) _mm256_mul_ps(x, y)
#define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z)
#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
#include <arm_neon.h>
#define SIMD_WIDTH 4
#define LOAD(x) vld1q_f32(x)
#define STORE(x, y) vst1q_f32(x, y)
#define SET1(x) vdupq_n_f32(x)
#define MULTIPLY(x, y) vmulq_f32(x, y)
#define MULTADD(x, y, z) vmlaq_f32(z, x, y)
#else
#define SIMD_WIDTH 1
#define LOAD(x) x
#define STORE(x, y) x = y
#define SET1(x) x
#define MULTIPLY(x, y) x * y
#define MULTADD(x, y, z) x * y + z
#endif

// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L57
// Original code by Harrison Vanderbyl.
static void rwkv_wkv_v5_impl(struct ggml_tensor * result, const struct ggml_tensor * src, int ith, int nth, void * userdata) {
const size_t T = result->ne[1];
const size_t C = result->ne[0];
const size_t H = result->src[1]->ne[2];

float * result_data = (float *) result->data;

memset(result_data, 0, T * C * sizeof(float));

float * k = (float *) result->src[1]->data;
float * v = (float *) result->src[2]->data;
float * r = (float *) result->src[3]->data;
float * time_f = (float *) result->src[4]->data;
float * time_decay = (float *) result->src[5]->data;
float * state = (float *) result->src[6]->data;

size_t t_stride = H * (C / H);

size_t h_stride = C / H;
size_t h_stride_2d = (C / H) * (C / H);

for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;

for (size_t h = 0; h < H; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;

for (size_t i = 0; i < C / H; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;

auto k_val = SET1(k[t_h_i_offset]);
auto r_val = SET1(r[t_h_i_offset]);
auto time_f_val = SET1(time_f[h_i_offset]);
auto time_decay_val = SET1(time_decay[h_i_offset]);

for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;

auto v_val = LOAD(&v[t_h_j_offset]);

auto kv_val = MULTIPLY(v_val, k_val);

auto prev_state_val = LOAD(&state[h_2d_i_j_offset]);

auto temp_val = MULTADD(kv_val, time_f_val, prev_state_val);

auto prev_result_data = LOAD(&result_data[t_h_j_offset]);

STORE(&result_data[t_h_j_offset], MULTADD(temp_val, r_val, prev_result_data));

STORE(&state[h_2d_i_j_offset], MULTADD(prev_state_val, time_decay_val, kv_val));
}
}
}
}

// Suppress "unused parameter" warnings.
(void) src;
(void) ith;
(void) nth;
(void) userdata;
}

// Parameters:
// - T: sequence length
// - C: channel count, same as n_embed
// - H: head count
// - S: head size
// Shapes (in ggml order):
// - x: [C, T, 1, 1]
// - k: [1, S, H, T]
// - v: [S, 1, H, T]
// - r: [S, 1, H, T]
// - time_f: [1, S, H, 1]
// - time_decay: [1, S, H, 1]
// - state: [S * S * H, 1, 1, 1]
// - result: same as x
// time_f and time_decay must be preprocessed as neccessary -- exp() applied, etc.
// state will be written to.
static struct ggml_tensor * rwkv_wkv_v5(
struct ggml_context * ctx,
const size_t T,
const size_t C,
const size_t H,
const size_t S,
struct ggml_tensor * x,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * r,
// time_first for v5.1, time_faaaa for v5.2.
struct ggml_tensor * time_f,
struct ggml_tensor * time_decay,
struct ggml_tensor * state
) {
GGML_ASSERT(x->type == GGML_TYPE_F32);
GGML_ASSERT(k->type == GGML_TYPE_F32);
GGML_ASSERT(v->type == GGML_TYPE_F32);
GGML_ASSERT(r->type == GGML_TYPE_F32);
GGML_ASSERT(time_f->type == GGML_TYPE_F32);
GGML_ASSERT(time_decay->type == GGML_TYPE_F32);
GGML_ASSERT(state->type == GGML_TYPE_F32);

GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(r));
GGML_ASSERT(ggml_is_contiguous(time_f));
GGML_ASSERT(ggml_is_contiguous(time_decay));
GGML_ASSERT(ggml_is_contiguous(state));

GGML_ASSERT(x->ne[0] == C && x->ne[1] == T && x->ne[2] == 1 && x->ne[3] == 1);
GGML_ASSERT(k->ne[0] == 1 && k->ne[1] == S && k->ne[2] == H && k->ne[3] == T);
GGML_ASSERT(v->ne[0] == S && v->ne[1] == 1 && v->ne[2] == H && v->ne[3] == T);
GGML_ASSERT(r->ne[0] == S && r->ne[1] == 1 && r->ne[2] == H && r->ne[3] == T);
GGML_ASSERT(ggml_nelements(state) == S * S * H);

k = ggml_cont_inplace(ctx, ggml_transpose(ctx, k));
v = ggml_cont_inplace(ctx, ggml_transpose(ctx, v));
r = ggml_cont_inplace(ctx, ggml_transpose(ctx, r));

struct ggml_tensor * result = ggml_map_custom1(
ctx,
x,
rwkv_wkv_v5_impl,
1,
NULL
);
result->src[1] = k;
result->src[2] = v;
result->src[3] = r;
result->src[4] = time_f;
result->src[5] = time_decay;
// GGML_MAX_SRC must be increased from 6 to 8 for this.
result->src[6] = state;

return result;
}

0 comments on commit e9fbb2f

Please sign in to comment.