Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various improvements #131

Merged
merged 5 commits into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ python python/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169

#### Using the command line

**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/).
**Requirements**: Python 3.x with [numpy](https://numpy.org/). If using `Pile` or `Raven` models, [tokenizers](https://pypi.org/project/tokenizers/) is also required.

To generate some text, run:

Expand Down
18 changes: 8 additions & 10 deletions python/chat_with_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import copy
import json
import time
import torch
import sampling
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
from tokenizer_util import get_tokenizer
from tokenizer_util import add_tokenizer_argument, get_tokenizer
from typing import List, Dict, Optional

# ======================================== Script settings ========================================
Expand Down Expand Up @@ -41,7 +40,7 @@

parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B')
add_tokenizer_argument(parser)
args = parser.parse_args()

script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
Expand All @@ -53,27 +52,26 @@

assert init_prompt != '', 'Prompt must not be empty'

tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer)

library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')

print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)

tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)

# =================================================================================================

processed_tokens: List[int] = []
logits: Optional[torch.Tensor] = None
state: Optional[torch.Tensor] = None
logits: Optional[rwkv_cpp_model.NumpyArrayOrPyTorchTensor] = None
state: Optional[rwkv_cpp_model.NumpyArrayOrPyTorchTensor] = None

def process_tokens(_tokens: List[int], new_line_logit_bias: float = 0.0) -> None:
global processed_tokens, logits, state

processed_tokens += _tokens
logits, state = model.eval_sequence_in_chunks(_tokens, state, state, logits, use_numpy=True)

for _token in _tokens:
logits, state = model.eval(_token, state, state, logits)
processed_tokens += _tokens

logits[END_OF_LINE_TOKEN] += new_line_logit_bias

Expand Down
21 changes: 9 additions & 12 deletions python/generate_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import sampling
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
from tokenizer_util import get_tokenizer
from tokenizer_util import add_tokenizer_argument, get_tokenizer
from typing import List

# ======================================== Script settings ========================================
Expand All @@ -29,43 +29,40 @@

parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B')
add_tokenizer_argument(parser)
args = parser.parse_args()

assert prompt != '', 'Prompt must not be empty'

tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer)

prompt_tokens: List[int] = tokenizer_encode(prompt)

library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')

print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)

tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)

prompt_tokens: List[int] = tokenizer_encode(prompt)

prompt_token_count: int = len(prompt_tokens)
print(f'{prompt_token_count} tokens in prompt')

init_logits, init_state = None, None

for token in prompt_tokens:
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)
init_logits, init_state = model.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True)

for GENERATION in range(generation_count):
print(f'\n--- Generation {GENERATION} ---\n')
print(prompt, end='[')

start: float = time.time()

logits, state = init_logits.clone(), init_state.clone()
logits, state = init_logits.copy(), init_state.copy()

for i in range(tokens_per_generation):
token: int = sampling.sample_logits(logits, temperature, top_p)

print(tokenizer_decode([token]), end='', flush=True)

logits, state = model.eval(token, state, state, logits)
logits, state = model.eval(token, state, state, logits, use_numpy=True)

delay: float = time.time() - start

Expand Down
15 changes: 5 additions & 10 deletions python/inference_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,28 @@
import argparse
import sampling
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
from tokenizer_util import get_tokenizer
from tokenizer_util import add_tokenizer_argument, get_tokenizer
from typing import List

# Parse received arguments.
parser = argparse.ArgumentParser(description='Generate some text with an RWKV model')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B')
add_tokenizer_argument(parser)
args = parser.parse_args()

# Load the model.
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
model = rwkv_cpp_model.RWKVModel(library, args.model_path)

# Set up the tokenizer.
tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer)
tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)

# Prepare the prompt.
prompt: str = """One upon a time,"""
prompt_tokens: List[int] = tokenizer_encode(prompt)

# Process the prompt.
init_logits, init_state = None, None

for token in prompt_tokens:
init_logits, init_state = model.eval(token, init_state, init_state, init_logits)

logits, state = init_logits.clone(), init_state.clone()
logits, state = model.eval_sequence_in_chunks(prompt_tokens, None, None, None, use_numpy=True)

# Generate and print the completion.
print(prompt, end='')
Expand All @@ -40,7 +35,7 @@

print(tokenizer_decode([token]), end='', flush=True)

logits, state = model.eval(token, state, state, logits)
logits, state = model.eval(token, state, state, logits, use_numpy=True)

# Don't forget to free the memory after you are done working with the model!
model.free()
18 changes: 10 additions & 8 deletions python/measure_pexplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import os
import time
import argparse
# TODO Get rid of this PyTorch dependency by writing a cross_entropy impl for numpy
import torch
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
from tokenizer_util import get_tokenizer
from tokenizer_util import add_tokenizer_argument, get_tokenizer
from typing import List

def parse_args():
Expand All @@ -16,15 +17,21 @@ def parse_args():
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding', type=str)
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int)
parser.add_argument('token_limit', help='How many tokens to process; set to -1 to process all text', nargs='?', type=int, default=-1)
parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='20B')
add_tokenizer_argument(parser)
return parser.parse_args()

args = parse_args()

print('Loading model')
model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel(
rwkv_cpp_shared_library.load_rwkv_shared_library(),
args.model_path
)

print('Loading text')
text: str = open(args.text_path, encoding='utf-8').read()

_, tokenizer_encode = get_tokenizer(args.tokenizer)
_, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab)

tokens: List[int] = tokenizer_encode(text)

Expand Down Expand Up @@ -52,11 +59,6 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:

# ---

model: rwkv_cpp_model.RWKVModel = rwkv_cpp_model.RWKVModel(
rwkv_cpp_shared_library.load_rwkv_shared_library(),
args.model_path
)

logits, state = None, None

loss_sum: torch.Tensor = torch.tensor([0.0])
Expand Down