diff --git a/rwkv/rwkv_cpp_shared_library.py b/rwkv/rwkv_cpp_shared_library.py index 9466990..718b697 100644 --- a/rwkv/rwkv_cpp_shared_library.py +++ b/rwkv/rwkv_cpp_shared_library.py @@ -90,6 +90,8 @@ def __init__(self, shared_library_path: str): self.library.rwkv_get_system_info_string.argtypes = [] self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p + self.nullptr = ctypes.cast(0, ctypes.c_void_p) + def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext: """ Loads the model from a file and prepares it for inference. @@ -232,7 +234,7 @@ def rwkv_free(self, ctx: RWKVContext) -> None: self.library.rwkv_free(ctx.ptr) - ctx.ptr = ctypes.cast(0, ctypes.c_void_p) + ctx.ptr = self.nullptr def rwkv_quantize_model_file(self, model_file_path_in: str, model_file_path_out: str, format_name: str) -> None: """