From 196c7eec09b3231f15103a65402e0aea505219b1 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Wed, 21 Jun 2023 19:58:12 +0400 Subject: [PATCH] Fix copy-pasted tensor validation --- rwkv/rwkv_cpp_model.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index 97e8209..1d166ff 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -82,26 +82,20 @@ def eval( assert self._valid, 'Model was freed' - def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None: - assert buf.device == torch.device('cpu'), f'{name} is not on CPU' - assert buf.dtype == torch.float32, f'{name} is not of type float32' - assert buf.is_contiguous(), f'{name} is not contiguous' - assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' - if state_in is not None: - validate_buffer(state_in, 'state_in', self._state_buffer_element_count) + validate_tensor(state_in, 'state_in', self._state_buffer_element_count) state_in_ptr = state_in.data_ptr() else: state_in_ptr = 0 if state_out is not None: - validate_buffer(state_out, 'state_out', self._state_buffer_element_count) + validate_tensor(state_out, 'state_out', self._state_buffer_element_count) else: state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu') if logits_out is not None: - validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count) + validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) else: logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu') @@ -145,25 +139,20 @@ def eval_sequence( assert self._valid, 'Model was freed' - def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None: - assert buf.dtype == torch.float32, f'{name} is not of type float32' - assert buf.is_contiguous(), f'{name} is not contiguous' - assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' - if state_in is not None: - validate_buffer(state_in, 'state_in', self._state_buffer_element_count) + validate_tensor(state_in, 'state_in', self._state_buffer_element_count) state_in_ptr = state_in.data_ptr() else: state_in_ptr = 0 if state_out is not None: - validate_buffer(state_out, 'state_out', self._state_buffer_element_count) + validate_tensor(state_out, 'state_out', self._state_buffer_element_count) else: state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu') if logits_out is not None: - validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count) + validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) else: logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu') @@ -194,3 +183,9 @@ def __del__(self): # Free the context on GC in case user forgot to call free() explicitly. if hasattr(self, '_valid') and self._valid: self.free() + +def validate_tensor(buf: torch.Tensor, name: str, size: int) -> None: + assert buf.device == torch.device('cpu'), f'{name} is not on CPU' + assert buf.dtype == torch.float32, f'{name} is not of type float32' + assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' + assert buf.is_contiguous(), f'{name} is not contiguous'