Skip to content

Commit

Permalink
Fix copy-pasted tensor validation
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Jun 21, 2023
1 parent a226519 commit 196c7ee
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions rwkv/rwkv_cpp_model.py
Expand Up @@ -82,26 +82,20 @@ def eval(

assert self._valid, 'Model was freed'

def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
assert buf.device == torch.device('cpu'), f'{name} is not on CPU'
assert buf.dtype == torch.float32, f'{name} is not of type float32'
assert buf.is_contiguous(), f'{name} is not contiguous'
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'

if state_in is not None:
validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
validate_tensor(state_in, 'state_in', self._state_buffer_element_count)

state_in_ptr = state_in.data_ptr()
else:
state_in_ptr = 0

if state_out is not None:
validate_buffer(state_out, 'state_out', self._state_buffer_element_count)
validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
else:
state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu')

if logits_out is not None:
validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count)
validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
else:
logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu')

Expand Down Expand Up @@ -145,25 +139,20 @@ def eval_sequence(

assert self._valid, 'Model was freed'

def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
assert buf.dtype == torch.float32, f'{name} is not of type float32'
assert buf.is_contiguous(), f'{name} is not contiguous'
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'

if state_in is not None:
validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
validate_tensor(state_in, 'state_in', self._state_buffer_element_count)

state_in_ptr = state_in.data_ptr()
else:
state_in_ptr = 0

if state_out is not None:
validate_buffer(state_out, 'state_out', self._state_buffer_element_count)
validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
else:
state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu')

if logits_out is not None:
validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count)
validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
else:
logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu')

Expand Down Expand Up @@ -194,3 +183,9 @@ def __del__(self):
# Free the context on GC in case user forgot to call free() explicitly.
if hasattr(self, '_valid') and self._valid:
self.free()

def validate_tensor(buf: torch.Tensor, name: str, size: int) -> None:
assert buf.device == torch.device('cpu'), f'{name} is not on CPU'
assert buf.dtype == torch.float32, f'{name} is not of type float32'
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'
assert buf.is_contiguous(), f'{name} is not contiguous'

0 comments on commit 196c7ee

Please sign in to comment.