diff --git a/rwkv.cpp b/rwkv.cpp index c3b5443..28a1da3 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -1717,15 +1717,15 @@ extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct r return rwkv_get_logits_len(ctx); } -size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) { +extern "C" RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) { return (size_t) ctx->instance->model.header.n_vocab; } -size_t rwkv_get_n_embed(const struct rwkv_context * ctx) { +extern "C" RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx) { return (size_t) ctx->instance->model.header.n_embed; } -size_t rwkv_get_n_layer(const struct rwkv_context * ctx) { +extern "C" RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx) { return (size_t) ctx->instance->model.header.n_layer; } diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index 1d166ff..0a8a842 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -52,6 +52,19 @@ def __init__( self._valid = True + @property + def n_vocab(self): + return self._library.rwkv_get_n_vocab(self._ctx) + + @property + def n_embed(self): + return self._library.rwkv_get_n_embed(self._ctx) + + @property + def n_layer(self): + return self._library.rwkv_get_n_layer(self._ctx) + + def eval( self, token: int, diff --git a/rwkv/rwkv_cpp_shared_library.py b/rwkv/rwkv_cpp_shared_library.py index ac6b722..9466990 100644 --- a/rwkv/rwkv_cpp_shared_library.py +++ b/rwkv/rwkv_cpp_shared_library.py @@ -63,6 +63,15 @@ def __init__(self, shared_library_path: str): ] self.library.rwkv_eval_sequence.restype = ctypes.c_bool + self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t + + self.library.rwkv_get_n_embed.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_n_embed.restype = ctypes.c_size_t + + self.library.rwkv_get_n_layer.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_n_layer.restype = ctypes.c_size_t + self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p] self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32 @@ -255,6 +264,27 @@ def rwkv_get_system_info_string(self) -> str: return self.library.rwkv_get_system_info_string().decode('utf-8') + def rwkv_get_n_embed(self, ctx: RWKVContext) -> int: + """ + Returns the size of one embedding vector. + """ + + return self.library.rwkv_get_n_embed(ctx.ptr) + + def rwkv_get_n_layer(self, ctx: RWKVContext) -> int: + """ + Returns the number of layers. + """ + + return self.library.rwkv_get_n_layer(ctx.ptr) + + def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int: + """ + Returns vocab size. + """ + + return self.library.rwkv_get_n_vocab(ctx.ptr) + def load_rwkv_shared_library() -> RWKVSharedLibrary: """