Skip to content

Commit

Permalink
Assert contiguity instead of assuming it
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Sep 18, 2023
1 parent b530dcc commit 8d910a5
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,11 @@ bool rwkv_load_model_from_file(const char * file_path, struct rwkv_model & model
void rwkv_exp_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) {
GGML_ASSERT(dest->type == GGML_TYPE_F32);
GGML_ASSERT(src->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(dest));
GGML_ASSERT(ggml_is_contiguous(src));
GGML_ASSERT(ggml_are_same_shape(src, dest));

// Assuming contiguous 2D tensors.
// Assuming 2D tensors.
int64_t element_count = src->ne[0] * src->ne[1];
float * src_data = (float *) src->data;
float * dest_data = (float *) dest->data;
Expand All @@ -609,9 +611,11 @@ void rwkv_exp_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, in
void rwkv_1_minus_x_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) {
GGML_ASSERT(dest->type == GGML_TYPE_F32);
GGML_ASSERT(src->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(dest));
GGML_ASSERT(ggml_is_contiguous(src));
GGML_ASSERT(ggml_are_same_shape(src, dest));

// Assuming contiguous 2D tensors.
// Assuming 2D tensors.
int64_t element_count = src->ne[0] * src->ne[1];
float * src_data = (float *) src->data;
float * dest_data = (float *) dest->data;
Expand All @@ -629,9 +633,11 @@ void rwkv_1_minus_x_impl(struct ggml_tensor * dest, const struct ggml_tensor * s
void rwkv_sigmoid_impl(struct ggml_tensor * dest, const struct ggml_tensor * src, int ith, int nth, void * userdata) {
GGML_ASSERT(dest->type == GGML_TYPE_F32);
GGML_ASSERT(src->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(dest));
GGML_ASSERT(ggml_is_contiguous(src));
GGML_ASSERT(ggml_are_same_shape(src, dest));

// Assuming contiguous 2D tensors.
// Assuming 2D tensors.
int64_t element_count = src->ne[0] * src->ne[1];
float * src_data = (float *) src->data;
float * dest_data = (float *) dest->data;
Expand All @@ -657,10 +663,13 @@ void rwkv_max_impl(
GGML_ASSERT(dest->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(dest));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_are_same_shape(src0, dest));
GGML_ASSERT(ggml_are_same_shape(src1, dest));

// Assuming contiguous 2D tensors.
// Assuming 2D tensors.
int64_t element_count = src0->ne[0] * src0->ne[1];
float * src0_data = (float *) src0->data;
float * src1_data = (float *) src1->data;
Expand Down

0 comments on commit 8d910a5

Please sign in to comment.