Skip to content

Commit

Permalink
Free ggml context when model is garbage collected
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Apr 5, 2023
1 parent 10b71d7 commit 1992a37
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions rwkv/rwkv_cpp_model.py
Expand Up @@ -32,14 +32,14 @@ def __init__(
assert os.path.isfile(model_path), f'{model_path} is not a file'
assert thread_count > 0, 'Thread count must be positive'

self.library = shared_library
self._library = shared_library

self.ctx = self.library.rwkv_init_from_file(model_path, thread_count)
self._ctx = self._library.rwkv_init_from_file(model_path, thread_count)

self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx)
self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx)

self.valid = True
self._valid = True

def eval(
self,
Expand Down Expand Up @@ -69,32 +69,32 @@ def eval(
Logits vector of shape (n_vocab); state for the next step.
"""

assert self.valid, 'Model was freed'
assert self._valid, 'Model was freed'

def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
assert buf.dtype == torch.float32, f'{name} is not of type float32'
assert buf.is_contiguous(), f'{name} is not contiguous'
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'

if state_in is not None:
validate_buffer(state_in, 'state_in', self.state_buffer_element_count)
validate_buffer(state_in, 'state_in', self._state_buffer_element_count)

state_in_ptr = state_in.storage().data_ptr()
else:
state_in_ptr = 0

if state_out is not None:
validate_buffer(state_out, 'state_out', self.state_buffer_element_count)
validate_buffer(state_out, 'state_out', self._state_buffer_element_count)
else:
state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu')

if logits_out is not None:
validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count)
validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count)
else:
logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu')

self.library.rwkv_eval(
self.ctx,
self._library.rwkv_eval(
self._ctx,
token,
state_in_ptr,
state_out.storage().data_ptr(),
Expand All @@ -110,8 +110,13 @@ def free(self):
The object must not be used anymore after calling this method.
"""

assert self.valid, 'Already freed'
assert self._valid, 'Already freed'

self.valid = False
self._valid = False

self.library.rwkv_free(self.ctx)
self._library.rwkv_free(self._ctx)

def __del__(self):
# Free the context on GC in case user forgot to call free() explicitly.
if self._valid:
self.free()

0 comments on commit 1992a37

Please sign in to comment.