Skip to content

Commit

Permalink
Add support for the world tokenizer (#86)
Browse files Browse the repository at this point in the history
* Add support for the world tokenizer

* Move tokenizer logic to rwkv_tokenizer.py

* Added test for the tokenizer
  • Loading branch information
Mathmagician8191 committed Jun 8, 2023
1 parent 09ec314 commit 82c4ac7
Show file tree
Hide file tree
Showing 6 changed files with 66,394 additions and 30 deletions.
24 changes: 11 additions & 13 deletions rwkv/chat_with_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import copy
import torch
import sampling
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library
from rwkv_tokenizer import get_tokenizer
import json
from typing import List, Dict, Optional
import time
Expand Down Expand Up @@ -42,6 +42,7 @@

parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
args = parser.parse_args()

script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent
Expand All @@ -53,19 +54,14 @@

assert init_prompt != '', 'Prompt must not be empty'

print('Loading 20B tokenizer')
tokenizer_path = script_dir / '20B_tokenizer.json'
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)

library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')

print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)

prompt_tokens = tokenizer.encode(init_prompt).ids
prompt_token_count = len(prompt_tokens)

# =================================================================================================

processed_tokens: List[int] = []
Expand Down Expand Up @@ -110,9 +106,11 @@ def split_last_end_of_line(tokens):

# =================================================================================================
T1 = time.time()
prompt_tokens = tokenizer_encode(init_prompt)
prompt_token_count = len(prompt_tokens)
print(f'Processing {prompt_token_count} prompt tokens, may take a while')

process_tokens(split_last_end_of_line(tokenizer.encode(init_prompt).ids))
process_tokens(split_last_end_of_line(prompt_tokens))
T2 = time.time()
print(f'Process time :{((T2 - T1)*1000)} ms')
print(f'Process time per token :{(((T2 - T1)*1000)) / prompt_token_count} ms')
Expand Down Expand Up @@ -164,7 +162,7 @@ def split_last_end_of_line(tokens):
new = '\n' + msg[5:].strip()
state = None
processed_tokens = []
process_tokens(tokenizer.encode(new).ids)
process_tokens(tokenizer_encode(new))
save_thread_state('gen_0')

# +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
Expand All @@ -179,15 +177,15 @@ def split_last_end_of_line(tokens):
'''
state = None
processed_tokens = []
process_tokens(tokenizer.encode(new).ids)
process_tokens(tokenizer_encode(new))
save_thread_state('gen_0')

# +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
elif msg[:4].lower() == '+qq ':
new = '\nQ: ' + msg[4:].strip() + '\nA:'
state = None
processed_tokens = []
process_tokens(tokenizer.encode(new).ids)
process_tokens(tokenizer_encode(new))
save_thread_state('gen_0')

# +qa YOUR QUESTION --> answer an independent question (regardless of context).
Expand All @@ -197,7 +195,7 @@ def split_last_end_of_line(tokens):
real_msg = msg[4:].strip()
new = f'{user}{separator} {real_msg}\n\n{bot}{separator}'

process_tokens(tokenizer.encode(new).ids)
process_tokens(tokenizer_encode(new))
save_thread_state('gen_0')

# +++ --> continue last free generation (only for +gen / +i)
Expand Down Expand Up @@ -230,7 +228,7 @@ def split_last_end_of_line(tokens):
else:
load_thread_state('chat')
new = f'{user}{separator} {msg}\n\n{bot}{separator}'
process_tokens(tokenizer.encode(new).ids, new_line_logit_bias=-999999999)
process_tokens(tokenizer_encode(new), new_line_logit_bias=-999999999)
save_thread_state('chat_pre')

thread = 'chat'
Expand Down
13 changes: 6 additions & 7 deletions rwkv/generate_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import argparse
import os
import pathlib
import time
import sampling
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library

from rwkv_tokenizer import get_tokenizer
from typing import List

# ======================================== Script settings ========================================

Expand All @@ -31,21 +30,21 @@

parser = argparse.ArgumentParser(description='Generate completions from RWKV model based on a prompt')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
args = parser.parse_args()

assert prompt != '', 'Prompt must not be empty'

print('Loading 20B tokenizer')
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)

prompt_tokens = tokenizer_encode(prompt)

library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')

print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)

prompt_tokens = tokenizer.encode(prompt).ids
prompt_token_count = len(prompt_tokens)
print(f'{prompt_token_count} tokens in prompt')

Expand Down
17 changes: 7 additions & 10 deletions rwkv/measure_pexplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,30 @@

import os
import time
import pathlib
import argparse
import tokenizers
import torch
import rwkv_cpp_model
import rwkv_cpp_shared_library
from typing import List
from rwkv_tokenizer import get_tokenizer

def parse_args():
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
parser.add_argument('model_path', help='Path to model checkpoint file', type=str)
parser.add_argument('text_path', help='Path to text file in UTF-8 encoding', type=str)
parser.add_argument('ignore_first_n_tokens', help='How many tokens should be skipped before loss is measured', type=int)
parser.add_argument('token_limit', help='How many tokens to process; set to -1 to process all text', nargs='?', type=int, default=-1)
parser.add_argument('tokenizer', help='Which tokenizer to use', nargs='?', type=str, default="20B")
return parser.parse_args()

args = parse_args()

# ---

print('Loading 20B tokenizer')
tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))

print('Loading text')
text: str = open(args.text_path, encoding='utf-8').read()
tokens: List[int] = tokenizer.encode(text).ids

tokenizer, tokenizer_encode = get_tokenizer(args.tokenizer)

tokens = tokenizer_encode(text)

token_count: int = len(tokens)
print(f'{token_count} tokens in the text')

Expand Down
119 changes: 119 additions & 0 deletions rwkv/rwkv_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os
import tokenizers
import pathlib

########################################################################################################
# Taken from https://github.com/BlinkDL/ChatRWKV/tree/main/tokenizer/rwkv_tokenizer.py
########################################################################################################

class TRIE:
__slots__ = tuple("ch,to,values,front".split(","))
to:list
values:set
def __init__(self, front=None, ch=None):
self.ch = ch
self.to = [None for ch in range(256)]
self.values = set()
self.front = front

def __repr__(self):
fr = self
ret = []
while(fr!=None):
if(fr.ch!=None):
ret.append(fr.ch)
fr = fr.front
return "<TRIE %s %s>"%(ret[::-1], self.values)

def add(self, key:bytes, idx:int=0, val=None):
if(idx == len(key)):
if(val is None):
val = key
self.values.add(val)
return self
ch = key[idx]
if(self.to[ch] is None):
self.to[ch] = TRIE(front=self, ch=ch)
return self.to[ch].add(key, idx=idx+1, val=val)

def find_longest(self, key:bytes, idx:int=0):
u:TRIE = self
ch:int = key[idx]

while(u.to[ch] is not None):
u = u.to[ch]
idx += 1
if(u.values):
ret = idx, u, u.values
if(idx==len(key)):
break
ch = key[idx]
return ret

class TRIE_TOKENIZER():
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
with open(file_name, "r", encoding="utf-8") as f:
lines = f.readlines()
for l in lines:
idx = int(l[:l.index(' ')])
x = eval(l[l.index(' '):l.rindex(' ')])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):])
sorted += [x]
self.idx2token[idx] = x

self.token2idx = {}
for k,v in self.idx2token.items():
self.token2idx[v] = int(k)

self.root = TRIE()
for t, i in self.token2idx.items():
_ = self.root.add(t, val=(t, i))

def encodeBytes(self, src:bytes) -> list[int]:
idx:int = 0
tokens:list[int] = []
while (idx < len(src)):
_idx:int = idx
idx, _, values = self.root.find_longest(src, idx)
assert(idx != _idx)
_, token = next(iter(values))
tokens.append(token)
return tokens

def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens))

def encode(self, src):
return self.encodeBytes(src.encode("utf-8"))

def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8')

def printTokens(self, tokens):
for i in tokens:
s = self.idx2token[i]
try:
s = s.decode('utf-8')
except:
pass
print(f'{repr(s)}{i}', end=' ')
print()

def get_tokenizer(tokenizer="20B"):
if tokenizer == "world":
print('Loading world tokenizer')
tokenizer = TRIE_TOKENIZER('rwkv_vocab_v20230424.txt')
tokenizer_encode = lambda prompt: tokenizer.encode(prompt)
elif tokenizer == "20B":
print('Loading 20B tokenizer')
tokenizer_path: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
tokenizer_encode = lambda prompt: tokenizer.encode(prompt).ids
else:
print(f"Unknown tokenizer: {args.tokenizer}")
quit()
return tokenizer, tokenizer_encode

0 comments on commit 82c4ac7

Please sign in to comment.