diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index 9d6b9d3..97e8209 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -15,6 +15,7 @@ def __init__( model_path: str, thread_count: int = max(1, multiprocessing.cpu_count() // 2), gpu_layer_count: int = 0, + **kwargs ): """ Loads the model and prepares it for inference. @@ -32,6 +33,9 @@ def __init__( Count of layers to offload onto the GPU, must be >= 0. """ + if 'gpu_layers_count' in kwargs: + gpu_layer_count = kwargs['gpu_layers_count'] + assert os.path.isfile(model_path), f'{model_path} is not a file' assert thread_count > 0, 'Thread count must be > 0' assert gpu_layer_count >= 0, 'GPU layer count must be >= 0'