Skip to content

Commit

Permalink
Add rwkv_eval_sequence_in_chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
saharNooby committed Sep 23, 2023
1 parent 8260553 commit 3caaa1d
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 24 deletions.
5 changes: 2 additions & 3 deletions python/chat_with_bot.py
Expand Up @@ -69,10 +69,9 @@
def process_tokens(_tokens: List[int], new_line_logit_bias: float = 0.0) -> None:
global processed_tokens, logits, state

processed_tokens += _tokens
logits, state = model.eval_sequence_in_chunks(_tokens, state, state, logits, use_numpy=True)

for _token in _tokens:
logits, state = model.eval(_token, state, state, logits, use_numpy=True)
processed_tokens += _tokens

logits[END_OF_LINE_TOKEN] += new_line_logit_bias

Expand Down
5 changes: 1 addition & 4 deletions python/generate_completions.py
Expand Up @@ -47,10 +47,7 @@
prompt_token_count: int = len(prompt_tokens)
print(f'{prompt_token_count} tokens in prompt')

init_logits, init_state = None, None

for token in prompt_tokens:
init_logits, init_state = model.eval(token, init_state, init_state, init_logits, use_numpy=True)
init_logits, init_state = model.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True)

for GENERATION in range(generation_count):
print(f'\n--- Generation {GENERATION} ---\n')
Expand Down
7 changes: 1 addition & 6 deletions python/inference_example.py
Expand Up @@ -25,12 +25,7 @@
prompt_tokens: List[int] = tokenizer_encode(prompt)

# Process the prompt.
init_logits, init_state = None, None

for token in prompt_tokens:
init_logits, init_state = model.eval(token, init_state, init_state, init_logits, use_numpy=True)

logits, state = init_logits.copy(), init_state.copy()
logits, state = model.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True)

# Generate and print the completion.
print(prompt, end='')
Expand Down
75 changes: 75 additions & 0 deletions python/rwkv_cpp/rwkv_cpp_model.py
Expand Up @@ -238,6 +238,81 @@ def eval_sequence(

return logits_out, state_out

def eval_sequence_in_chunks(
self,
tokens: List[int],
state_in: Optional[NumpyArrayOrPyTorchTensor],
state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
chunk_size: int = 16,
use_numpy: bool = False
) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
"""
Evaluates the model for a sequence of tokens using `eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
It is recommended to use this function instead of `eval_sequence` to avoid mistakes and get maximum performance.
Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
and choose one that works the best in your use case.
In case of any error, this method will throw an exception.
Parameters
----------
tokens : List[int]
Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
chunk_size : int
Size of each chunk in tokens, must be positive.
state_in : Optional[NumpyArrayOrTorchTensor]
State from previous call of this method. If this is a first pass, set it to None.
state_out : Optional[NumpyArrayOrTorchTensor]
Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
logits_out : Optional[NumpyArrayOrTorchTensor]
Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
use_numpy : bool
If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
This parameter is ignored if any tensor parameter is not None; in such case,
type of returned tensors will match the type of received tensors.
Returns
-------
logits, state
Logits vector of shape (n_vocab); state for the next step.
"""

assert self._valid, 'Model was freed'

use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)

if state_in is not None:
self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)

state_in_ptr = self._get_data_ptr(state_in)
else:
state_in_ptr = 0

if state_out is not None:
self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
else:
state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)

if logits_out is not None:
self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
else:
logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)

self._library.rwkv_eval_sequence_in_chunks(
self._ctx,
tokens,
chunk_size,
state_in_ptr,
self._get_data_ptr(state_out),
self._get_data_ptr(logits_out)
)

return logits_out, state_out

def free(self) -> None:
"""
Frees all allocated resources.
Expand Down
80 changes: 71 additions & 9 deletions python/rwkv_cpp/rwkv_cpp_shared_library.py
Expand Up @@ -63,6 +63,17 @@ def __init__(self, shared_library_path: str) -> None:
]
self.library.rwkv_eval_sequence.restype = ctypes.c_bool

self.library.rwkv_eval_sequence_in_chunks.argtypes = [
ctypes.c_void_p, # ctx
P_INT, # tokens
ctypes.c_size_t, # token count
ctypes.c_size_t, # chunk size
P_FLOAT, # state_in
P_FLOAT, # state_out
P_FLOAT # logits_out
]
self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool

self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p]
self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t

Expand Down Expand Up @@ -142,6 +153,43 @@ def rwkv_eval(
) -> None:
"""
Evaluates the model for a single token.
Throws an exception in case of any error. Error messages would be printed to stderr.
Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
token : int
Next token index, in range 0 <= token < n_vocab.
state_in_address : int
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
state_out_address : int
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
logits_out_address : int
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
"""

assert self.library.rwkv_eval(
ctx.ptr,
ctypes.c_int32(token),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT)
), 'rwkv_eval failed, check stderr'

def rwkv_eval_sequence(
self,
ctx: RWKVContext,
tokens: List[int],
state_in_address: Optional[int],
state_out_address: int,
logits_out_address: int
) -> None:
"""
Evaluates the model for a sequence of tokens.
Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so.
Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
NOTE ON GGML NODE LIMIT
Expand All @@ -152,14 +200,15 @@ def rwkv_eval(
If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
To get rid of the assertion failure, reduce the model size and/or sequence length.
Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
Throws an exception in case of any error. Error messages would be printed to stderr.
Parameters
----------
ctx : RWKVContext
RWKV context obtained from rwkv_init_from_file.
token : int
Next token index, in range 0 <= token < n_vocab.
tokens : List[int]
Next token indices, in range 0 <= token < n_vocab.
state_in_address : int
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
state_out_address : int
Expand All @@ -168,24 +217,34 @@ def rwkv_eval(
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
"""

assert self.library.rwkv_eval(
assert self.library.rwkv_eval_sequence(
ctx.ptr,
ctypes.c_int32(token),
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
ctypes.c_size_t(len(tokens)),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT)
), 'rwkv_eval failed, check stderr'
), 'rwkv_eval_sequence failed, check stderr'

def rwkv_eval_sequence(
def rwkv_eval_sequence_in_chunks(
self,
ctx: RWKVContext,
tokens: List[int],
chunk_size: int,
state_in_address: Optional[int],
state_out_address: int,
logits_out_address: int
) -> None:
"""
Evaluates the model for a sequence of tokens.
Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance.
Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
and choose one that works the best in your use case.
Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
Throws an exception in case of any error. Error messages would be printed to stderr.
Parameters
Expand All @@ -194,6 +253,8 @@ def rwkv_eval_sequence(
RWKV context obtained from rwkv_init_from_file.
tokens : List[int]
Next token indices, in range 0 <= token < n_vocab.
chunk_size : int
Size of each chunk in tokens, must be positive.
state_in_address : int
Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass.
state_out_address : int
Expand All @@ -202,14 +263,15 @@ def rwkv_eval_sequence(
Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
"""

assert self.library.rwkv_eval_sequence(
assert self.library.rwkv_eval_sequence_in_chunks(
ctx.ptr,
ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT),
ctypes.c_size_t(len(tokens)),
ctypes.c_size_t(chunk_size),
ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT),
ctypes.cast(state_out_address, P_FLOAT),
ctypes.cast(logits_out_address, P_FLOAT)
), 'rwkv_eval failed, check stderr'
), 'rwkv_eval_sequence_in_chunks failed, check stderr'

def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int:
"""
Expand Down
30 changes: 28 additions & 2 deletions rwkv.h
Expand Up @@ -122,7 +122,7 @@ extern "C" {
);

// Evaluates the model for a sequence of tokens.
// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so.
// Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so.
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
//
// NOTE ON GGML NODE LIMIT
Expand All @@ -137,7 +137,7 @@ extern "C" {
// TODO When Metal (MPS) support is implemented, check that large sequence lengths work
//
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10 ms per iteration, because logits are not calculated.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
// Returns false on any error.
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
// - sequence_len: number of tokens to read from the array.
Expand All @@ -153,6 +153,32 @@ extern "C" {
float * logits_out
);

// Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
// This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
// It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance.
//
// Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
// A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
// and choose one that works the best in your use case.
//
// Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
// Returns false on any error.
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
// - sequence_len: number of tokens to read from the array.
// - chunk_size: size of each chunk in tokens, must be positive.
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
RWKV_API bool rwkv_eval_sequence_in_chunks(
struct rwkv_context * ctx,
const uint32_t * tokens,
const size_t sequence_len,
const size_t chunk_size,
const float * state_in,
float * state_out,
float * logits_out
);

// Returns the number of tokens in the given model's vocabulary.
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx);
Expand Down

0 comments on commit 3caaa1d

Please sign in to comment.