Skip to content

Commit

Permalink
Fix: Replace all assertions in Python code with if statements (#167)
Browse files Browse the repository at this point in the history
* Fix: Replace all assertions in Python code with if statements

Signed-off-by: whitealpa <ben@air.local>
Signed-off-by: whitealpa <132802748+whitealpa@users.noreply.github.com>

* Fixed typos

Signed-off-by: whitealpa <132802748+whitealpa@users.noreply.github.com>

---------

Signed-off-by: whitealpa <ben@air.local>
Signed-off-by: whitealpa <132802748+whitealpa@users.noreply.github.com>
Co-authored-by: whitealpa <ben@air.local>
  • Loading branch information
whitealpa and whitealpa committed Mar 2, 2024
1 parent 2a8735e commit d8f13ff
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 42 deletions.
3 changes: 2 additions & 1 deletion python/chat_with_bot.py
Expand Up @@ -50,7 +50,8 @@

user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt']

assert init_prompt != '', 'Prompt must not be empty'
if init_prompt == '':
raise ValueError('Prompt must not be empty')

library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')
Expand Down
3 changes: 2 additions & 1 deletion python/generate_completions.py
Expand Up @@ -32,7 +32,8 @@
add_tokenizer_argument(parser)
args = parser.parse_args()

assert prompt != '', 'Prompt must not be empty'
if prompt == '':
raise ValueError('Prompt must not be empty')

library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')
Expand Down
6 changes: 4 additions & 2 deletions python/measure_pexplexity.py
Expand Up @@ -40,14 +40,16 @@ def parse_args():

token_limit: int = args.token_limit

assert token_limit == -1 or token_limit > 0, 'Invalid token_limit'
if not (token_limit == -1 or token_limit > 0):
raise ValueError('Invalid token_limit')

if token_limit != -1 and token_count > token_limit:
tokens = tokens[0:token_limit]
token_count = token_limit
print(f'Text was limited to {token_limit} tokens')

assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation'
if not (token_count - args.ignore_first_n_tokens > 1):
raise ValueError('Need at least 2 tokens for evaluation')

# ---

Expand Down
13 changes: 8 additions & 5 deletions python/merge_lora_into_ggml.py
Expand Up @@ -86,7 +86,8 @@ def main() -> None:

print(f'* {key} {shape}')

assert data_type == 0 or data_type == 1, 'Only FP32 and FP16 models are supported'
if not (data_type == 0 or data_type == 1):
raise ValueError('Only FP32 and FP16 models are supported')

element_count: int = 1

Expand Down Expand Up @@ -126,8 +127,9 @@ def main() -> None:
if parameter.dtype == torch.float16:
replacement = replacement.half()

assert replacement.shape == parameter.shape, f'Parameter {key} has shape {parameter.shape} in model file ' \
f'and shape {replacement.shape} in LoRA file'
if replacement.shape != parameter.shape:
raise ValueError(f'Parameter {key} has shape {parameter.shape} in model file ' \
f'and shape {replacement.shape} in LoRA file')

parameter = replacement

Expand All @@ -143,8 +145,9 @@ def main() -> None:
lora_A: torch.Tensor = lora_state_dict[lora_A_key]
lora_B: torch.Tensor = lora_state_dict[lora_B_key]

assert lora_B.shape[1] == lora_A.shape[0], f'Invalid shape of LoRA matrices for {key}: ' \
f'{lora_A.shape}, {lora_B.shape}'
if lora_B.shape[1] != lora_A.shape[0]:
raise ValueError(f'Invalid shape of LoRA matrices for {key}: ' \
f'{lora_A.shape}, {lora_B.shape}')

lora_R: int = lora_B.shape[1]

Expand Down
49 changes: 34 additions & 15 deletions python/rwkv_cpp/rwkv_cpp_model.py
Expand Up @@ -52,9 +52,14 @@ def __init__(
if 'gpu_layers_count' in kwargs:
gpu_layer_count = kwargs['gpu_layers_count']

assert os.path.isfile(model_path), f'{model_path} is not a file'
assert thread_count > 0, 'Thread count must be > 0'
assert gpu_layer_count >= 0, 'GPU layer count must be >= 0'
if not os.path.isfile(model_path):
raise ValueError(f'{model_path} is not a file')

if not (thread_count > 0):
raise ValueError('Thread count must be > 0')

if not (gpu_layer_count >= 0):
raise ValueError('GPU layer count must be >= 0')

self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library

Expand Down Expand Up @@ -84,7 +89,8 @@ def gpu_offload_layers(self, layer_count: int) -> bool:
Count of layers to offload onto the GPU, must be >= 0.
"""

assert layer_count >= 0, 'Layer count must be >= 0'
if not (layer_count >= 0):
raise ValueError('Layer count must be >= 0')

return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)

Expand Down Expand Up @@ -133,7 +139,8 @@ def eval(
Logits vector of shape (n_vocab); state for the next step.
"""

assert self._valid, 'Model was freed'
if not self._valid:
raise ValueError('Model was freed')

use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)

Expand Down Expand Up @@ -207,7 +214,8 @@ def eval_sequence(
Logits vector of shape (n_vocab); state for the next step.
"""

assert self._valid, 'Model was freed'
if not self._valid:
raise ValueError('Model was freed')

use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)

Expand Down Expand Up @@ -281,7 +289,8 @@ def eval_sequence_in_chunks(
Logits vector of shape (n_vocab); state for the next step.
"""

assert self._valid, 'Model was freed'
if not self._valid:
raise ValueError('Model was freed')

use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)

Expand Down Expand Up @@ -320,7 +329,8 @@ def free(self) -> None:
The object must not be used anymore after calling this method.
"""

assert self._valid, 'Already freed'
if not self._valid:
raise ValueError('Already freed')

self._valid = False

Expand All @@ -344,16 +354,25 @@ def _detect_numpy_usage(self, tensors: List[Optional[NumpyArrayOrPyTorchTensor]]
def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None:
if self._is_pytorch_tensor(tensor):
tensor: torch.Tensor = tensor
assert tensor.device == torch.device('cpu'), f'{name} is not on CPU'
assert tensor.dtype == torch.float32, f'{name} is not of type float32'
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
assert tensor.is_contiguous(), f'{name} is not contiguous'

if tensor.device != torch.device('cpu'):
raise ValueError(f'{name} is not on CPU')
if tensor.dtype != torch.float32:
raise ValueError(f'{name} is not of type float32')
if tensor.shape != (size,):
raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
if not tensor.is_contiguous():
raise ValueError(f'{name} is not contiguous')
else:
import numpy as np
tensor: np.ndarray = tensor
assert tensor.dtype == np.float32, f'{name} is not of type float32'
assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})'
assert tensor.data.contiguous, f'{name} is not contiguous'

if tensor.dtype != np.float32:
raise ValueError(f'{name} is not of type float32')
if tensor.shape != (size,):
raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
if not tensor.data.contiguous:
raise ValueError(f'{name} is not contiguous')

def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor):
if self._is_pytorch_tensor(tensor):
Expand Down
33 changes: 20 additions & 13 deletions python/rwkv_cpp/rwkv_cpp_shared_library.py
Expand Up @@ -124,7 +124,8 @@ def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVCo

ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count))

assert ptr is not None, 'rwkv_init_from_file failed, check stderr'
if ptr is None:
raise ValueError('rwkv_init_from_file failed, check stderr')

return RWKVContext(ptr)

Expand All @@ -145,7 +146,8 @@ def rwkv_gpu_offload_layers(self, ctx: RWKVContext, layer_count: int) -> bool:
Count of layers to offload onto the GPU, must be >= 0.
"""

assert layer_count >= 0, 'Layer count must be >= 0'
if not (layer_count >= 0):
raise ValueError('Layer count must be >= 0')

return self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(layer_count))

Expand Down Expand Up @@ -176,13 +178,14 @@ def rwkv_eval(
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
"""

assert self.library.rwkv_eval(
if not self.library.rwkv_eval(
ctx.ptr,
ctypes.c_int32(token),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT)
), 'rwkv_eval failed, check stderr'
):
raise ValueError('rwkv_eval failed, check stderr')

def rwkv_eval_sequence(
self,
Expand Down Expand Up @@ -223,14 +226,15 @@ def rwkv_eval_sequence(
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
"""

assert self.library.rwkv_eval_sequence(
if not self.library.rwkv_eval_sequence(
ctx.ptr,
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
ctypes.c_size_t(len(tokens)),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT)
), 'rwkv_eval_sequence failed, check stderr'
):
raise ValueError('rwkv_eval_sequence failed, check stderr')

def rwkv_eval_sequence_in_chunks(
self,
Expand Down Expand Up @@ -269,15 +273,16 @@ def rwkv_eval_sequence_in_chunks(
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
"""

assert self.library.rwkv_eval_sequence_in_chunks(
if not self.library.rwkv_eval_sequence_in_chunks(
ctx.ptr,
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
ctypes.c_size_t(len(tokens)),
ctypes.c_size_t(chunk_size),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT)
), 'rwkv_eval_sequence_in_chunks failed, check stderr'
):
raise ValueError('rwkv_eval_sequence_in_chunks failed, check stderr')

def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
"""
Expand Down Expand Up @@ -373,13 +378,15 @@ def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out:
One of QUANTIZED_FORMAT_NAMES.
"""

assert format_name in QUANTIZED_FORMAT_NAMES, f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}'
if format_name not in QUANTIZED_FORMAT_NAMES:
raise ValueError(f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}')

assert self.library.rwkv_quantize_model_file(
if not self.library.rwkv_quantize_model_file(
model_file_path_in.encode('utf-8'),
model_file_path_out.encode('utf-8'),
format_name.encode('utf-8')
), 'rwkv_quantize_model_file failed, check stderr'
):
raise ValueError('rwkv_quantize_model_file failed, check stderr')

def rwkv_get_system_info_string(self) -> str:
"""
Expand Down Expand Up @@ -439,5 +446,5 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary:
if os.path.isfile(full_path):
return RWKVSharedLibrary(str(full_path))

assert False, (f'Failed to find {file_name} automatically; '
f'you need to find the library and create RWKVSharedLibrary specifying the path to it')
raise ValueError(f'Failed to find {file_name} automatically; '
f'you need to find the library and create RWKVSharedLibrary specifying the path to it')
3 changes: 2 additions & 1 deletion python/rwkv_cpp/rwkv_world_tokenizer.py
Expand Up @@ -46,7 +46,8 @@ def find_longest(self, key: bytes, idx: int = 0) -> Tuple[int, 'Trie', set]:

ch = key[idx]

assert ret is not None, 'Entry not found'
if ret is None:
raise ValueError('Entry not found')

return ret

Expand Down
6 changes: 4 additions & 2 deletions python/sampling.py
Expand Up @@ -16,8 +16,10 @@ def sample_logits(out, temperature: float = 1.0, top_p: float = 0.8, logit_bias:
return sample_probs(probs, temperature, top_p, logit_bias)

def sample_probs(probs: np.ndarray, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int:
assert 0.0 <= temperature, 'temperature'
assert 0.0 <= top_p <= 1.0, 'top_p'
if not (0.0 <= temperature):
raise ValueError('temperature')
if not (0.0 <= top_p <= 1.0):
raise ValueError('top_p')

if top_p == 0.0:
top_p = 1.0
Expand Down
4 changes: 2 additions & 2 deletions python/tokenizer_util.py
Expand Up @@ -22,7 +22,7 @@ def get_tokenizer(tokenizer_name: str, n_vocab: int) -> Tuple[
elif n_vocab == 65536:
tokenizer_name = 'world'
else:
assert False, f'Can not guess the tokenizer from n_vocab value of {n_vocab}'
raise ValueError(f'Can not guess the tokenizer from n_vocab value of {n_vocab}')

parent: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent

Expand All @@ -35,4 +35,4 @@ def get_tokenizer(tokenizer_name: str, n_vocab: int) -> Tuple[
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(parent / '20B_tokenizer.json'))
return tokenizer.decode, lambda x: tokenizer.encode(x).ids
else:
assert False, f'Unknown tokenizer {tokenizer_name}'
raise ValueError(f'Unknown tokenizer {tokenizer_name}')

0 comments on commit d8f13ff

Please sign in to comment.