From d8f13ffe231712c11427b180cce2fed76757b38d Mon Sep 17 00:00:00 2001 From: whitealpa <132802748+whitealpa@users.noreply.github.com> Date: Sat, 2 Mar 2024 13:42:32 +0700 Subject: [PATCH] Fix: Replace all assertions in Python code with if statements (#167) * Fix: Replace all assertions in Python code with if statements Signed-off-by: whitealpa Signed-off-by: whitealpa <132802748+whitealpa@users.noreply.github.com> * Fixed typos Signed-off-by: whitealpa <132802748+whitealpa@users.noreply.github.com> --------- Signed-off-by: whitealpa Signed-off-by: whitealpa <132802748+whitealpa@users.noreply.github.com> Co-authored-by: whitealpa --- python/chat_with_bot.py | 3 +- python/generate_completions.py | 3 +- python/measure_pexplexity.py | 6 ++- python/merge_lora_into_ggml.py | 13 +++--- python/rwkv_cpp/rwkv_cpp_model.py | 49 +++++++++++++++------- python/rwkv_cpp/rwkv_cpp_shared_library.py | 33 +++++++++------ python/rwkv_cpp/rwkv_world_tokenizer.py | 3 +- python/sampling.py | 6 ++- python/tokenizer_util.py | 4 +- 9 files changed, 78 insertions(+), 42 deletions(-) diff --git a/python/chat_with_bot.py b/python/chat_with_bot.py index b7a630a..17fd1c1 100644 --- a/python/chat_with_bot.py +++ b/python/chat_with_bot.py @@ -50,7 +50,8 @@ user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt'] -assert init_prompt != '', 'Prompt must not be empty' +if init_prompt == '': + raise ValueError('Prompt must not be empty') library = rwkv_cpp_shared_library.load_rwkv_shared_library() print(f'System info: {library.rwkv_get_system_info_string()}') diff --git a/python/generate_completions.py b/python/generate_completions.py index 6f585ed..4685720 100644 --- a/python/generate_completions.py +++ b/python/generate_completions.py @@ -32,7 +32,8 @@ add_tokenizer_argument(parser) args = parser.parse_args() -assert prompt != '', 'Prompt must not be empty' +if prompt == '': + raise ValueError('Prompt must not be empty') library = rwkv_cpp_shared_library.load_rwkv_shared_library() print(f'System info: {library.rwkv_get_system_info_string()}') diff --git a/python/measure_pexplexity.py b/python/measure_pexplexity.py index 6f6674c..80fd7e5 100644 --- a/python/measure_pexplexity.py +++ b/python/measure_pexplexity.py @@ -40,14 +40,16 @@ def parse_args(): token_limit: int = args.token_limit -assert token_limit == -1 or token_limit > 0, 'Invalid token_limit' +if not (token_limit == -1 or token_limit > 0): + raise ValueError('Invalid token_limit') if token_limit != -1 and token_count > token_limit: tokens = tokens[0:token_limit] token_count = token_limit print(f'Text was limited to {token_limit} tokens') -assert token_count - args.ignore_first_n_tokens > 1, 'Need at least 2 tokens for evaluation' +if not (token_count - args.ignore_first_n_tokens > 1): + raise ValueError('Need at least 2 tokens for evaluation') # --- diff --git a/python/merge_lora_into_ggml.py b/python/merge_lora_into_ggml.py index 5754ddf..3988697 100644 --- a/python/merge_lora_into_ggml.py +++ b/python/merge_lora_into_ggml.py @@ -86,7 +86,8 @@ def main() -> None: print(f'* {key} {shape}') - assert data_type == 0 or data_type == 1, 'Only FP32 and FP16 models are supported' + if not (data_type == 0 or data_type == 1): + raise ValueError('Only FP32 and FP16 models are supported') element_count: int = 1 @@ -126,8 +127,9 @@ def main() -> None: if parameter.dtype == torch.float16: replacement = replacement.half() - assert replacement.shape == parameter.shape, f'Parameter {key} has shape {parameter.shape} in model file ' \ - f'and shape {replacement.shape} in LoRA file' + if replacement.shape != parameter.shape: + raise ValueError(f'Parameter {key} has shape {parameter.shape} in model file ' \ + f'and shape {replacement.shape} in LoRA file') parameter = replacement @@ -143,8 +145,9 @@ def main() -> None: lora_A: torch.Tensor = lora_state_dict[lora_A_key] lora_B: torch.Tensor = lora_state_dict[lora_B_key] - assert lora_B.shape[1] == lora_A.shape[0], f'Invalid shape of LoRA matrices for {key}: ' \ - f'{lora_A.shape}, {lora_B.shape}' + if lora_B.shape[1] != lora_A.shape[0]: + raise ValueError(f'Invalid shape of LoRA matrices for {key}: ' \ + f'{lora_A.shape}, {lora_B.shape}') lora_R: int = lora_B.shape[1] diff --git a/python/rwkv_cpp/rwkv_cpp_model.py b/python/rwkv_cpp/rwkv_cpp_model.py index 4b78c76..59dd304 100644 --- a/python/rwkv_cpp/rwkv_cpp_model.py +++ b/python/rwkv_cpp/rwkv_cpp_model.py @@ -52,9 +52,14 @@ def __init__( if 'gpu_layers_count' in kwargs: gpu_layer_count = kwargs['gpu_layers_count'] - assert os.path.isfile(model_path), f'{model_path} is not a file' - assert thread_count > 0, 'Thread count must be > 0' - assert gpu_layer_count >= 0, 'GPU layer count must be >= 0' + if not os.path.isfile(model_path): + raise ValueError(f'{model_path} is not a file') + + if not (thread_count > 0): + raise ValueError('Thread count must be > 0') + + if not (gpu_layer_count >= 0): + raise ValueError('GPU layer count must be >= 0') self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library @@ -84,7 +89,8 @@ def gpu_offload_layers(self, layer_count: int) -> bool: Count of layers to offload onto the GPU, must be >= 0. """ - assert layer_count >= 0, 'Layer count must be >= 0' + if not (layer_count >= 0): + raise ValueError('Layer count must be >= 0') return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count) @@ -133,7 +139,8 @@ def eval( Logits vector of shape (n_vocab); state for the next step. """ - assert self._valid, 'Model was freed' + if not self._valid: + raise ValueError('Model was freed') use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) @@ -207,7 +214,8 @@ def eval_sequence( Logits vector of shape (n_vocab); state for the next step. """ - assert self._valid, 'Model was freed' + if not self._valid: + raise ValueError('Model was freed') use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) @@ -281,7 +289,8 @@ def eval_sequence_in_chunks( Logits vector of shape (n_vocab); state for the next step. """ - assert self._valid, 'Model was freed' + if not self._valid: + raise ValueError('Model was freed') use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) @@ -320,7 +329,8 @@ def free(self) -> None: The object must not be used anymore after calling this method. """ - assert self._valid, 'Already freed' + if not self._valid: + raise ValueError('Already freed') self._valid = False @@ -344,16 +354,25 @@ def _detect_numpy_usage(self, tensors: List[Optional[NumpyArrayOrPyTorchTensor]] def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None: if self._is_pytorch_tensor(tensor): tensor: torch.Tensor = tensor - assert tensor.device == torch.device('cpu'), f'{name} is not on CPU' - assert tensor.dtype == torch.float32, f'{name} is not of type float32' - assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})' - assert tensor.is_contiguous(), f'{name} is not contiguous' + + if tensor.device != torch.device('cpu'): + raise ValueError(f'{name} is not on CPU') + if tensor.dtype != torch.float32: + raise ValueError(f'{name} is not of type float32') + if tensor.shape != (size,): + raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})') + if not tensor.is_contiguous(): + raise ValueError(f'{name} is not contiguous') else: import numpy as np tensor: np.ndarray = tensor - assert tensor.dtype == np.float32, f'{name} is not of type float32' - assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})' - assert tensor.data.contiguous, f'{name} is not contiguous' + + if tensor.dtype != np.float32: + raise ValueError(f'{name} is not of type float32') + if tensor.shape != (size,): + raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})') + if not tensor.data.contiguous: + raise ValueError(f'{name} is not contiguous') def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor): if self._is_pytorch_tensor(tensor): diff --git a/python/rwkv_cpp/rwkv_cpp_shared_library.py b/python/rwkv_cpp/rwkv_cpp_shared_library.py index fe82673..4c095a7 100644 --- a/python/rwkv_cpp/rwkv_cpp_shared_library.py +++ b/python/rwkv_cpp/rwkv_cpp_shared_library.py @@ -124,7 +124,8 @@ def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVCo ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count)) - assert ptr is not None, 'rwkv_init_from_file failed, check stderr' + if ptr is None: + raise ValueError('rwkv_init_from_file failed, check stderr') return RWKVContext(ptr) @@ -145,7 +146,8 @@ def rwkv_gpu_offload_layers(self, ctx: RWKVContext, layer_count: int) -> bool: Count of layers to offload onto the GPU, must be >= 0. """ - assert layer_count >= 0, 'Layer count must be >= 0' + if not (layer_count >= 0): + raise ValueError('Layer count must be >= 0') return self.library.rwkv_gpu_offload_layers(ctx.ptr, ctypes.c_uint32(layer_count)) @@ -176,13 +178,14 @@ def rwkv_eval( Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. """ - assert self.library.rwkv_eval( + if not self.library.rwkv_eval( ctx.ptr, ctypes.c_int32(token), ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), ctypes.cast(state_out_address, P_FLOAT), ctypes.cast(logits_out_address, P_FLOAT) - ), 'rwkv_eval failed, check stderr' + ): + raise ValueError('rwkv_eval failed, check stderr') def rwkv_eval_sequence( self, @@ -223,14 +226,15 @@ def rwkv_eval_sequence( Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. """ - assert self.library.rwkv_eval_sequence( + if not self.library.rwkv_eval_sequence( ctx.ptr, ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT), ctypes.c_size_t(len(tokens)), ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), ctypes.cast(state_out_address, P_FLOAT), ctypes.cast(logits_out_address, P_FLOAT) - ), 'rwkv_eval_sequence failed, check stderr' + ): + raise ValueError('rwkv_eval_sequence failed, check stderr') def rwkv_eval_sequence_in_chunks( self, @@ -269,7 +273,7 @@ def rwkv_eval_sequence_in_chunks( Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. """ - assert self.library.rwkv_eval_sequence_in_chunks( + if not self.library.rwkv_eval_sequence_in_chunks( ctx.ptr, ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT), ctypes.c_size_t(len(tokens)), @@ -277,7 +281,8 @@ def rwkv_eval_sequence_in_chunks( ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), ctypes.cast(state_out_address, P_FLOAT), ctypes.cast(logits_out_address, P_FLOAT) - ), 'rwkv_eval_sequence_in_chunks failed, check stderr' + ): + raise ValueError('rwkv_eval_sequence_in_chunks failed, check stderr') def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int: """ @@ -373,13 +378,15 @@ def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: One of QUANTIZED_FORMAT_NAMES. """ - assert format_name in QUANTIZED_FORMAT_NAMES, f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}' + if format_name not in QUANTIZED_FORMAT_NAMES: + raise ValueError(f'Unknown format name {format_name}, use one of {QUANTIZED_FORMAT_NAMES}') - assert self.library.rwkv_quantize_model_file( + if not self.library.rwkv_quantize_model_file( model_file_path_in.encode('utf-8'), model_file_path_out.encode('utf-8'), format_name.encode('utf-8') - ), 'rwkv_quantize_model_file failed, check stderr' + ): + raise ValueError('rwkv_quantize_model_file failed, check stderr') def rwkv_get_system_info_string(self) -> str: """ @@ -439,5 +446,5 @@ def load_rwkv_shared_library() -> RWKVSharedLibrary: if os.path.isfile(full_path): return RWKVSharedLibrary(str(full_path)) - assert False, (f'Failed to find {file_name} automatically; ' - f'you need to find the library and create RWKVSharedLibrary specifying the path to it') + raise ValueError(f'Failed to find {file_name} automatically; ' + f'you need to find the library and create RWKVSharedLibrary specifying the path to it') \ No newline at end of file diff --git a/python/rwkv_cpp/rwkv_world_tokenizer.py b/python/rwkv_cpp/rwkv_world_tokenizer.py index ca864ef..e76d272 100644 --- a/python/rwkv_cpp/rwkv_world_tokenizer.py +++ b/python/rwkv_cpp/rwkv_world_tokenizer.py @@ -46,7 +46,8 @@ def find_longest(self, key: bytes, idx: int = 0) -> Tuple[int, 'Trie', set]: ch = key[idx] - assert ret is not None, 'Entry not found' + if ret is None: + raise ValueError('Entry not found') return ret diff --git a/python/sampling.py b/python/sampling.py index 0736638..7cc04b2 100644 --- a/python/sampling.py +++ b/python/sampling.py @@ -16,8 +16,10 @@ def sample_logits(out, temperature: float = 1.0, top_p: float = 0.8, logit_bias: return sample_probs(probs, temperature, top_p, logit_bias) def sample_probs(probs: np.ndarray, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int: - assert 0.0 <= temperature, 'temperature' - assert 0.0 <= top_p <= 1.0, 'top_p' + if not (0.0 <= temperature): + raise ValueError('temperature') + if not (0.0 <= top_p <= 1.0): + raise ValueError('top_p') if top_p == 0.0: top_p = 1.0 diff --git a/python/tokenizer_util.py b/python/tokenizer_util.py index 0471142..909bcb1 100644 --- a/python/tokenizer_util.py +++ b/python/tokenizer_util.py @@ -22,7 +22,7 @@ def get_tokenizer(tokenizer_name: str, n_vocab: int) -> Tuple[ elif n_vocab == 65536: tokenizer_name = 'world' else: - assert False, f'Can not guess the tokenizer from n_vocab value of {n_vocab}' + raise ValueError(f'Can not guess the tokenizer from n_vocab value of {n_vocab}') parent: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent @@ -35,4 +35,4 @@ def get_tokenizer(tokenizer_name: str, n_vocab: int) -> Tuple[ tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(parent / '20B_tokenizer.json')) return tokenizer.decode, lambda x: tokenizer.encode(x).ids else: - assert False, f'Unknown tokenizer {tokenizer_name}' + raise ValueError(f'Unknown tokenizer {tokenizer_name}')