Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose model version api #170

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/rwkv_cpp/rwkv_cpp_model.py
Expand Up @@ -94,6 +94,14 @@ def gpu_offload_layers(self, layer_count: int) -> bool:

return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)

@property
def arch_version_major(self) -> int:
return self._library.rwkv_get_arch_version_major(self._ctx)

@property
def arch_version_minor(self) -> int:
return self._library.rwkv_get_arch_version_minor(self._ctx)

@property
def n_vocab(self) -> int:
return self._library.rwkv_get_n_vocab(self._ctx)
Expand Down
30 changes: 30 additions & 0 deletions python/rwkv_cpp/rwkv_cpp_shared_library.py
Expand Up @@ -80,6 +80,12 @@ def __init__(self, shared_library_path: str) -> None:
]
self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool

self.library.rwkv_get_arch_version_major.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_arch_version_major.restype = ctypes.c_uint32

self.library.rwkv_get_arch_version_minor.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_arch_version_minor.restype = ctypes.c_uint32

self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t

Expand Down Expand Up @@ -284,6 +290,30 @@ def rwkv_eval_sequence_in_chunks(
):
raise ValueError('rwkv_eval_sequence_in_chunks failed, check stderr')

def rwkv_get_arch_version_major(self, ctx: RWKVContext) -> int:
"""
Returns the major version used by the given model.

Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""

return self.library.rwkv_get_arch_version_major(ctx.ptr)

def rwkv_get_arch_version_minor(self, ctx: RWKVContext) -> int:
"""
Returns the minor version used by the given model.

Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
"""

return self.library.rwkv_get_arch_version_minor(ctx.ptr)

def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
"""
Returns the number of tokens in the given model's vocabulary.
Expand Down
10 changes: 10 additions & 0 deletions rwkv.cpp
Expand Up @@ -104,6 +104,16 @@ extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct r
return rwkv_get_logits_len(ctx);
}

// API function.
size_t rwkv_get_arch_version_major(const struct rwkv_context * ctx) {
return (size_t) ctx->model->arch_version_major;
}

// API function.
size_t rwkv_get_arch_version_minor(const struct rwkv_context * ctx) {
return (size_t) ctx->model->arch_version_minor;
}

// API function.
size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) {
return (size_t) ctx->model->header.n_vocab;
Expand Down
6 changes: 6 additions & 0 deletions rwkv.h
Expand Up @@ -179,6 +179,12 @@ extern "C" {
float * logits_out
);

// Returns the major version used by the given model.
RWKV_API size_t rwkv_get_arch_version_major(const struct rwkv_context * ctx);

// Returns the minor version used by the given model.
RWKV_API size_t rwkv_get_arch_version_minor(const struct rwkv_context * ctx);

// Returns the number of tokens in the given model's vocabulary.
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx);
Expand Down