Skip to content

Commit

Permalink
Add function to offload layers
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Sep 23, 2023
1 parent d6a9b24 commit 8260553
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions python/rwkv_cpp/rwkv_cpp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
Thread count to use. If not set, defaults to CPU count / 2.
gpu_layer_count : int
Count of layers to offload onto the GPU, must be >= 0.
See documentation of `RWKVSharedLibrary.rwkv_gpu_offload_layers` for details about layer offloading.
See documentation of `gpu_offload_layers` for details about layer offloading.
"""

if 'gpu_layers_count' in kwargs:
Expand All @@ -61,13 +61,33 @@ def __init__(
self._ctx: rwkv_cpp_shared_library.RWKVContext = self._library.rwkv_init_from_file(model_path, thread_count)

if gpu_layer_count > 0:
self._library.rwkv_gpu_offload_layers(self._ctx, gpu_layer_count)
self.gpu_offload_layers(gpu_layer_count)

self._state_buffer_element_count: int = self._library.rwkv_get_state_buffer_element_count(self._ctx)
self._logits_buffer_element_count: int = self._library.rwkv_get_logits_buffer_element_count(self._ctx)

self._valid: bool = True

def gpu_offload_layers(self, layer_count: int) -> bool:
"""
Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
- pass `model.n_layer` to offload all layers except model head
- pass `model.n_layer + 1` to offload all layers, including model head
Returns true if at least one layer was offloaded.
If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
Parameters
----------
layer_count : int
Count of layers to offload onto the GPU, must be >= 0.
"""

assert layer_count >= 0, 'Layer count must be >= 0'

return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)

@property
def n_vocab(self) -> int:
return self._library.rwkv_get_n_vocab(self._ctx)
Expand Down

0 comments on commit 8260553

Please sign in to comment.