diff --git a/python/chat_with_bot.py b/python/chat_with_bot.py index e6800ad..b7a630a 100644 --- a/python/chat_with_bot.py +++ b/python/chat_with_bot.py @@ -69,10 +69,9 @@ def process_tokens(_tokens: List[int], new_line_logit_bias: float = 0.0) -> None: global processed_tokens, logits, state - processed_tokens += _tokens + logits, state = model.eval_sequence_in_chunks(_tokens, state, state, logits, use_numpy=True) - for _token in _tokens: - logits, state = model.eval(_token, state, state, logits, use_numpy=True) + processed_tokens += _tokens logits[END_OF_LINE_TOKEN] += new_line_logit_bias diff --git a/python/generate_completions.py b/python/generate_completions.py index 38a9868..6f585ed 100644 --- a/python/generate_completions.py +++ b/python/generate_completions.py @@ -47,10 +47,7 @@ prompt_token_count: int = len(prompt_tokens) print(f'{prompt_token_count} tokens in prompt') -init_logits, init_state = None, None - -for token in prompt_tokens: - init_logits, init_state = model.eval(token, init_state, init_state, init_logits, use_numpy=True) +init_logits, init_state = model.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True) for GENERATION in range(generation_count): print(f'\n--- Generation {GENERATION} ---\n') diff --git a/python/inference_example.py b/python/inference_example.py index 99b04fe..efd0016 100644 --- a/python/inference_example.py +++ b/python/inference_example.py @@ -25,12 +25,7 @@ prompt_tokens: List[int] = tokenizer_encode(prompt) # Process the prompt. -init_logits, init_state = None, None - -for token in prompt_tokens: - init_logits, init_state = model.eval(token, init_state, init_state, init_logits, use_numpy=True) - -logits, state = init_logits.copy(), init_state.copy() +logits, state = model.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True) # Generate and print the completion. print(prompt, end='') diff --git a/python/rwkv_cpp/rwkv_cpp_model.py b/python/rwkv_cpp/rwkv_cpp_model.py index 642bcd9..4b78c76 100644 --- a/python/rwkv_cpp/rwkv_cpp_model.py +++ b/python/rwkv_cpp/rwkv_cpp_model.py @@ -238,6 +238,81 @@ def eval_sequence( return logits_out, state_out + def eval_sequence_in_chunks( + self, + tokens: List[int], + state_in: Optional[NumpyArrayOrPyTorchTensor], + state_out: Optional[NumpyArrayOrPyTorchTensor] = None, + logits_out: Optional[NumpyArrayOrPyTorchTensor] = None, + chunk_size: int = 16, + use_numpy: bool = False + ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]: + """ + Evaluates the model for a sequence of tokens using `eval_sequence`, splitting a potentially long sequence into fixed-length chunks. + This function is useful for processing complete prompts and user input in chat & role-playing use-cases. + It is recommended to use this function instead of `eval_sequence` to avoid mistakes and get maximum performance. + + Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory. + A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64] + and choose one that works the best in your use case. + + In case of any error, this method will throw an exception. + + Parameters + ---------- + tokens : List[int] + Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab. + chunk_size : int + Size of each chunk in tokens, must be positive. + state_in : Optional[NumpyArrayOrTorchTensor] + State from previous call of this method. If this is a first pass, set it to None. + state_out : Optional[NumpyArrayOrTorchTensor] + Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). + logits_out : Optional[NumpyArrayOrTorchTensor] + Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). + use_numpy : bool + If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors. + This parameter is ignored if any tensor parameter is not None; in such case, + type of returned tensors will match the type of received tensors. + + Returns + ------- + logits, state + Logits vector of shape (n_vocab); state for the next step. + """ + + assert self._valid, 'Model was freed' + + use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) + + if state_in is not None: + self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count) + + state_in_ptr = self._get_data_ptr(state_in) + else: + state_in_ptr = 0 + + if state_out is not None: + self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count) + else: + state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy) + + if logits_out is not None: + self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) + else: + logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy) + + self._library.rwkv_eval_sequence_in_chunks( + self._ctx, + tokens, + chunk_size, + state_in_ptr, + self._get_data_ptr(state_out), + self._get_data_ptr(logits_out) + ) + + return logits_out, state_out + def free(self) -> None: """ Frees all allocated resources. diff --git a/python/rwkv_cpp/rwkv_cpp_shared_library.py b/python/rwkv_cpp/rwkv_cpp_shared_library.py index 12625c0..3b94fde 100644 --- a/python/rwkv_cpp/rwkv_cpp_shared_library.py +++ b/python/rwkv_cpp/rwkv_cpp_shared_library.py @@ -63,6 +63,17 @@ def __init__(self, shared_library_path: str) -> None: ] self.library.rwkv_eval_sequence.restype = ctypes.c_bool + self.library.rwkv_eval_sequence_in_chunks.argtypes = [ + ctypes.c_void_p, # ctx + P_INT, # tokens + ctypes.c_size_t, # token count + ctypes.c_size_t, # chunk size + P_FLOAT, # state_in + P_FLOAT, # state_out + P_FLOAT # logits_out + ] + self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool + self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p] self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t @@ -142,6 +153,43 @@ def rwkv_eval( ) -> None: """ Evaluates the model for a single token. + Throws an exception in case of any error. Error messages would be printed to stderr. + Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + token : int + Next token index, in range 0 <= token < n_vocab. + state_in_address : int + Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass. + state_out_address : int + Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. + logits_out_address : int + Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. + """ + + assert self.library.rwkv_eval( + ctx.ptr, + ctypes.c_int32(token), + ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), + ctypes.cast(state_out_address, P_FLOAT), + ctypes.cast(logits_out_address, P_FLOAT) + ), 'rwkv_eval failed, check stderr' + + def rwkv_eval_sequence( + self, + ctx: RWKVContext, + tokens: List[int], + state_in_address: Optional[int], + state_out_address: int, + logits_out_address: int + ) -> None: + """ + Evaluates the model for a sequence of tokens. + Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so. + Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length. NOTE ON GGML NODE LIMIT @@ -152,14 +200,15 @@ def rwkv_eval( If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. To get rid of the assertion failure, reduce the model size and/or sequence length. + Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread. Throws an exception in case of any error. Error messages would be printed to stderr. Parameters ---------- ctx : RWKVContext RWKV context obtained from rwkv_init_from_file. - token : int - Next token index, in range 0 <= token < n_vocab. + tokens : List[int] + Next token indices, in range 0 <= token < n_vocab. state_in_address : int Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass. state_out_address : int @@ -168,24 +217,34 @@ def rwkv_eval( Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. """ - assert self.library.rwkv_eval( + assert self.library.rwkv_eval_sequence( ctx.ptr, - ctypes.c_int32(token), + ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT), + ctypes.c_size_t(len(tokens)), ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), ctypes.cast(state_out_address, P_FLOAT), ctypes.cast(logits_out_address, P_FLOAT) - ), 'rwkv_eval failed, check stderr' + ), 'rwkv_eval_sequence failed, check stderr' - def rwkv_eval_sequence( + def rwkv_eval_sequence_in_chunks( self, ctx: RWKVContext, tokens: List[int], + chunk_size: int, state_in_address: Optional[int], state_out_address: int, logits_out_address: int ) -> None: """ - Evaluates the model for a sequence of tokens. + Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks. + This function is useful for processing complete prompts and user input in chat & role-playing use-cases. + It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance. + + Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory. + A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64] + and choose one that works the best in your use case. + + Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread. Throws an exception in case of any error. Error messages would be printed to stderr. Parameters @@ -194,6 +253,8 @@ def rwkv_eval_sequence( RWKV context obtained from rwkv_init_from_file. tokens : List[int] Next token indices, in range 0 <= token < n_vocab. + chunk_size : int + Size of each chunk in tokens, must be positive. state_in_address : int Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass. state_out_address : int @@ -202,14 +263,15 @@ def rwkv_eval_sequence( Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. """ - assert self.library.rwkv_eval_sequence( + assert self.library.rwkv_eval_sequence_in_chunks( ctx.ptr, ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT), ctypes.c_size_t(len(tokens)), + ctypes.c_size_t(chunk_size), ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), ctypes.cast(state_out_address, P_FLOAT), ctypes.cast(logits_out_address, P_FLOAT) - ), 'rwkv_eval failed, check stderr' + ), 'rwkv_eval_sequence_in_chunks failed, check stderr' def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int: """ diff --git a/rwkv.h b/rwkv.h index 884e1aa..40b9266 100644 --- a/rwkv.h +++ b/rwkv.h @@ -122,7 +122,7 @@ extern "C" { ); // Evaluates the model for a sequence of tokens. - // Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so. + // Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so. // Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length. // // NOTE ON GGML NODE LIMIT @@ -137,7 +137,7 @@ extern "C" { // TODO When Metal (MPS) support is implemented, check that large sequence lengths work // // You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10 ms per iteration, because logits are not calculated. - // Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. + // Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread. // Returns false on any error. // - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization. // - sequence_len: number of tokens to read from the array. @@ -153,6 +153,32 @@ extern "C" { float * logits_out ); + // Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks. + // This function is useful for processing complete prompts and user input in chat & role-playing use-cases. + // It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance. + // + // Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory. + // A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64] + // and choose one that works the best in your use case. + // + // Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread. + // Returns false on any error. + // - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization. + // - sequence_len: number of tokens to read from the array. + // - chunk_size: size of each chunk in tokens, must be positive. + // - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass. + // - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. + // - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. + RWKV_API bool rwkv_eval_sequence_in_chunks( + struct rwkv_context * ctx, + const uint32_t * tokens, + const size_t sequence_len, + const size_t chunk_size, + const float * state_in, + float * state_out, + float * logits_out + ); + // Returns the number of tokens in the given model's vocabulary. // Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536). RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx); diff --git a/rwkv_eval.inc b/rwkv_eval.inc index 38bbb33..37f3a9c 100644 --- a/rwkv_eval.inc +++ b/rwkv_eval.inc @@ -69,6 +69,17 @@ bool rwkv_eval_sequence( RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, sequence_len > 0, "Sequence length is 0"); + if (sequence_len == 1) { + // Avoid building single-token sequence graph, we already have regular eval for this. + return rwkv_eval( + ctx, + sequence[0], + state_in, + state_out, + logits_out + ); + } + if (sequence) { const size_t n_vocab = ctx->model->header.n_vocab; @@ -97,6 +108,72 @@ bool rwkv_eval_sequence( return true; } +// API function. +bool rwkv_eval_sequence_in_chunks( + struct rwkv_context * ctx, + const uint32_t * tokens, + const size_t sequence_len, + const size_t chunk_size, + const float * state_in, + float * state_out, + float * logits_out +) { + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, sequence_len > 0, "Sequence length is 0"); + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, chunk_size > 0, "Chunk size is 0"); + + // Will be de-allocated automatically on return. + std::unique_ptr state{ new(std::nothrow) float[rwkv_get_state_len(ctx)] }; + + if (state_in != NULL) { + memcpy(state.get(), state_in, rwkv_get_state_len(ctx) * sizeof(float)); + } else { + rwkv_init_state(ctx, state.get()); + } + + size_t chunk_count = sequence_len / chunk_size; + size_t remainder = sequence_len % chunk_size; + uint32_t * tokens_offset = (uint32_t *) tokens; + + for (size_t c = 0; c < chunk_count; c++) { + bool is_last_eval = c == chunk_count - 1 && remainder == 0; + + bool result = rwkv_eval_sequence( + ctx, + tokens_offset, + chunk_size, + state.get(), + // On the last eval call, copy the state into the user-provided buffer. + is_last_eval ? state_out : state.get(), + // If this is not the last call, we don't have the use for logits and can skip their calculation. + is_last_eval ? logits_out : NULL + ); + + if (!result) { + return false; + } + + tokens_offset += chunk_size; + } + + if (remainder > 0) { + bool result = rwkv_eval_sequence( + ctx, + tokens_offset, + remainder, + state.get(), + // This eval call is always the last. + state_out, + logits_out + ); + + if (!result) { + return false; + } + } + + return true; +} + // API function. void rwkv_init_state(const struct rwkv_context * ctx, float * state) { const struct rwkv_file_header & header = ctx->model->header; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8705d94..cb32fcd 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -24,4 +24,5 @@ rwkv_add_test(test_quantized_matmul_on_gpu.c) rwkv_add_test(test_tiny_rwkv.c) rwkv_add_test(test_quantization_format_compatibility.c) rwkv_add_test(test_logit_calculation_skipping.c) +rwkv_add_test(test_eval_sequence_in_chunks.c) rwkv_add_test(test_context_cloning.c) diff --git a/tests/test_eval_sequence_in_chunks.c b/tests/test_eval_sequence_in_chunks.c new file mode 100644 index 0000000..7c3b10a --- /dev/null +++ b/tests/test_eval_sequence_in_chunks.c @@ -0,0 +1,76 @@ +// Tests that eval_sequence_in_chunks gives results equivalent to serial eval. +#include +#include +#include + +#include + +#include "assertions.inc" + +void test_on_prompt(const char * prompt, const size_t prompt_length) { + fprintf(stderr, "Calculating expected state and logits for prompt of size %zd\n", prompt_length); + + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); + + ASSERT(ctx != NULL, "Unexpected error 0x%.8X", rwkv_get_last_error(NULL)); + + float * expected_state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + float * expected_logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); + + ASSERT(expected_state != NULL, "Failed to allocate state"); + ASSERT(expected_logits != NULL, "Failed to allocate logits"); + + rwkv_eval(ctx, prompt[0], NULL, expected_state, expected_logits); + + for (size_t i = 1; prompt[i] != 0; i++) { + rwkv_eval(ctx, prompt[i], expected_state, expected_state, expected_logits); + } + + // --- + + uint32_t * prompt_tokens = calloc(prompt_length, sizeof(uint32_t)); + + for (int i = 0; i < prompt_length; i++) { + prompt_tokens[i] = prompt[i]; + } + + // --- + + float * state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); + + ASSERT(state != NULL, "Failed to allocate state"); + ASSERT(logits != NULL, "Failed to allocate logits"); + + const size_t chunk_sizes[4] = {1, 2, 8, 10}; + + for (int i = 0; i < 4; i++) { + size_t chunk_size = chunk_sizes[i]; + + fprintf(stderr, "Testing chunk_size = %zd\n", chunk_size); + + rwkv_eval_sequence_in_chunks(ctx, prompt_tokens, prompt_length, chunk_size, NULL, state, logits); + + ASSERT(memcmp(expected_state, state, rwkv_get_state_len(ctx) * sizeof(float)) == 0, "Results are not identical"); + } + + // --- + + rwkv_free(ctx); + + free(logits); + free(state); + free(expected_logits); + free(expected_state); + free(prompt_tokens); +} + +int main(void) { + const char prompt1[70 + 1] = "This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM"; + test_on_prompt(prompt1, 70); + + const char prompt2[1 + 1] = "T"; + test_on_prompt(prompt2, 1); + + return 0; +}