Skip to content

Commit

Permalink
Fix cuBLAS
Browse files Browse the repository at this point in the history
Properly set the backend and then call ggml_cuda_transform_tensor
  • Loading branch information
LoganDark committed Jun 20, 2023
1 parent 4c7c74c commit f9ad712
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1507,25 +1507,28 @@ struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32

bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_gpu_layers) {
#ifdef GGML_USE_CUBLAS
size_t n_gpu = std::min(n_gpu_layers, ctx->instance->model.header.n_layer);

size_t & gpu_layers = ctx->gpu_layers;
size_t & vram_total = ctx->vram_total;
const auto offload = [&](struct ggml_tensor * tensor) {
// TODO support split-GPU
tensor->backend = GGML_BACKEND_GPU;
ggml_cuda_transform_tensor(tensor->data, tensor);
vram_total += ggml_nbytes(tensor);
};

const size_t n_gpu = std::min(n_gpu_layers, ctx->instance->model.header.n_layer);

for (size_t i = gpu_layers; i < n_gpu; i++) {
for (size_t & i = ctx->gpu_layers; i < n_gpu; i++) {
const struct rwkv_layer & layer = ctx->instance->model.layers[i];

// Use cuBLAS only for heavy matrices; other operations are not supported for GPU at the moment
ggml_cuda_assign_buffers(layer.att_key); vram_total += ggml_nbytes(layer.att_key);
ggml_cuda_assign_buffers(layer.att_value); vram_total += ggml_nbytes(layer.att_value);
ggml_cuda_assign_buffers(layer.att_receptance); vram_total += ggml_nbytes(layer.att_receptance);
ggml_cuda_assign_buffers(layer.att_output); vram_total += ggml_nbytes(layer.att_output);

ggml_cuda_assign_buffers(layer.ffn_key); vram_total += ggml_nbytes(layer.ffn_key);
ggml_cuda_assign_buffers(layer.ffn_value); vram_total += ggml_nbytes(layer.ffn_value);
ggml_cuda_assign_buffers(layer.ffn_receptance); vram_total += ggml_nbytes(layer.ffn_receptance);

gpu_layers++;
offload(layer.att_key);
offload(layer.att_value);
offload(layer.att_receptance);
offload(layer.att_output);

offload(layer.ffn_key);
offload(layer.ffn_value);
offload(layer.ffn_receptance);
}
#endif

Expand Down

0 comments on commit f9ad712

Please sign in to comment.