diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index 4f089ad..70c4258 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -32,14 +32,14 @@ def __init__( assert os.path.isfile(model_path), f'{model_path} is not a file' assert thread_count > 0, 'Thread count must be positive' - self.library = shared_library + self._library = shared_library - self.ctx = self.library.rwkv_init_from_file(model_path, thread_count) + self._ctx = self._library.rwkv_init_from_file(model_path, thread_count) - self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx) - self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx) + self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx) + self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx) - self.valid = True + self._valid = True def eval( self, @@ -69,7 +69,7 @@ def eval( Logits vector of shape (n_vocab); state for the next step. """ - assert self.valid, 'Model was freed' + assert self._valid, 'Model was freed' def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None: assert buf.dtype == torch.float32, f'{name} is not of type float32' @@ -77,24 +77,24 @@ def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None: assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' if state_in is not None: - validate_buffer(state_in, 'state_in', self.state_buffer_element_count) + validate_buffer(state_in, 'state_in', self._state_buffer_element_count) state_in_ptr = state_in.storage().data_ptr() else: state_in_ptr = 0 if state_out is not None: - validate_buffer(state_out, 'state_out', self.state_buffer_element_count) + validate_buffer(state_out, 'state_out', self._state_buffer_element_count) else: - state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu') + state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu') if logits_out is not None: - validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count) + validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count) else: - logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu') + logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu') - self.library.rwkv_eval( - self.ctx, + self._library.rwkv_eval( + self._ctx, token, state_in_ptr, state_out.storage().data_ptr(), @@ -110,8 +110,13 @@ def free(self): The object must not be used anymore after calling this method. """ - assert self.valid, 'Already freed' + assert self._valid, 'Already freed' - self.valid = False + self._valid = False - self.library.rwkv_free(self.ctx) + self._library.rwkv_free(self._ctx) + + def __del__(self): + # Free the context on GC in case user forgot to call free() explicitly. + if self._valid: + self.free()