From f416df62c01581acec6e2fcd62a327fe59051be7 Mon Sep 17 00:00:00 2001 From: josc146 Date: Sun, 24 Mar 2024 21:43:26 +0800 Subject: [PATCH] expose model version api --- python/rwkv_cpp/rwkv_cpp_model.py | 8 ++++++ python/rwkv_cpp/rwkv_cpp_shared_library.py | 30 ++++++++++++++++++++++ rwkv.cpp | 10 ++++++++ rwkv.h | 6 +++++ 4 files changed, 54 insertions(+) diff --git a/python/rwkv_cpp/rwkv_cpp_model.py b/python/rwkv_cpp/rwkv_cpp_model.py index 59dd304..3859dbb 100644 --- a/python/rwkv_cpp/rwkv_cpp_model.py +++ b/python/rwkv_cpp/rwkv_cpp_model.py @@ -94,6 +94,14 @@ def gpu_offload_layers(self, layer_count: int) -> bool: return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count) + @property + def arch_version_major(self) -> int: + return self._library.rwkv_get_arch_version_major(self._ctx) + + @property + def arch_version_minor(self) -> int: + return self._library.rwkv_get_arch_version_minor(self._ctx) + @property def n_vocab(self) -> int: return self._library.rwkv_get_n_vocab(self._ctx) diff --git a/python/rwkv_cpp/rwkv_cpp_shared_library.py b/python/rwkv_cpp/rwkv_cpp_shared_library.py index 4c095a7..08eb0a3 100644 --- a/python/rwkv_cpp/rwkv_cpp_shared_library.py +++ b/python/rwkv_cpp/rwkv_cpp_shared_library.py @@ -80,6 +80,12 @@ def __init__(self, shared_library_path: str) -> None: ] self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool + self.library.rwkv_get_arch_version_major.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_arch_version_major.restype = ctypes.c_uint32 + + self.library.rwkv_get_arch_version_minor.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_arch_version_minor.restype = ctypes.c_uint32 + self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p] self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t @@ -284,6 +290,30 @@ def rwkv_eval_sequence_in_chunks( ): raise ValueError('rwkv_eval_sequence_in_chunks failed, check stderr') + def rwkv_get_arch_version_major(self, ctx: RWKVContext) -> int: + """ + Returns the major version used by the given model. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_arch_version_major(ctx.ptr) + + def rwkv_get_arch_version_minor(self, ctx: RWKVContext) -> int: + """ + Returns the minor version used by the given model. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_arch_version_minor(ctx.ptr) + def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int: """ Returns the number of tokens in the given model's vocabulary. diff --git a/rwkv.cpp b/rwkv.cpp index 6fae152..7dc3e60 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -104,6 +104,16 @@ extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct r return rwkv_get_logits_len(ctx); } +// API function. +size_t rwkv_get_arch_version_major(const struct rwkv_context * ctx) { + return (size_t) ctx->model->arch_version_major; +} + +// API function. +size_t rwkv_get_arch_version_minor(const struct rwkv_context * ctx) { + return (size_t) ctx->model->arch_version_minor; +} + // API function. size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) { return (size_t) ctx->model->header.n_vocab; diff --git a/rwkv.h b/rwkv.h index 40b9266..276e048 100644 --- a/rwkv.h +++ b/rwkv.h @@ -179,6 +179,12 @@ extern "C" { float * logits_out ); + // Returns the major version used by the given model. + RWKV_API size_t rwkv_get_arch_version_major(const struct rwkv_context * ctx); + + // Returns the minor version used by the given model. + RWKV_API size_t rwkv_get_arch_version_minor(const struct rwkv_context * ctx); + // Returns the number of tokens in the given model's vocabulary. // Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536). RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx);