From 39ed572ef506d07f397a44309c86bc17272b4a8d Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 23 Sep 2023 18:18:32 +0500 Subject: [PATCH] Various improvements (#131) * Implement model head offloading * Guess the tokenizer from n_vocab * Make PyTorch optional for inference * Add function to offload layers * Add rwkv_eval_sequence_in_chunks --- README.md | 2 +- python/chat_with_bot.py | 18 +- python/generate_completions.py | 21 +- python/inference_example.py | 15 +- python/measure_pexplexity.py | 18 +- python/rwkv_cpp/rwkv_cpp_model.py | 216 +++++++++++++++++---- python/rwkv_cpp/rwkv_cpp_shared_library.py | 149 +++++++++++--- python/sampling.py | 15 +- python/tokenizer_util.py | 29 ++- rwkv.h | 39 +++- rwkv_eval.inc | 77 ++++++++ rwkv_gpu_offload.inc | 9 +- rwkv_model_loading.inc | 2 + tests/CMakeLists.txt | 1 + tests/logit_difference_validator.inc | 2 +- tests/test_eval_sequence_in_chunks.c | 76 ++++++++ 16 files changed, 567 insertions(+), 122 deletions(-) create mode 100644 tests/test_eval_sequence_in_chunks.c diff --git a/README.md b/README.md index 80be95e..7561380 100644 --- a/README.md +++ b/README.md @@ -146,7 +146,7 @@ python python/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169 #### Using the command line -**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/). +**Requirements**: Python 3.x with [numpy](https://numpy.org/). If using `Pile` or `Raven` models, [tokenizers](https://pypi.org/project/tokenizers/) is also required. To generate some text, run: diff --git a/python/chat_with_bot.py b/python/chat_with_bot.py index af2b9af..b7a630a 100644 --- a/python/chat_with_bot.py +++ b/python/chat_with_bot.py @@ -8,10 +8,9 @@ import copy import json import time -import torch import sampling from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model -from tokenizer_util import get_tokenizer +from tokenizer_util import add_tokenizer_argument, get_tokenizer from typing import List, Dict, Optional # ======================================== Script settings ======================================== @@ -41,7 +40,7 @@ parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model') parser.add_argument('model_path', help='Path to RWKV model in ggml format') -parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B') +add_tokenizer_argument(parser) args = parser.parse_args() script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent @@ -53,27 +52,26 @@ assert init_prompt != '', 'Prompt must not be empty' -tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer) - library = rwkv_cpp_shared_library.load_rwkv_shared_library() print(f'System info: {library.rwkv_get_system_info_string()}') print('Loading RWKV model') model = rwkv_cpp_model.RWKVModel(library, args.model_path) +tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab) + # ================================================================================================= processed_tokens: List[int] = [] -logits: Optional[torch.Tensor] = None -state: Optional[torch.Tensor] = None +logits: Optional[rwkv_cpp_model.NumpyArrayOrPyTorchTensor] = None +state: Optional[rwkv_cpp_model.NumpyArrayOrPyTorchTensor] = None def process_tokens(_tokens: List[int], new_line_logit_bias: float = 0.0) -> None: global processed_tokens, logits, state - processed_tokens += _tokens + logits, state = model.eval_sequence_in_chunks(_tokens, state, state, logits, use_numpy=True) - for _token in _tokens: - logits, state = model.eval(_token, state, state, logits) + processed_tokens += _tokens logits[END_OF_LINE_TOKEN] += new_line_logit_bias diff --git a/python/generate_completions.py b/python/generate_completions.py index 89a2e02..6f585ed 100644 --- a/python/generate_completions.py +++ b/python/generate_completions.py @@ -5,7 +5,7 @@ import time import sampling from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model -from tokenizer_util import get_tokenizer +from tokenizer_util import add_tokenizer_argument, get_tokenizer from typing import List # ======================================== Script settings ======================================== @@ -29,28 +29,25 @@ parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt') parser.add_argument('model_path', help='Path to RWKV model in ggml format') -parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B') +add_tokenizer_argument(parser) args = parser.parse_args() assert prompt != '', 'Prompt must not be empty' -tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer) - -prompt_tokens: List[int] = tokenizer_encode(prompt) - library = rwkv_cpp_shared_library.load_rwkv_shared_library() print(f'System info: {library.rwkv_get_system_info_string()}') print('Loading RWKV model') model = rwkv_cpp_model.RWKVModel(library, args.model_path) +tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab) + +prompt_tokens: List[int] = tokenizer_encode(prompt) + prompt_token_count: int = len(prompt_tokens) print(f'{prompt_token_count} tokens in prompt') -init_logits, init_state = None, None - -for token in prompt_tokens: - init_logits, init_state = model.eval(token, init_state, init_state, init_logits) +init_logits, init_state = model.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True) for GENERATION in range(generation_count): print(f'\n--- Generation {GENERATION} ---\n') @@ -58,14 +55,14 @@ start: float = time.time() - logits, state = init_logits.clone(), init_state.clone() + logits, state = init_logits.copy(), init_state.copy() for i in range(tokens_per_generation): token: int = sampling.sample_logits(logits, temperature, top_p) print(tokenizer_decode([token]), end='', flush=True) - logits, state = model.eval(token, state, state, logits) + logits, state = model.eval(token, state, state, logits, use_numpy=True) delay: float = time.time() - start diff --git a/python/inference_example.py b/python/inference_example.py index 116694c..efd0016 100644 --- a/python/inference_example.py +++ b/python/inference_example.py @@ -4,13 +4,13 @@ import argparse import sampling from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model -from tokenizer_util import get_tokenizer +from tokenizer_util import add_tokenizer_argument, get_tokenizer from typing import List # Parse received arguments. parser = argparse.ArgumentParser(description='Generate some text with an RWKV model') parser.add_argument('model_path', help='Path to RWKV model in ggml format') -parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B') +add_tokenizer_argument(parser) args = parser.parse_args() # Load the model. @@ -18,19 +18,14 @@ model = rwkv_cpp_model.RWKVModel(library, args.model_path) # Set up the tokenizer. -tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer) +tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab) # Prepare the prompt. prompt: str = """One upon a time,""" prompt_tokens: List[int] = tokenizer_encode(prompt) # Process the prompt. -init_logits, init_state = None, None - -for token in prompt_tokens: - init_logits, init_state = model.eval(token, init_state, init_state, init_logits) - -logits, state = init_logits.clone(), init_state.clone() +logits, state = model.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True) # Generate and print the completion. print(prompt, end='') @@ -40,7 +35,7 @@ print(tokenizer_decode([token]), end='', flush=True) - logits, state = model.eval(token, state, state, logits) + logits, state = model.eval(token, state, state, logits, use_numpy=True) # Don't forget to free the memory after you are done working with the model! model.free() diff --git a/python/measure_pexplexity.py b/python/measure_pexplexity.py index f483277..6f6674c 100644 --- a/python/measure_pexplexity.py +++ b/python/measure_pexplexity.py @@ -5,9 +5,10 @@ import os import time import argparse +# TODO Get rid of this PyTorch dependency by writing a cross_entropy impl for numpy import torch from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model -from tokenizer_util import get_tokenizer +from tokenizer_util import add_tokenizer_argument, get_tokenizer from typing import List def parse_args(): @@ -16,15 +17,21 @@ def parse_args(): parser.add_argument('text_path', help='Path to text file in UTF-8 encoding', type=str) parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int) parser.add_argument('token_limit', help='How many tokens to process; set to -1 to process all text', nargs='?', type=int, default=-1) - parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B') + add_tokenizer_argument(parser) return parser.parse_args() args = parse_args() +print('Loading model') +model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel( + rwkv_cpp_shared_library.load_rwkv_shared_library(), + args.model_path +) + print('Loading text') text: str = open(args.text_path, encoding='utf-8').read() -_, tokenizer_encode = get_tokenizer(args.tokenizer) +_, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab) tokens: List[int] = tokenizer_encode(text) @@ -52,11 +59,6 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str: # --- -model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel( - rwkv_cpp_shared_library.load_rwkv_shared_library(), - args.model_path -) - logits, state = None, None loss_sum: torch.Tensor = torch.tensor([0.0]) diff --git a/python/rwkv_cpp/rwkv_cpp_model.py b/python/rwkv_cpp/rwkv_cpp_model.py index ae3b0d5..4b78c76 100644 --- a/python/rwkv_cpp/rwkv_cpp_model.py +++ b/python/rwkv_cpp/rwkv_cpp_model.py @@ -1,18 +1,27 @@ import os -import torch import multiprocessing +# Pre-import PyTorch, if available. +# This fixes "OSError: [WinError 127] The specified procedure could not be found". +try: + import torch +except ModuleNotFoundError: + pass + # I'm sure this is not strictly correct, but let's keep this crutch for now. try: import rwkv_cpp_shared_library except ModuleNotFoundError: from . import rwkv_cpp_shared_library -from typing import Tuple, Optional, List +from typing import TypeVar, Optional, Tuple, List + +# A value of this type is either a numpy's ndarray or a PyTorch's Tensor. +NumpyArrayOrPyTorchTensor: TypeVar = TypeVar('NumpyArrayOrPyTorchTensor') class RWKVModel: """ - PyTorch wrapper around rwkv.cpp model. + An RWKV model managed by rwkv.cpp library. """ def __init__( @@ -37,6 +46,7 @@ def __init__( Thread count to use. If not set, defaults to CPU count / 2. gpu_layer_count : int Count of layers to offload onto the GPU, must be >= 0. + See documentation of `gpu_offload_layers` for details about layer offloading. """ if 'gpu_layers_count' in kwargs: @@ -51,13 +61,33 @@ def __init__( self._ctx: rwkv_cpp_shared_library.RWKVContext = self._library.rwkv_init_from_file(model_path, thread_count) if gpu_layer_count > 0: - self._library.rwkv_gpu_offload_layers(self._ctx, gpu_layer_count) + self.gpu_offload_layers(gpu_layer_count) self._state_buffer_element_count: int = self._library.rwkv_get_state_buffer_element_count(self._ctx) self._logits_buffer_element_count: int = self._library.rwkv_get_logits_buffer_element_count(self._ctx) self._valid: bool = True + def gpu_offload_layers(self, layer_count: int) -> bool: + """ + Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast. + For the purposes of this function, model head (unembedding matrix) is treated as an additional layer: + - pass `model.n_layer` to offload all layers except model head + - pass `model.n_layer + 1` to offload all layers, including model head + + Returns true if at least one layer was offloaded. + If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false. + + Parameters + ---------- + layer_count : int + Count of layers to offload onto the GPU, must be >= 0. + """ + + assert layer_count >= 0, 'Layer count must be >= 0' + + return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count) + @property def n_vocab(self) -> int: return self._library.rwkv_get_n_vocab(self._ctx) @@ -73,10 +103,11 @@ def n_layer(self) -> int: def eval( self, token: int, - state_in: Optional[torch.Tensor], - state_out: Optional[torch.Tensor] = None, - logits_out: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: + state_in: Optional[NumpyArrayOrPyTorchTensor], + state_out: Optional[NumpyArrayOrPyTorchTensor] = None, + logits_out: Optional[NumpyArrayOrPyTorchTensor] = None, + use_numpy: bool = False + ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]: """ Evaluates the model for a single token. In case of any error, this method will throw an exception. @@ -85,12 +116,16 @@ def eval( ---------- token : int Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab. - state_in : Optional[torch.Tensor] + state_in : Optional[NumpyArrayOrTorchTensor] State from previous call of this method. If this is a first pass, set it to None. - state_out : Optional[torch.Tensor] + state_out : Optional[NumpyArrayOrTorchTensor] Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). - logits_out : Optional[torch.Tensor] + logits_out : Optional[NumpyArrayOrTorchTensor] Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). + use_numpy : bool + If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors. + This parameter is ignored if any tensor parameter is not None; in such case, + type of returned tensors will match the type of received tensors. Returns ------- @@ -100,29 +135,31 @@ def eval( assert self._valid, 'Model was freed' + use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) + if state_in is not None: self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count) - state_in_ptr = state_in.data_ptr() + state_in_ptr = self._get_data_ptr(state_in) else: state_in_ptr = 0 if state_out is not None: self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count) else: - state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu') + state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy) if logits_out is not None: self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) else: - logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu') + logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy) self._library.rwkv_eval( self._ctx, token, state_in_ptr, - state_out.data_ptr(), - logits_out.data_ptr() + self._get_data_ptr(state_out), + self._get_data_ptr(logits_out) ) return logits_out, state_out @@ -130,10 +167,11 @@ def eval( def eval_sequence( self, tokens: List[int], - state_in: Optional[torch.Tensor], - state_out: Optional[torch.Tensor] = None, - logits_out: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: + state_in: Optional[NumpyArrayOrPyTorchTensor], + state_out: Optional[NumpyArrayOrPyTorchTensor] = None, + logits_out: Optional[NumpyArrayOrPyTorchTensor] = None, + use_numpy: bool = False + ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]: """ Evaluates the model for a sequence of tokens. @@ -152,12 +190,16 @@ def eval_sequence( ---------- tokens : List[int] Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab. - state_in : Optional[torch.Tensor] + state_in : Optional[NumpyArrayOrTorchTensor] State from previous call of this method. If this is a first pass, set it to None. - state_out : Optional[torch.Tensor] + state_out : Optional[NumpyArrayOrTorchTensor] Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). - logits_out : Optional[torch.Tensor] + logits_out : Optional[NumpyArrayOrTorchTensor] Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). + use_numpy : bool + If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors. + This parameter is ignored if any tensor parameter is not None; in such case, + type of returned tensors will match the type of received tensors. Returns ------- @@ -167,29 +209,106 @@ def eval_sequence( assert self._valid, 'Model was freed' + use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) + if state_in is not None: self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count) - state_in_ptr = state_in.data_ptr() + state_in_ptr = self._get_data_ptr(state_in) else: state_in_ptr = 0 if state_out is not None: self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count) else: - state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu') + state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy) if logits_out is not None: self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) else: - logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu') + logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy) self._library.rwkv_eval_sequence( self._ctx, tokens, state_in_ptr, - state_out.data_ptr(), - logits_out.data_ptr() + self._get_data_ptr(state_out), + self._get_data_ptr(logits_out) + ) + + return logits_out, state_out + + def eval_sequence_in_chunks( + self, + tokens: List[int], + state_in: Optional[NumpyArrayOrPyTorchTensor], + state_out: Optional[NumpyArrayOrPyTorchTensor] = None, + logits_out: Optional[NumpyArrayOrPyTorchTensor] = None, + chunk_size: int = 16, + use_numpy: bool = False + ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]: + """ + Evaluates the model for a sequence of tokens using `eval_sequence`, splitting a potentially long sequence into fixed-length chunks. + This function is useful for processing complete prompts and user input in chat & role-playing use-cases. + It is recommended to use this function instead of `eval_sequence` to avoid mistakes and get maximum performance. + + Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory. + A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64] + and choose one that works the best in your use case. + + In case of any error, this method will throw an exception. + + Parameters + ---------- + tokens : List[int] + Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab. + chunk_size : int + Size of each chunk in tokens, must be positive. + state_in : Optional[NumpyArrayOrTorchTensor] + State from previous call of this method. If this is a first pass, set it to None. + state_out : Optional[NumpyArrayOrTorchTensor] + Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). + logits_out : Optional[NumpyArrayOrTorchTensor] + Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). + use_numpy : bool + If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors. + This parameter is ignored if any tensor parameter is not None; in such case, + type of returned tensors will match the type of received tensors. + + Returns + ------- + logits, state + Logits vector of shape (n_vocab); state for the next step. + """ + + assert self._valid, 'Model was freed' + + use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) + + if state_in is not None: + self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count) + + state_in_ptr = self._get_data_ptr(state_in) + else: + state_in_ptr = 0 + + if state_out is not None: + self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count) + else: + state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy) + + if logits_out is not None: + self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) + else: + logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy) + + self._library.rwkv_eval_sequence_in_chunks( + self._ctx, + tokens, + chunk_size, + state_in_ptr, + self._get_data_ptr(state_out), + self._get_data_ptr(logits_out) ) return logits_out, state_out @@ -212,8 +331,39 @@ def __del__(self) -> None: if hasattr(self, '_valid') and self._valid: self.free() - def _validate_tensor(self, buf: torch.Tensor, name: str, size: int) -> None: - assert buf.device == torch.device('cpu'), f'{name} is not on CPU' - assert buf.dtype == torch.float32, f'{name} is not of type float32' - assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' - assert buf.is_contiguous(), f'{name} is not contiguous' + def _is_pytorch_tensor(self, tensor: NumpyArrayOrPyTorchTensor) -> bool: + return hasattr(tensor, '__module__') and tensor.__module__ == 'torch' + + def _detect_numpy_usage(self, tensors: List[Optional[NumpyArrayOrPyTorchTensor]], use_numpy_by_default: bool) -> bool: + for tensor in tensors: + if tensor is not None: + return False if self._is_pytorch_tensor(tensor) else True + + return use_numpy_by_default + + def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None: + if self._is_pytorch_tensor(tensor): + tensor: torch.Tensor = tensor + assert tensor.device == torch.device('cpu'), f'{name} is not on CPU' + assert tensor.dtype == torch.float32, f'{name} is not of type float32' + assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})' + assert tensor.is_contiguous(), f'{name} is not contiguous' + else: + import numpy as np + tensor: np.ndarray = tensor + assert tensor.dtype == np.float32, f'{name} is not of type float32' + assert tensor.shape == (size,), f'{name} has invalid shape {tensor.shape}, expected ({size})' + assert tensor.data.contiguous, f'{name} is not contiguous' + + def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor): + if self._is_pytorch_tensor(tensor): + return tensor.data_ptr() + else: + return tensor.ctypes.data + + def _zeros_float32(self, element_count: int, use_numpy: bool) -> NumpyArrayOrPyTorchTensor: + if use_numpy: + import numpy as np + return np.zeros(element_count, dtype=np.float32) + else: + return torch.zeros(element_count, dtype=torch.float32, device='cpu') diff --git a/python/rwkv_cpp/rwkv_cpp_shared_library.py b/python/rwkv_cpp/rwkv_cpp_shared_library.py index 951ed5b..3b94fde 100644 --- a/python/rwkv_cpp/rwkv_cpp_shared_library.py +++ b/python/rwkv_cpp/rwkv_cpp_shared_library.py @@ -63,6 +63,17 @@ def __init__(self, shared_library_path: str) -> None: ] self.library.rwkv_eval_sequence.restype = ctypes.c_bool + self.library.rwkv_eval_sequence_in_chunks.argtypes = [ + ctypes.c_void_p, # ctx + P_INT, # tokens + ctypes.c_size_t, # token count + ctypes.c_size_t, # chunk size + P_FLOAT, # state_in + P_FLOAT, # state_out + P_FLOAT # logits_out + ] + self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool + self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p] self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t @@ -113,9 +124,12 @@ def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVCo def rwkv_gpu_offload_layers(self, ctx: RWKVContext, layer_count: int) -> bool: """ - Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS. + Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast. + For the purposes of this function, model head (unembedding matrix) is treated as an additional layer: + - pass `rwkv_get_n_layer(ctx)` to offload all layers except model head + - pass `rwkv_get_n_layer(ctx) + 1` to offload all layers, including model head Returns true if at least one layer was offloaded. - If rwkv.cpp was compiled without cuBLAS support, this function is a no-op and always returns false. + If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false. Parameters ---------- @@ -139,6 +153,43 @@ def rwkv_eval( ) -> None: """ Evaluates the model for a single token. + Throws an exception in case of any error. Error messages would be printed to stderr. + Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + token : int + Next token index, in range 0 <= token < n_vocab. + state_in_address : int + Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass. + state_out_address : int + Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to. + logits_out_address : int + Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. + """ + + assert self.library.rwkv_eval( + ctx.ptr, + ctypes.c_int32(token), + ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), + ctypes.cast(state_out_address, P_FLOAT), + ctypes.cast(logits_out_address, P_FLOAT) + ), 'rwkv_eval failed, check stderr' + + def rwkv_eval_sequence( + self, + ctx: RWKVContext, + tokens: List[int], + state_in_address: Optional[int], + state_out_address: int, + logits_out_address: int + ) -> None: + """ + Evaluates the model for a sequence of tokens. + Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so. + Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length. NOTE ON GGML NODE LIMIT @@ -149,14 +200,15 @@ def rwkv_eval( If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. To get rid of the assertion failure, reduce the model size and/or sequence length. + Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread. Throws an exception in case of any error. Error messages would be printed to stderr. Parameters ---------- ctx : RWKVContext RWKV context obtained from rwkv_init_from_file. - token : int - Next token index, in range 0 <= token < n_vocab. + tokens : List[int] + Next token indices, in range 0 <= token < n_vocab. state_in_address : int Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass. state_out_address : int @@ -165,24 +217,34 @@ def rwkv_eval( Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. """ - assert self.library.rwkv_eval( + assert self.library.rwkv_eval_sequence( ctx.ptr, - ctypes.c_int32(token), + ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT), + ctypes.c_size_t(len(tokens)), ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), ctypes.cast(state_out_address, P_FLOAT), ctypes.cast(logits_out_address, P_FLOAT) - ), 'rwkv_eval failed, check stderr' + ), 'rwkv_eval_sequence failed, check stderr' - def rwkv_eval_sequence( + def rwkv_eval_sequence_in_chunks( self, ctx: RWKVContext, tokens: List[int], + chunk_size: int, state_in_address: Optional[int], state_out_address: int, logits_out_address: int ) -> None: """ - Evaluates the model for a sequence of tokens. + Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks. + This function is useful for processing complete prompts and user input in chat & role-playing use-cases. + It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance. + + Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory. + A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64] + and choose one that works the best in your use case. + + Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread. Throws an exception in case of any error. Error messages would be printed to stderr. Parameters @@ -191,6 +253,8 @@ def rwkv_eval_sequence( RWKV context obtained from rwkv_init_from_file. tokens : List[int] Next token indices, in range 0 <= token < n_vocab. + chunk_size : int + Size of each chunk in tokens, must be positive. state_in_address : int Address of the first element of a FP32 buffer of size rwkv_get_state_buffer_element_count; or None, if this is a first pass. state_out_address : int @@ -199,14 +263,56 @@ def rwkv_eval_sequence( Address of the first element of a FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to. """ - assert self.library.rwkv_eval_sequence( + assert self.library.rwkv_eval_sequence_in_chunks( ctx.ptr, ctypes.cast((ctypes.c_int32 * len(tokens))(*tokens), P_INT), ctypes.c_size_t(len(tokens)), + ctypes.c_size_t(chunk_size), ctypes.cast(0 if state_in_address is None else state_in_address, P_FLOAT), ctypes.cast(state_out_address, P_FLOAT), ctypes.cast(logits_out_address, P_FLOAT) - ), 'rwkv_eval failed, check stderr' + ), 'rwkv_eval_sequence_in_chunks failed, check stderr' + + def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int: + """ + Returns the number of tokens in the given model's vocabulary. + Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536). + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_n_vocab(ctx.ptr) + + def rwkv_get_n_embed(self, ctx: RWKVContext) -> int: + """ + Returns the number of elements in the given model's embedding. + Useful for reading individual fields of a model's hidden state. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_n_embed(ctx.ptr) + + def rwkv_get_n_layer(self, ctx: RWKVContext) -> int: + """ + Returns the number of layers in the given model. + A layer is a pair of RWKV and FFN operations, stacked multiple times throughout the model. + Embedding matrix and model head (unembedding matrix) are NOT counted in `n_layer`. + Useful for always offloading the entire model to GPU. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_n_layer(ctx.ptr) def rwkv_get_state_buffer_element_count(self, ctx: RWKVContext) -> int: """ @@ -276,27 +382,6 @@ def rwkv_get_system_info_string(self) -> str: return self.library.rwkv_get_system_info_string().decode('utf-8') - def rwkv_get_n_embed(self, ctx: RWKVContext) -> int: - """ - Returns the size of one embedding vector. - """ - - return self.library.rwkv_get_n_embed(ctx.ptr) - - def rwkv_get_n_layer(self, ctx: RWKVContext) -> int: - """ - Returns the number of layers. - """ - - return self.library.rwkv_get_n_layer(ctx.ptr) - - def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int: - """ - Returns vocab size. - """ - - return self.library.rwkv_get_n_vocab(ctx.ptr) - def load_rwkv_shared_library() -> RWKVSharedLibrary: """ Attempts to find rwkv.cpp shared library and load it. diff --git a/python/sampling.py b/python/sampling.py index 3d86e93..0736638 100644 --- a/python/sampling.py +++ b/python/sampling.py @@ -1,10 +1,17 @@ import numpy as np -import torch from typing import Dict -from torch.nn import functional as F -def sample_logits(out: torch.Tensor, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int: - probs: np.ndarray = F.softmax(out.cpu(), dim=-1).numpy() +# https://stackoverflow.com/a/50425683 +def softmax(x: np.ndarray, axis: int): + x -= x.max(axis=axis, keepdims=True) + e: np.ndarray = np.exp(x) + return e / e.sum(axis=axis, keepdims=True) + +def sample_logits(out, temperature: float = 1.0, top_p: float = 0.8, logit_bias: Dict[int, float] = None) -> int: + if hasattr(out, '__module__') and out.__module__ == 'torch': + out = out.cpu().numpy() + + probs: np.ndarray = softmax(out, axis=-1) return sample_probs(probs, temperature, top_p, logit_bias) diff --git a/python/tokenizer_util.py b/python/tokenizer_util.py index 1f5e06d..0471142 100644 --- a/python/tokenizer_util.py +++ b/python/tokenizer_util.py @@ -1,21 +1,38 @@ import os -import tokenizers import pathlib from rwkv_cpp import rwkv_world_tokenizer from typing import List, Tuple, Callable -def get_tokenizer(tokenizer: str = '20B') -> Tuple[ +def add_tokenizer_argument(parser) -> None: + parser.add_argument( + 'tokenizer', + help='Tokenizer to use; supported tokenizers: auto (guess from n_vocab), 20B, world', + nargs='?', + type=str, + default='auto' + ) + +def get_tokenizer(tokenizer_name: str, n_vocab: int) -> Tuple[ Callable[[List[int]], str], Callable[[str], List[int]] ]: + if tokenizer_name == 'auto': + if n_vocab == 50277: + tokenizer_name = '20B' + elif n_vocab == 65536: + tokenizer_name = 'world' + else: + assert False, f'Can not guess the tokenizer from n_vocab value of {n_vocab}' + parent: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent - if tokenizer == 'world': - print('Loading world tokenizer') + if tokenizer_name == 'world': + print('Loading World v20230424 tokenizer') return rwkv_world_tokenizer.get_world_tokenizer_v20230424() - elif tokenizer == '20B': + elif tokenizer_name == '20B': print('Loading 20B tokenizer') + import tokenizers tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(parent / '20B_tokenizer.json')) return tokenizer.decode, lambda x: tokenizer.encode(x).ids else: - assert False, f'Unknown tokenizer {tokenizer}' + assert False, f'Unknown tokenizer {tokenizer_name}' diff --git a/rwkv.h b/rwkv.h index 4b5ddee..40b9266 100644 --- a/rwkv.h +++ b/rwkv.h @@ -97,9 +97,12 @@ extern "C" { // - n_threads: count of threads to use, must be positive. RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads); - // Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS. + // Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast. + // For the purposes of this function, model head (unembedding matrix) is treated as an additional layer: + // - pass `rwkv_get_n_layer(ctx)` to offload all layers except model head + // - pass `rwkv_get_n_layer(ctx) + 1` to offload all layers, including model head // Returns true if at least one layer was offloaded. - // If rwkv.cpp was compiled without cuBLAS support, this function is a no-op and always returns false. + // If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false. RWKV_API bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers); // Evaluates the model for a single token. @@ -119,7 +122,7 @@ extern "C" { ); // Evaluates the model for a sequence of tokens. - // Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so. + // Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so. // Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length. // // NOTE ON GGML NODE LIMIT @@ -134,7 +137,7 @@ extern "C" { // TODO When Metal (MPS) support is implemented, check that large sequence lengths work // // You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10 ms per iteration, because logits are not calculated. - // Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread. + // Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread. // Returns false on any error. // - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization. // - sequence_len: number of tokens to read from the array. @@ -150,6 +153,32 @@ extern "C" { float * logits_out ); + // Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks. + // This function is useful for processing complete prompts and user input in chat & role-playing use-cases. + // It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance. + // + // Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory. + // A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64] + // and choose one that works the best in your use case. + // + // Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread. + // Returns false on any error. + // - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization. + // - sequence_len: number of tokens to read from the array. + // - chunk_size: size of each chunk in tokens, must be positive. + // - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass. + // - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL. + // - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL. + RWKV_API bool rwkv_eval_sequence_in_chunks( + struct rwkv_context * ctx, + const uint32_t * tokens, + const size_t sequence_len, + const size_t chunk_size, + const float * state_in, + float * state_out, + float * logits_out + ); + // Returns the number of tokens in the given model's vocabulary. // Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536). RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx); @@ -159,6 +188,8 @@ extern "C" { RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx); // Returns the number of layers in the given model. + // A layer is a pair of RWKV and FFN operations, stacked multiple times throughout the model. + // Embedding matrix and model head (unembedding matrix) are NOT counted in `n_layer`. // Useful for always offloading the entire model to GPU. RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx); diff --git a/rwkv_eval.inc b/rwkv_eval.inc index 38bbb33..37f3a9c 100644 --- a/rwkv_eval.inc +++ b/rwkv_eval.inc @@ -69,6 +69,17 @@ bool rwkv_eval_sequence( RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, sequence_len > 0, "Sequence length is 0"); + if (sequence_len == 1) { + // Avoid building single-token sequence graph, we already have regular eval for this. + return rwkv_eval( + ctx, + sequence[0], + state_in, + state_out, + logits_out + ); + } + if (sequence) { const size_t n_vocab = ctx->model->header.n_vocab; @@ -97,6 +108,72 @@ bool rwkv_eval_sequence( return true; } +// API function. +bool rwkv_eval_sequence_in_chunks( + struct rwkv_context * ctx, + const uint32_t * tokens, + const size_t sequence_len, + const size_t chunk_size, + const float * state_in, + float * state_out, + float * logits_out +) { + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, sequence_len > 0, "Sequence length is 0"); + RWKV_CTX_ASSERT_FALSE_MSG(ctx, RWKV_ERROR_ARGS, chunk_size > 0, "Chunk size is 0"); + + // Will be de-allocated automatically on return. + std::unique_ptr state{ new(std::nothrow) float[rwkv_get_state_len(ctx)] }; + + if (state_in != NULL) { + memcpy(state.get(), state_in, rwkv_get_state_len(ctx) * sizeof(float)); + } else { + rwkv_init_state(ctx, state.get()); + } + + size_t chunk_count = sequence_len / chunk_size; + size_t remainder = sequence_len % chunk_size; + uint32_t * tokens_offset = (uint32_t *) tokens; + + for (size_t c = 0; c < chunk_count; c++) { + bool is_last_eval = c == chunk_count - 1 && remainder == 0; + + bool result = rwkv_eval_sequence( + ctx, + tokens_offset, + chunk_size, + state.get(), + // On the last eval call, copy the state into the user-provided buffer. + is_last_eval ? state_out : state.get(), + // If this is not the last call, we don't have the use for logits and can skip their calculation. + is_last_eval ? logits_out : NULL + ); + + if (!result) { + return false; + } + + tokens_offset += chunk_size; + } + + if (remainder > 0) { + bool result = rwkv_eval_sequence( + ctx, + tokens_offset, + remainder, + state.get(), + // This eval call is always the last. + state_out, + logits_out + ); + + if (!result) { + return false; + } + } + + return true; +} + // API function. void rwkv_init_state(const struct rwkv_context * ctx, float * state) { const struct rwkv_file_header & header = ctx->model->header; diff --git a/rwkv_gpu_offload.inc b/rwkv_gpu_offload.inc index 3564d78..cc47c64 100644 --- a/rwkv_gpu_offload.inc +++ b/rwkv_gpu_offload.inc @@ -18,13 +18,20 @@ bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) #endif }; - const size_t n_gpu = std::min(n_layers, ctx->model->header.n_layer); + const size_t n_gpu = std::min(n_layers, ctx->model->header.n_layer + 1); if (ctx->model->offloaded_layer_count >= n_gpu) { return false; } for (size_t & i = ctx->model->offloaded_layer_count; i < n_gpu; i++) { + if (i == ctx->model->header.n_layer) { + // This is the index of the model head. + offload(ctx->model->head); + + continue; + } + const struct rwkv_layer & layer = ctx->model->layers[i]; // TODO Also offload other supported operations to GPU diff --git a/rwkv_model_loading.inc b/rwkv_model_loading.inc index e19af58..3cf0392 100644 --- a/rwkv_model_loading.inc +++ b/rwkv_model_loading.inc @@ -46,6 +46,8 @@ struct rwkv_model { struct ggml_tensor * head; // How many layers were offloaded to the GPU. + // Model head is counted as an additional layer, + // so the max value for this field is n_layers + 1. size_t offloaded_layer_count; // How many RWKV contexts reference this model. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8705d94..cb32fcd 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -24,4 +24,5 @@ rwkv_add_test(test_quantized_matmul_on_gpu.c) rwkv_add_test(test_tiny_rwkv.c) rwkv_add_test(test_quantization_format_compatibility.c) rwkv_add_test(test_logit_calculation_skipping.c) +rwkv_add_test(test_eval_sequence_in_chunks.c) rwkv_add_test(test_context_cloning.c) diff --git a/tests/logit_difference_validator.inc b/tests/logit_difference_validator.inc index 9ff0591..8907a75 100644 --- a/tests/logit_difference_validator.inc +++ b/tests/logit_difference_validator.inc @@ -29,7 +29,7 @@ void test_model(const char * model_path, const float * expected_logits, const fl ASSERT(error == 0, "Unexpected error %d", error); #if defined(GGML_USE_CUBLAS) - ASSERT(rwkv_gpu_offload_layers(model, rwkv_get_n_layer(model)), "Failed to offload layers to GPU"); + ASSERT(rwkv_gpu_offload_layers(model, rwkv_get_n_layer(model) + 1), "Failed to offload layers to GPU"); #endif const size_t n_vocab = rwkv_get_logits_len(model); diff --git a/tests/test_eval_sequence_in_chunks.c b/tests/test_eval_sequence_in_chunks.c new file mode 100644 index 0000000..7c3b10a --- /dev/null +++ b/tests/test_eval_sequence_in_chunks.c @@ -0,0 +1,76 @@ +// Tests that eval_sequence_in_chunks gives results equivalent to serial eval. +#include +#include +#include + +#include + +#include "assertions.inc" + +void test_on_prompt(const char * prompt, const size_t prompt_length) { + fprintf(stderr, "Calculating expected state and logits for prompt of size %zd\n", prompt_length); + + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); + + ASSERT(ctx != NULL, "Unexpected error 0x%.8X", rwkv_get_last_error(NULL)); + + float * expected_state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + float * expected_logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); + + ASSERT(expected_state != NULL, "Failed to allocate state"); + ASSERT(expected_logits != NULL, "Failed to allocate logits"); + + rwkv_eval(ctx, prompt[0], NULL, expected_state, expected_logits); + + for (size_t i = 1; prompt[i] != 0; i++) { + rwkv_eval(ctx, prompt[i], expected_state, expected_state, expected_logits); + } + + // --- + + uint32_t * prompt_tokens = calloc(prompt_length, sizeof(uint32_t)); + + for (int i = 0; i < prompt_length; i++) { + prompt_tokens[i] = prompt[i]; + } + + // --- + + float * state = calloc(rwkv_get_state_len(ctx), sizeof(float)); + float * logits = calloc(rwkv_get_logits_len(ctx), sizeof(float)); + + ASSERT(state != NULL, "Failed to allocate state"); + ASSERT(logits != NULL, "Failed to allocate logits"); + + const size_t chunk_sizes[4] = {1, 2, 8, 10}; + + for (int i = 0; i < 4; i++) { + size_t chunk_size = chunk_sizes[i]; + + fprintf(stderr, "Testing chunk_size = %zd\n", chunk_size); + + rwkv_eval_sequence_in_chunks(ctx, prompt_tokens, prompt_length, chunk_size, NULL, state, logits); + + ASSERT(memcmp(expected_state, state, rwkv_get_state_len(ctx) * sizeof(float)) == 0, "Results are not identical"); + } + + // --- + + rwkv_free(ctx); + + free(logits); + free(state); + free(expected_logits); + free(expected_state); + free(prompt_tokens); +} + +int main(void) { + const char prompt1[70 + 1] = "This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM"; + test_on_prompt(prompt1, 70); + + const char prompt2[1 + 1] = "T"; + test_on_prompt(prompt2, 1); + + return 0; +}