Skip to content

Commit

Permalink
Various improvements (#47)
Browse files Browse the repository at this point in the history
* Update ggml

* Pack only rwkv.dll for Windows releases

Test executables would not be packed anymore.

* Move test code into a separate file

* Remove redundant zeroing

* Refactor chat script
  • Loading branch information
saharNooby committed Apr 30, 2023
1 parent 3621172 commit 5eb8f09
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 137 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Expand Up @@ -230,7 +230,7 @@ jobs:
id: pack_artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: |
7z a rwkv-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip .\build\bin\Release\*
7z a rwkv-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip .\build\bin\Release\rwkv.dll
- name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
Expand Down
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated 1 files
+45 −16 src/ggml.c
1 change: 0 additions & 1 deletion rwkv.cpp
Expand Up @@ -568,7 +568,6 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float

RWKV_ASSERT_FALSE(token >= 0 && token < n_vocab, "Token is out of range 0..%d", n_vocab - 1);

ggml_set_i32(ctx->token_index, 0);
ggml_set_i32_1d(ctx->token_index, 0, token);

if (state_in == NULL) {
Expand Down
169 changes: 84 additions & 85 deletions rwkv/chat_with_bot.py
Expand Up @@ -12,22 +12,15 @@
import rwkv_cpp_model
import rwkv_cpp_shared_library
import json
from typing import Optional

# ======================================== Script settings ========================================

# English, Chinese, Japanese
LANGUAGE: str = 'English'
# QA: Question and Answer prompt
# Chat: chat prompt (you need a large model for adequate quality, 7B+)
PROMPT_TYPE: str = "Chat"

PROMPT_FILE: str = f'./rwkv/prompt/{LANGUAGE}-{PROMPT_TYPE}.json'

def load_prompt(PROMPT_FILE: str):
with open(PROMPT_FILE, 'r') as json_file:
variables = json.load(json_file)
user, bot, separator, prompt = variables['user'], variables['bot'], variables['separator'], variables['prompt']
return user, bot, separator, prompt
# QA: Question and Answer prompt to talk to an AI assistant.
# Chat: chat prompt (need a large model for adequate quality, 7B+).
PROMPT_TYPE: str = 'QA'

MAX_GENERATION_LENGTH: int = 250

Expand All @@ -39,6 +32,7 @@ def load_prompt(PROMPT_FILE: str):
PRESENCE_PENALTY: float = 0.2
# Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
FREQUENCY_PENALTY: float = 0.2

END_OF_LINE_TOKEN: int = 187
END_OF_TEXT_TOKEN: int = 0

Expand All @@ -48,11 +42,17 @@ def load_prompt(PROMPT_FILE: str):
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
args = parser.parse_args()

user, bot, separator, init_prompt = load_prompt(PROMPT_FILE)
script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent

with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r') as json_file:
prompt_data = json.load(json_file)

user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt']

assert init_prompt != '', 'Prompt must not be empty'

print('Loading 20B tokenizer')
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer_path = script_dir / '20B_tokenizer.json'
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))

library = rwkv_cpp_shared_library.load_rwkv_shared_library()
Expand All @@ -64,48 +64,48 @@ def load_prompt(PROMPT_FILE: str):
prompt_tokens = tokenizer.encode(init_prompt).ids
prompt_token_count = len(prompt_tokens)

########################################################################################################

model_tokens: list[int] = []

logits, model_state = None, None
# =================================================================================================

def process_tokens(_tokens: list[int], newline_adj: int = 0) -> torch.Tensor:
global model_tokens, model_state, logits
processed_tokens: list[int] = []
logits: Optional[torch.Tensor] = None
state: Optional[torch.Tensor] = None

_tokens = [int(x) for x in _tokens]
def process_tokens(_tokens: list[int], new_line_logit_bias: float = 0.0) -> None:
global processed_tokens, logits, state

model_tokens += _tokens
processed_tokens += _tokens

for _token in _tokens:
logits, model_state = model.eval(_token, model_state, model_state, logits)
logits, state = model.eval(_token, state, state, logits)

logits[END_OF_LINE_TOKEN] += newline_adj # adjust \n probability

return logits
logits[END_OF_LINE_TOKEN] += new_line_logit_bias

state_by_thread: dict[str, dict] = {}

def save_thread_state(_thread: str, _logits: torch.Tensor) -> None:
state_by_thread[_thread] = {}
state_by_thread[_thread]['logits'] = copy.deepcopy(_logits)
state_by_thread[_thread]['rnn'] = copy.deepcopy(model_state)
state_by_thread[_thread]['token'] = copy.deepcopy(model_tokens)
def save_thread_state(_thread: str) -> None:
state_by_thread[_thread] = {
'tokens': copy.deepcopy(processed_tokens),
'logits': copy.deepcopy(logits),
'state': copy.deepcopy(state)
}

def load_thread_state(_thread: str) -> None:
global processed_tokens, logits, state

def load_thread_state(_thread: str) -> torch.Tensor:
global model_tokens, model_state
model_state = copy.deepcopy(state_by_thread[_thread]['rnn'])
model_tokens = copy.deepcopy(state_by_thread[_thread]['token'])
return copy.deepcopy(state_by_thread[_thread]['logits'])
thread_state = state_by_thread[_thread]

########################################################################################################
processed_tokens = copy.deepcopy(thread_state['tokens'])
logits = copy.deepcopy(thread_state['logits'])
state = copy.deepcopy(thread_state['state'])

# =================================================================================================

print(f'Processing {prompt_token_count} prompt tokens, may take a while')

logits = process_tokens(tokenizer.encode(init_prompt).ids)
process_tokens(tokenizer.encode(init_prompt).ids)

save_thread_state('chat_init', logits)
save_thread_state('chat', logits)
save_thread_state('chat_init')
save_thread_state('chat')

print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.')

Expand All @@ -117,7 +117,7 @@ def load_thread_state(_thread: str) -> torch.Tensor:
temperature = TEMPERATURE
top_p = TOP_P

if "-temp=" in msg:
if '-temp=' in msg:
temperature = float(msg.split('-temp=')[1].split(' ')[0])

msg = msg.replace('-temp='+f'{temperature:g}', '')
Expand All @@ -128,7 +128,7 @@ def load_thread_state(_thread: str) -> torch.Tensor:
if temperature >= 5:
temperature = 5

if "-top_p=" in msg:
if '-top_p=' in msg:
top_p = float(msg.split('-top_p=')[1].split(' ')[0])

msg = msg.replace('-top_p='+f'{top_p:g}', '')
Expand All @@ -140,20 +140,19 @@ def load_thread_state(_thread: str) -> torch.Tensor:

# + reset --> reset chat
if msg == '+reset':
logits = load_thread_state('chat_init')
save_thread_state('chat', logits)
load_thread_state('chat_init')
save_thread_state('chat')
print(f'{bot}{separator} Chat reset.\n')
continue
elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++':

# +gen YOUR PROMPT --> free single-round generation with any prompt. Requires Novel model.
if msg[:5].lower() == '+gen ':
new = '\n' + msg[5:].strip()
# print(f'### prompt ###\n[{new}]')
model_state = None
model_tokens = []
logits = process_tokens(tokenizer.encode(new).ids)
save_thread_state('gen_0', logits)
state = None
processed_tokens = []
process_tokens(tokenizer.encode(new).ids)
save_thread_state('gen_0')

# +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model.
elif msg[:3].lower() == '+i ':
Expand All @@ -165,87 +164,87 @@ def load_thread_state(_thread: str) -> torch.Tensor:
# Response:
'''
# print(f'### prompt ###\n[{new}]')
model_state = None
model_tokens = []
logits = process_tokens(tokenizer.encode(new).ids)
save_thread_state('gen_0', logits)
state = None
processed_tokens = []
process_tokens(tokenizer.encode(new).ids)
save_thread_state('gen_0')

# +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context).
elif msg[:4].lower() == '+qq ':
new = '\nQ: ' + msg[4:].strip() + '\nA:'
# print(f'### prompt ###\n[{new}]')
model_state = None
model_tokens = []
logits = process_tokens(tokenizer.encode(new).ids)
save_thread_state('gen_0', logits)
state = None
processed_tokens = []
process_tokens(tokenizer.encode(new).ids)
save_thread_state('gen_0')

# +qa YOUR QUESTION --> answer an independent question (regardless of context).
elif msg[:4].lower() == '+qa ':
logits = load_thread_state('chat_init')
load_thread_state('chat_init')

real_msg = msg[4:].strip()
new = f"{user}{separator} {real_msg}\n\n{bot}{separator}"
# print(f'### qa ###\n[{new}]')
new = f'{user}{separator} {real_msg}\n\n{bot}{separator}'

logits = process_tokens(tokenizer.encode(new).ids)
save_thread_state('gen_0', logits)
process_tokens(tokenizer.encode(new).ids)
save_thread_state('gen_0')

# +++ --> continue last free generation (only for +gen / +i)
elif msg.lower() == '+++':
try:
logits = load_thread_state('gen_1')
save_thread_state('gen_0', logits)
load_thread_state('gen_1')
save_thread_state('gen_0')
except Exception as e:
print(e)
continue

# ++ --> retry last free generation (only for +gen / +i)
elif msg.lower() == '++':
try:
logits = load_thread_state('gen_0')
load_thread_state('gen_0')
except Exception as e:
print(e)
continue
thread = "gen_1"
thread = 'gen_1'

else:
# + --> alternate chat reply
if msg.lower() == '+':
try:
logits = load_thread_state('chat_pre')
load_thread_state('chat_pre')
except Exception as e:
print(e)
continue
# chat with bot
else:
logits = load_thread_state('chat')
new = f"{user}{separator} {msg}\n\n{bot}{separator}"
# print(f'### add ###\n[{new}]')
logits = process_tokens(tokenizer.encode(new).ids, newline_adj=-999999999)
save_thread_state('chat_pre', logits)
load_thread_state('chat')
new = f'{user}{separator} {msg}\n\n{bot}{separator}'
process_tokens(tokenizer.encode(new).ids, new_line_logit_bias=-999999999)
save_thread_state('chat_pre')

thread = 'chat'

# Print bot response
print(f"> {bot}{separator}", end='')
print(f'> {bot}{separator}', end='')

start_index: int = len(model_tokens)
start_index: int = len(processed_tokens)
accumulated_tokens: list[int] = []
occurrence: dict[int, int] = {}
token_counts: dict[int, int] = {}

for i in range(MAX_GENERATION_LENGTH):
for n in occurrence:
logits[n] -= (PRESENCE_PENALTY + occurrence[n] * FREQUENCY_PENALTY)
for n in token_counts:
logits[n] -= PRESENCE_PENALTY + token_counts[n] * FREQUENCY_PENALTY

token: int = sampling.sample_logits(logits, temperature, top_p)

if token == END_OF_TEXT_TOKEN:
print()
break
if token not in occurrence:
occurrence[token] = 1

if token not in token_counts:
token_counts[token] = 1
else:
occurrence[token] += 1
logits: torch.Tensor = process_tokens([token])
token_counts[token] += 1

process_tokens([token])

# Avoid UTF-8 display issues
accumulated_tokens += [token]
Expand All @@ -258,10 +257,10 @@ def load_thread_state(_thread: str) -> torch.Tensor:
accumulated_tokens = []

if thread == 'chat':
if '\n\n' in tokenizer.decode(model_tokens[start_index:]):
if '\n\n' in tokenizer.decode(processed_tokens[start_index:]):
break

if i == MAX_GENERATION_LENGTH - 1:
print()

save_thread_state(thread, logits)
save_thread_state(thread)
49 changes: 0 additions & 49 deletions rwkv/convert_pytorch_to_ggml.py
Expand Up @@ -3,7 +3,6 @@
# Get model checkpoints from https://huggingface.co/BlinkDL
# See FILE_FORMAT.md for the documentation on the file format.

import os
import argparse
import struct
import torch
Expand Down Expand Up @@ -97,53 +96,5 @@ def main() -> None:

print('Done')

# --- Tests ---

def test() -> None:
test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp'

try:
state_dict: Dict[str, torch.Tensor] = {
'emb.weight': torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32),
'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32)
}

write_state_dict(state_dict, dest_path=test_file_path, data_type='float32')

with open(test_file_path, 'rb') as input:
actual_bytes: bytes = input.read()

expected_bytes: bytes = struct.pack(
'=iiiiii' + 'iiiii10sffffff' + 'iiii19sf',
0x67676d66,
100,
3,
2,
1,
0,
# emb.weight
2,
10,
0,
2, 3,
'emb.weight'.encode('utf-8'),
1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
# blocks.0.ln1.weight
1,
19,
0,
1,
'blocks.0.ln1.weight'.encode('utf-8'),
1.0
)

assert list(actual_bytes) == list(expected_bytes), f'\nActual: {list(actual_bytes)}\nExpected: {list(expected_bytes)}'

print('All tests pass')
finally:
if os.path.isfile(test_file_path):
os.remove(test_file_path)

if __name__ == "__main__":
main()

0 comments on commit 5eb8f09

Please sign in to comment.