Skip to content

Commit

Permalink
Various improvements (#131)
Browse files Browse the repository at this point in the history
* Implement model head offloading

* Guess the tokenizer from n_vocab

* Make PyTorch optional for inference

* Add function to offload layers

* Add rwkv_eval_sequence_in_chunks
  • Loading branch information
saharNooby committed Sep 23, 2023
1 parent 6caa45e commit 39ed572
Show file tree
Hide file tree
Showing 16 changed files with 567 additions and 122 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -146,7 +146,7 @@ python python/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169

#### Using the command line

**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/).
**Requirements**: Python 3.x with [numpy](https://numpy.org/). If using `Pile` or `Raven` models, [tokenizers](https://pypi.org/project/tokenizers/) is also required.

To generate some text, run:

Expand Down
18 changes: 8 additions & 10 deletions python/chat_with_bot.py
Expand Up @@ -8,10 +8,9 @@
import copy
import json
import time
import torch
import sampling
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
from tokenizer_util import get_tokenizer
from tokenizer_util import add_tokenizer_argument, get_tokenizer
from typing import List, Dict, Optional

# ======================================== Script settings ========================================
Expand Down Expand Up @@ -41,7 +40,7 @@

parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B')
add_tokenizer_argument(parser)
args = parser.parse_args()

script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
Expand All @@ -53,27 +52,26 @@

assert init_prompt != '', 'Prompt must not be empty'

tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer)

library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')

print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)

tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)

# =================================================================================================

processed_tokens: List[int] = []
logits: Optional[torch.Tensor] = None
state: Optional[torch.Tensor] = None
logits: Optional[rwkv_cpp_model.NumpyArrayOrPyTorchTensor] = None
state: Optional[rwkv_cpp_model.NumpyArrayOrPyTorchTensor] = None

def process_tokens(_tokens: List[int], new_line_logit_bias: float = 0.0) -> None:
global processed_tokens, logits, state

processed_tokens += _tokens
logits, state = model.eval_sequence_in_chunks(_tokens, state, state, logits, use_numpy=True)

for _token in _tokens:
logits, state = model.eval(_token, state, state, logits)
processed_tokens += _tokens

logits[END_OF_LINE_TOKEN] += new_line_logit_bias

Expand Down
21 changes: 9 additions & 12 deletions python/generate_completions.py
Expand Up @@ -5,7 +5,7 @@
import time
import sampling
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
from tokenizer_util import get_tokenizer
from tokenizer_util import add_tokenizer_argument, get_tokenizer
from typing import List

# ======================================== Script settings ========================================
Expand All @@ -29,43 +29,40 @@

parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B')
add_tokenizer_argument(parser)
args = parser.parse_args()

assert prompt != '', 'Prompt must not be empty'

tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer)

prompt_tokens: List[int] = tokenizer_encode(prompt)

library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')

print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)

tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)

prompt_tokens: List[int] = tokenizer_encode(prompt)

prompt_token_count: int = len(prompt_tokens)
print(f'{prompt_token_count} tokens in prompt')

init_logits, init_state = None, None

for token in prompt_tokens:
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
init_logits, init_state = model.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True)

for GENERATION in range(generation_count):
print(f'\n--- Generation {GENERATION} ---\n')
print(prompt, end='[')

start: float = time.time()

logits, state = init_logits.clone(), init_state.clone()
logits, state = init_logits.copy(), init_state.copy()

for i in range(tokens_per_generation):
token: int = sampling.sample_logits(logits, temperature, top_p)

print(tokenizer_decode([token]), end='', flush=True)

logits, state = model.eval(token, state, state, logits)
logits, state = model.eval(token, state, state, logits, use_numpy=True)

delay: float = time.time() - start

Expand Down
15 changes: 5 additions & 10 deletions python/inference_example.py
Expand Up @@ -4,33 +4,28 @@
import argparse
import sampling
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
from tokenizer_util import get_tokenizer
from tokenizer_util import add_tokenizer_argument, get_tokenizer
from typing import List

# Parse received arguments.
parser = argparse.ArgumentParser(description='Generate some text with an RWKV model')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B')
add_tokenizer_argument(parser)
args = parser.parse_args()

# Load the model.
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
model = rwkv_cpp_model.RWKVModel(library, args.model_path)

# Set up the tokenizer.
tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer)
tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)

# Prepare the prompt.
prompt: str = """One upon a time,"""
prompt_tokens: List[int] = tokenizer_encode(prompt)

# Process the prompt.
init_logits, init_state = None, None

for token in prompt_tokens:
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)

logits, state = init_logits.clone(), init_state.clone()
logits, state = model.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True)

# Generate and print the completion.
print(prompt, end='')
Expand All @@ -40,7 +35,7 @@

print(tokenizer_decode([token]), end='', flush=True)

logits, state = model.eval(token, state, state, logits)
logits, state = model.eval(token, state, state, logits, use_numpy=True)

# Don't forget to free the memory after you are done working with the model!
model.free()
18 changes: 10 additions & 8 deletions python/measure_pexplexity.py
Expand Up @@ -5,9 +5,10 @@
import os
import time
import argparse
# TODO Get rid of this PyTorch dependency by writing a cross_entropy impl for numpy
import torch
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
from tokenizer_util import get_tokenizer
from tokenizer_util import add_tokenizer_argument, get_tokenizer
from typing import List

def parse_args():
Expand All @@ -16,15 +17,21 @@ def parse_args():
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding', type=str)
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int)
parser.add_argument('token_limit', help='How many tokens to process; set to -1 to process all text', nargs='?', type=int, default=-1)
parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B')
add_tokenizer_argument(parser)
return parser.parse_args()

args = parse_args()

print('Loading model')
model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel(
rwkv_cpp_shared_library.load_rwkv_shared_library(),
args.model_path
)

print('Loading text')
text: str = open(args.text_path, encoding='utf-8').read()

_, tokenizer_encode = get_tokenizer(args.tokenizer)
_, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)

tokens: List[int] = tokenizer_encode(text)

Expand Down Expand Up @@ -52,11 +59,6 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:

# ---

model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel(
rwkv_cpp_shared_library.load_rwkv_shared_library(),
args.model_path
)

logits, state = None, None

loss_sum: torch.Tensor = torch.tensor([0.0])
Expand Down

0 comments on commit 39ed572

Please sign in to comment.