Skip to content

Commit

Permalink
Expose n_vocab, n_embed, n_layer to the Python interface (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
mczk77 committed Jul 18, 2023
1 parent 84634c0 commit 25ee75e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
6 changes: 3 additions & 3 deletions rwkv.cpp
Expand Up @@ -1717,15 +1717,15 @@ extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct r
return rwkv_get_logits_len(ctx);
}

size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) {
extern "C" RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) {
return (size_t) ctx->instance->model.header.n_vocab;
}

size_t rwkv_get_n_embed(const struct rwkv_context * ctx) {
extern "C" RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx) {
return (size_t) ctx->instance->model.header.n_embed;
}

size_t rwkv_get_n_layer(const struct rwkv_context * ctx) {
extern "C" RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx) {
return (size_t) ctx->instance->model.header.n_layer;
}

Expand Down
13 changes: 13 additions & 0 deletions rwkv/rwkv_cpp_model.py
Expand Up @@ -52,6 +52,19 @@ def __init__(

self._valid = True

@property
def n_vocab(self):
return self._library.rwkv_get_n_vocab(self._ctx)

@property
def n_embed(self):
return self._library.rwkv_get_n_embed(self._ctx)

@property
def n_layer(self):
return self._library.rwkv_get_n_layer(self._ctx)


def eval(
self,
token: int,
Expand Down
30 changes: 30 additions & 0 deletions rwkv/rwkv_cpp_shared_library.py
Expand Up @@ -63,6 +63,15 @@ def __init__(self, shared_library_path: str):
]
self.library.rwkv_eval_sequence.restype = ctypes.c_bool

self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t

self.library.rwkv_get_n_embed.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_n_embed.restype = ctypes.c_size_t

self.library.rwkv_get_n_layer.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_n_layer.restype = ctypes.c_size_t

self.library.rwkv_get_state_buffer_element_count.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_state_buffer_element_count.restype = ctypes.c_uint32

Expand Down Expand Up @@ -255,6 +264,27 @@ def rwkv_get_system_info_string(self) -> str:

return self.library.rwkv_get_system_info_string().decode('utf-8')

def rwkv_get_n_embed(self, ctx: RWKVContext) -> int:
"""
Returns the size of one embedding vector.
"""

return self.library.rwkv_get_n_embed(ctx.ptr)

def rwkv_get_n_layer(self, ctx: RWKVContext) -> int:
"""
Returns the number of layers.
"""

return self.library.rwkv_get_n_layer(ctx.ptr)

def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
"""
Returns vocab size.
"""

return self.library.rwkv_get_n_vocab(ctx.ptr)


def load_rwkv_shared_library() -> RWKVSharedLibrary:
"""
Expand Down

0 comments on commit 25ee75e

Please sign in to comment.