From 5eb8f09c146ea8124633ab041d9ea0b1f1db4459 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 30 Apr 2023 20:27:14 +0500 Subject: [PATCH] Various improvements (#47) * Update ggml * Pack only rwkv.dll for Windows releases Test executables would not be packed anymore. * Move test code into a separate file * Remove redundant zeroing * Refactor chat script --- .github/workflows/build.yml | 2 +- ggml | 2 +- rwkv.cpp | 1 - rwkv/chat_with_bot.py | 169 +++++++++++++-------------- rwkv/convert_pytorch_to_ggml.py | 49 -------- rwkv/convert_pytorch_to_ggml.test.py | 54 +++++++++ 6 files changed, 140 insertions(+), 137 deletions(-) create mode 100644 rwkv/convert_pytorch_to_ggml.test.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 787a44a..a9b1ecf 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -230,7 +230,7 @@ jobs: id: pack_artifacts if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} run: | - 7z a rwkv-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip .\build\bin\Release\* + 7z a rwkv-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip .\build\bin\Release\rwkv.dll - name: Upload artifacts if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} diff --git a/ggml b/ggml index b237714..9d7974c 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit b237714db49cc09b63a372aeb33ca83bc56b3977 +Subproject commit 9d7974c3cf1284b4ddb926d94552e9fe4c4ad483 diff --git a/rwkv.cpp b/rwkv.cpp index fccef74..9ba7786 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -568,7 +568,6 @@ bool rwkv_eval(struct rwkv_context * ctx, int32_t token, float * state_in, float RWKV_ASSERT_FALSE(token >= 0 && token < n_vocab, "Token is out of range 0..%d", n_vocab - 1); - ggml_set_i32(ctx->token_index, 0); ggml_set_i32_1d(ctx->token_index, 0, token); if (state_in == NULL) { diff --git a/rwkv/chat_with_bot.py b/rwkv/chat_with_bot.py index f06ffa3..0ff772b 100644 --- a/rwkv/chat_with_bot.py +++ b/rwkv/chat_with_bot.py @@ -12,22 +12,15 @@ import rwkv_cpp_model import rwkv_cpp_shared_library import json +from typing import Optional # ======================================== Script settings ======================================== # English, Chinese, Japanese LANGUAGE: str = 'English' -# QA: Question and Answer prompt -# Chat: chat prompt (you need a large model for adequate quality, 7B+) -PROMPT_TYPE: str = "Chat" - -PROMPT_FILE: str = f'./rwkv/prompt/{LANGUAGE}-{PROMPT_TYPE}.json' - -def load_prompt(PROMPT_FILE: str): - with open(PROMPT_FILE, 'r') as json_file: - variables = json.load(json_file) - user, bot, separator, prompt = variables['user'], variables['bot'], variables['separator'], variables['prompt'] - return user, bot, separator, prompt +# QA: Question and Answer prompt to talk to an AI assistant. +# Chat: chat prompt (need a large model for adequate quality, 7B+). +PROMPT_TYPE: str = 'QA' MAX_GENERATION_LENGTH: int = 250 @@ -39,6 +32,7 @@ def load_prompt(PROMPT_FILE: str): PRESENCE_PENALTY: float = 0.2 # Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. FREQUENCY_PENALTY: float = 0.2 + END_OF_LINE_TOKEN: int = 187 END_OF_TEXT_TOKEN: int = 0 @@ -48,11 +42,17 @@ def load_prompt(PROMPT_FILE: str): parser.add_argument('model_path', help='Path to RWKV model in ggml format') args = parser.parse_args() -user, bot, separator, init_prompt = load_prompt(PROMPT_FILE) +script_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent + +with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r') as json_file: + prompt_data = json.load(json_file) + + user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt'] + assert init_prompt != '', 'Prompt must not be empty' print('Loading 20B tokenizer') -tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json' +tokenizer_path = script_dir / '20B_tokenizer.json' tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path)) library = rwkv_cpp_shared_library.load_rwkv_shared_library() @@ -64,48 +64,48 @@ def load_prompt(PROMPT_FILE: str): prompt_tokens = tokenizer.encode(init_prompt).ids prompt_token_count = len(prompt_tokens) -######################################################################################################## - -model_tokens: list[int] = [] - -logits, model_state = None, None +# ================================================================================================= -def process_tokens(_tokens: list[int], newline_adj: int = 0) -> torch.Tensor: - global model_tokens, model_state, logits +processed_tokens: list[int] = [] +logits: Optional[torch.Tensor] = None +state: Optional[torch.Tensor] = None - _tokens = [int(x) for x in _tokens] +def process_tokens(_tokens: list[int], new_line_logit_bias: float = 0.0) -> None: + global processed_tokens, logits, state - model_tokens += _tokens + processed_tokens += _tokens for _token in _tokens: - logits, model_state = model.eval(_token, model_state, model_state, logits) + logits, state = model.eval(_token, state, state, logits) - logits[END_OF_LINE_TOKEN] += newline_adj # adjust \n probability - - return logits + logits[END_OF_LINE_TOKEN] += new_line_logit_bias state_by_thread: dict[str, dict] = {} -def save_thread_state(_thread: str, _logits: torch.Tensor) -> None: - state_by_thread[_thread] = {} - state_by_thread[_thread]['logits'] = copy.deepcopy(_logits) - state_by_thread[_thread]['rnn'] = copy.deepcopy(model_state) - state_by_thread[_thread]['token'] = copy.deepcopy(model_tokens) +def save_thread_state(_thread: str) -> None: + state_by_thread[_thread] = { + 'tokens': copy.deepcopy(processed_tokens), + 'logits': copy.deepcopy(logits), + 'state': copy.deepcopy(state) + } + +def load_thread_state(_thread: str) -> None: + global processed_tokens, logits, state -def load_thread_state(_thread: str) -> torch.Tensor: - global model_tokens, model_state - model_state = copy.deepcopy(state_by_thread[_thread]['rnn']) - model_tokens = copy.deepcopy(state_by_thread[_thread]['token']) - return copy.deepcopy(state_by_thread[_thread]['logits']) + thread_state = state_by_thread[_thread] -######################################################################################################## + processed_tokens = copy.deepcopy(thread_state['tokens']) + logits = copy.deepcopy(thread_state['logits']) + state = copy.deepcopy(thread_state['state']) + +# ================================================================================================= print(f'Processing {prompt_token_count} prompt tokens, may take a while') -logits = process_tokens(tokenizer.encode(init_prompt).ids) +process_tokens(tokenizer.encode(init_prompt).ids) -save_thread_state('chat_init', logits) -save_thread_state('chat', logits) +save_thread_state('chat_init') +save_thread_state('chat') print(f'\nChat initialized! Your name is {user}. Write something and press Enter. Use \\n to add line breaks to your message.') @@ -117,7 +117,7 @@ def load_thread_state(_thread: str) -> torch.Tensor: temperature = TEMPERATURE top_p = TOP_P - if "-temp=" in msg: + if '-temp=' in msg: temperature = float(msg.split('-temp=')[1].split(' ')[0]) msg = msg.replace('-temp='+f'{temperature:g}', '') @@ -128,7 +128,7 @@ def load_thread_state(_thread: str) -> torch.Tensor: if temperature >= 5: temperature = 5 - if "-top_p=" in msg: + if '-top_p=' in msg: top_p = float(msg.split('-top_p=')[1].split(' ')[0]) msg = msg.replace('-top_p='+f'{top_p:g}', '') @@ -140,8 +140,8 @@ def load_thread_state(_thread: str) -> torch.Tensor: # + reset --> reset chat if msg == '+reset': - logits = load_thread_state('chat_init') - save_thread_state('chat', logits) + load_thread_state('chat_init') + save_thread_state('chat') print(f'{bot}{separator} Chat reset.\n') continue elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++': @@ -149,11 +149,10 @@ def load_thread_state(_thread: str) -> torch.Tensor: # +gen YOUR PROMPT --> free single-round generation with any prompt. Requires Novel model. if msg[:5].lower() == '+gen ': new = '\n' + msg[5:].strip() - # print(f'### prompt ###\n[{new}]') - model_state = None - model_tokens = [] - logits = process_tokens(tokenizer.encode(new).ids) - save_thread_state('gen_0', logits) + state = None + processed_tokens = [] + process_tokens(tokenizer.encode(new).ids) + save_thread_state('gen_0') # +i YOUR INSTRUCT --> free single-round generation with any instruct. Requires Raven model. elif msg[:3].lower() == '+i ': @@ -165,37 +164,34 @@ def load_thread_state(_thread: str) -> torch.Tensor: # Response: ''' - # print(f'### prompt ###\n[{new}]') - model_state = None - model_tokens = [] - logits = process_tokens(tokenizer.encode(new).ids) - save_thread_state('gen_0', logits) + state = None + processed_tokens = [] + process_tokens(tokenizer.encode(new).ids) + save_thread_state('gen_0') # +qq YOUR QUESTION --> answer an independent question with more creativity (regardless of context). elif msg[:4].lower() == '+qq ': new = '\nQ: ' + msg[4:].strip() + '\nA:' - # print(f'### prompt ###\n[{new}]') - model_state = None - model_tokens = [] - logits = process_tokens(tokenizer.encode(new).ids) - save_thread_state('gen_0', logits) + state = None + processed_tokens = [] + process_tokens(tokenizer.encode(new).ids) + save_thread_state('gen_0') # +qa YOUR QUESTION --> answer an independent question (regardless of context). elif msg[:4].lower() == '+qa ': - logits = load_thread_state('chat_init') + load_thread_state('chat_init') real_msg = msg[4:].strip() - new = f"{user}{separator} {real_msg}\n\n{bot}{separator}" - # print(f'### qa ###\n[{new}]') + new = f'{user}{separator} {real_msg}\n\n{bot}{separator}' - logits = process_tokens(tokenizer.encode(new).ids) - save_thread_state('gen_0', logits) + process_tokens(tokenizer.encode(new).ids) + save_thread_state('gen_0') # +++ --> continue last free generation (only for +gen / +i) elif msg.lower() == '+++': try: - logits = load_thread_state('gen_1') - save_thread_state('gen_0', logits) + load_thread_state('gen_1') + save_thread_state('gen_0') except Exception as e: print(e) continue @@ -203,49 +199,52 @@ def load_thread_state(_thread: str) -> torch.Tensor: # ++ --> retry last free generation (only for +gen / +i) elif msg.lower() == '++': try: - logits = load_thread_state('gen_0') + load_thread_state('gen_0') except Exception as e: print(e) continue - thread = "gen_1" + thread = 'gen_1' else: # + --> alternate chat reply if msg.lower() == '+': try: - logits = load_thread_state('chat_pre') + load_thread_state('chat_pre') except Exception as e: print(e) continue # chat with bot else: - logits = load_thread_state('chat') - new = f"{user}{separator} {msg}\n\n{bot}{separator}" - # print(f'### add ###\n[{new}]') - logits = process_tokens(tokenizer.encode(new).ids, newline_adj=-999999999) - save_thread_state('chat_pre', logits) + load_thread_state('chat') + new = f'{user}{separator} {msg}\n\n{bot}{separator}' + process_tokens(tokenizer.encode(new).ids, new_line_logit_bias=-999999999) + save_thread_state('chat_pre') thread = 'chat' # Print bot response - print(f"> {bot}{separator}", end='') + print(f'> {bot}{separator}', end='') - start_index: int = len(model_tokens) + start_index: int = len(processed_tokens) accumulated_tokens: list[int] = [] - occurrence: dict[int, int] = {} + token_counts: dict[int, int] = {} for i in range(MAX_GENERATION_LENGTH): - for n in occurrence: - logits[n] -= (PRESENCE_PENALTY + occurrence[n] * FREQUENCY_PENALTY) + for n in token_counts: + logits[n] -= PRESENCE_PENALTY + token_counts[n] * FREQUENCY_PENALTY + token: int = sampling.sample_logits(logits, temperature, top_p) + if token == END_OF_TEXT_TOKEN: print() break - if token not in occurrence: - occurrence[token] = 1 + + if token not in token_counts: + token_counts[token] = 1 else: - occurrence[token] += 1 - logits: torch.Tensor = process_tokens([token]) + token_counts[token] += 1 + + process_tokens([token]) # Avoid UTF-8 display issues accumulated_tokens += [token] @@ -258,10 +257,10 @@ def load_thread_state(_thread: str) -> torch.Tensor: accumulated_tokens = [] if thread == 'chat': - if '\n\n' in tokenizer.decode(model_tokens[start_index:]): + if '\n\n' in tokenizer.decode(processed_tokens[start_index:]): break if i == MAX_GENERATION_LENGTH - 1: print() - save_thread_state(thread, logits) + save_thread_state(thread) diff --git a/rwkv/convert_pytorch_to_ggml.py b/rwkv/convert_pytorch_to_ggml.py index de4b5d6..e132fd5 100644 --- a/rwkv/convert_pytorch_to_ggml.py +++ b/rwkv/convert_pytorch_to_ggml.py @@ -3,7 +3,6 @@ # Get model checkpoints from https://huggingface.co/BlinkDL # See FILE_FORMAT.md for the documentation on the file format. -import os import argparse import struct import torch @@ -97,53 +96,5 @@ def main() -> None: print('Done') -# --- Tests --- - -def test() -> None: - test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp' - - try: - state_dict: Dict[str, torch.Tensor] = { - 'emb.weight': torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32), - 'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32) - } - - write_state_dict(state_dict, dest_path=test_file_path, data_type='float32') - - with open(test_file_path, 'rb') as input: - actual_bytes: bytes = input.read() - - expected_bytes: bytes = struct.pack( - '=iiiiii' + 'iiiii10sffffff' + 'iiii19sf', - 0x67676d66, - 100, - 3, - 2, - 1, - 0, - # emb.weight - 2, - 10, - 0, - 2, 3, - 'emb.weight'.encode('utf-8'), - 1.0, 2.0, 3.0, - 4.0, 5.0, 6.0, - # blocks.0.ln1.weight - 1, - 19, - 0, - 1, - 'blocks.0.ln1.weight'.encode('utf-8'), - 1.0 - ) - - assert list(actual_bytes) == list(expected_bytes), f'\nActual: {list(actual_bytes)}\nExpected: {list(expected_bytes)}' - - print('All tests pass') - finally: - if os.path.isfile(test_file_path): - os.remove(test_file_path) - if __name__ == "__main__": main() diff --git a/rwkv/convert_pytorch_to_ggml.test.py b/rwkv/convert_pytorch_to_ggml.test.py new file mode 100644 index 0000000..5578506 --- /dev/null +++ b/rwkv/convert_pytorch_to_ggml.test.py @@ -0,0 +1,54 @@ +import os +import struct +import torch +import convert_pytorch_to_ggml +from typing import Dict + +def test() -> None: + test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp' + + try: + state_dict: Dict[str, torch.Tensor] = { + 'emb.weight': torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32), + 'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32) + } + + convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='float32') + + with open(test_file_path, 'rb') as input: + actual_bytes: bytes = input.read() + + expected_bytes: bytes = struct.pack( + '=iiiiii' + 'iiiii10sffffff' + 'iiii19sf', + 0x67676d66, + 100, + 3, + 2, + 1, + 0, + # emb.weight + 2, + 10, + 0, + 2, 3, + 'emb.weight'.encode('utf-8'), + 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, + # blocks.0.ln1.weight + 1, + 19, + 0, + 1, + 'blocks.0.ln1.weight'.encode('utf-8'), + 1.0 + ) + + assert list(actual_bytes) == list(expected_bytes), f'\nActual: {list(actual_bytes)}\nExpected: {list(expected_bytes)}' + + print('All tests pass') + finally: + if os.path.isfile(test_file_path): + os.remove(test_file_path) + +if __name__ == "__main__": + test()