Skip to content

Commit

Permalink
Use wdata for storing dequantized weights
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Apr 4, 2023
1 parent 09131a9 commit 2f6b28b
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions ggml.c
Expand Up @@ -6843,16 +6843,15 @@ static void ggml_compute_forward_mul_mat_q_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

// TODO Alloc somewhere else, or maybe use wdata
float * dequantized = calloc(ne00, sizeof(float));

for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);

dequantize_row_q((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03), dequantized, ne00);
float * const wdata = (float *) ((char *) params->wdata + (i01 * nb01 + i02 * nb02 + i03 * nb03));

dequantize_row_q((char *) src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03), wdata, ne00);

for (int ic = 0; ic < ne11; ++ic) {
// src1 indices
Expand All @@ -6869,13 +6868,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
ggml_vec_dot_f32(
ne00,
(float *) ((char *) dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3)),
dequantized,
wdata,
(float *) ((char *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13)));
}
}

free(dequantized);

//int64_t t1 = ggml_time_us();
//static int64_t acc = 0;
//acc += t1 - t0;
Expand Down Expand Up @@ -9546,7 +9543,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} else
#endif
{
cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
}
} else {
GGML_ASSERT(false);
Expand Down

0 comments on commit 2f6b28b

Please sign in to comment.