Skip to content

Commit

Permalink
punish repetitions & break if END_OF_TEXT & decouple prompts from cha…
Browse files Browse the repository at this point in the history
…t script (#37)

* punish repetitions & break if END_OF_TEXT

* decouple prompts from chat_with_bot.py

* improve code style

* Update rwkv/chat_with_bot.py

Co-authored-by: Alex <saharNooby@users.noreply.github.com>

* Update rwkv/chat_with_bot.py

Co-authored-by: Alex <saharNooby@users.noreply.github.com>

* add types

* JSON prompt

---------

Co-authored-by: Alex <saharNooby@users.noreply.github.com>
  • Loading branch information
L-M-Sherlock and saharNooby committed Apr 30, 2023
1 parent 06dac0f commit 3621172
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 89 deletions.
123 changes: 34 additions & 89 deletions rwkv/chat_with_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,107 +11,44 @@
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library
import json

# ======================================== Script settings ========================================

# English, Chinese
# English, Chinese, Japanese
LANGUAGE: str = 'English'
# QA: Question and Answer prompt
# Chat: chat prompt (you need a large model for adequate quality, 7B+)
PROMPT_TYPE: str = "Chat"

# True: Q&A prompt
# False: chat prompt (you need a large model for adequate quality, 7B+)
QA_PROMPT: bool = False
PROMPT_FILE: str = f'./rwkv/prompt/{LANGUAGE}-{PROMPT_TYPE}.json'

def load_prompt(PROMPT_FILE: str):
with open(PROMPT_FILE, 'r') as json_file:
variables = json.load(json_file)
user, bot, separator, prompt = variables['user'], variables['bot'], variables['separator'], variables['prompt']
return user, bot, separator, prompt

MAX_GENERATION_LENGTH: int = 250

# Sampling temperature. It could be a good idea to increase temperature when top_p is low.
TEMPERATURE: float = 0.8
# For better Q&A accuracy and less diversity, reduce top_p (to 0.5, 0.2, 0.1 etc.)
TOP_P: float = 0.5

if LANGUAGE == 'English':
separator: str = ':'

if QA_PROMPT:
user: str = 'User'
bot: str = 'Bot'
init_prompt: str = f'''
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and \
polite.
{user}{separator} french revolution what year
{bot}{separator} The French Revolution started in 1789, and lasted 10 years until 1799.
{user}{separator} 3+5=?
{bot}{separator} The answer is 8.
{user}{separator} guess i marry who ?
{bot}{separator} Only if you tell me more about yourself - what are your interests?
{user}{separator} solve for a: 9-a=2
{bot}{separator} The answer is a = 7, because 9 - 7 = 2.
{user}{separator} what is lhc
{bot}{separator} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
'''
else:
user: str = 'Bob'
bot: str = 'Alice'
init_prompt: str = f'''
The following is a verbose detailed conversation between {user} and a young girl {bot}. {bot} is intelligent, friendly and cute. {bot} is likely to agree with {user}.
{user}{separator} Hello {bot}, how are you doing?
{bot}{separator} Hi {user}! Thanks, I'm fine. What about you?
{user}{separator} I am very good! It's nice to see you. Would you mind me chatting with you for a while?
{bot}{separator} Not at all! I'm listening.
'''

elif LANGUAGE == 'Chinese':
separator: str = ':'

if QA_PROMPT:
user: str = 'Q'
bot: str = 'A'
init_prompt: str = f'''
Expert Questions & Helpful Answers
Ask Research Experts
'''
else:
user: str = 'Bob'
bot: str = 'Alice'
init_prompt: str = f'''
The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and \
polite.
{user}{separator} what is lhc
{bot}{separator} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
{user}{separator} 企鹅会飞吗
{bot}{separator} 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。
'''
else:
assert False, f'Invalid language {LANGUAGE}'
# Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
PRESENCE_PENALTY: float = 0.2
# Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
FREQUENCY_PENALTY: float = 0.2
END_OF_LINE_TOKEN: int = 187
END_OF_TEXT_TOKEN: int = 0

# =================================================================================================

parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
args = parser.parse_args()

user, bot, separator, init_prompt = load_prompt(PROMPT_FILE)
assert init_prompt != '', 'Prompt must not be empty'

print('Loading 20B tokenizer')
Expand All @@ -133,7 +70,7 @@

logits, model_state = None, None

def process_tokens(_tokens: list[int]) -> torch.Tensor:
def process_tokens(_tokens: list[int], newline_adj: int = 0) -> torch.Tensor:
global model_tokens, model_state, logits

_tokens = [int(x) for x in _tokens]
Expand All @@ -143,6 +80,8 @@ def process_tokens(_tokens: list[int]) -> torch.Tensor:
for _token in _tokens:
logits, model_state = model.eval(_token, model_state, model_state, logits)

logits[END_OF_LINE_TOKEN] += newline_adj # adjust \n probability

return logits

state_by_thread: dict[str, dict] = {}
Expand All @@ -163,10 +102,7 @@ def load_thread_state(_thread: str) -> torch.Tensor:

print(f'Processing {prompt_token_count} prompt tokens, may take a while')

for token in prompt_tokens:
logits, model_state = model.eval(token, model_state, model_state, logits)

model_tokens.append(token)
logits = process_tokens(tokenizer.encode(init_prompt).ids)

save_thread_state('chat_init', logits)
save_thread_state('chat', logits)
Expand Down Expand Up @@ -286,7 +222,7 @@ def load_thread_state(_thread: str) -> torch.Tensor:
logits = load_thread_state('chat')
new = f"{user}{separator} {msg}\n\n{bot}{separator}"
# print(f'### add ###\n[{new}]')
logits = process_tokens(tokenizer.encode(new).ids)
logits = process_tokens(tokenizer.encode(new).ids, newline_adj=-999999999)
save_thread_state('chat_pre', logits)

thread = 'chat'
Expand All @@ -296,10 +232,19 @@ def load_thread_state(_thread: str) -> torch.Tensor:

start_index: int = len(model_tokens)
accumulated_tokens: list[int] = []
occurrence: dict[int, int] = {}

for i in range(MAX_GENERATION_LENGTH):
for n in occurrence:
logits[n] -= (PRESENCE_PENALTY + occurrence[n] * FREQUENCY_PENALTY)
token: int = sampling.sample_logits(logits, temperature, top_p)

if token == END_OF_TEXT_TOKEN:
print()
break
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
logits: torch.Tensor = process_tokens([token])

# Avoid UTF-8 display issues
Expand Down
6 changes: 6 additions & 0 deletions rwkv/prompt/Chinese-Chat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"user": "Bob",
"bot": "Alice",
"separator": ":",
"prompt": "\nThe following is a coherent verbose detailed conversation between a Chinese girl named Alice and her friend Bob. Alice is very intelligent, creative and friendly. Alice likes to tell Bob a lot about herself and her opinions. Alice usually gives Bob kind, helpful and informative advices.\n\nBob: lhc\n\nAlice: LHC是指大型强子对撞机(Large Hadron Collider),是世界最大最强的粒子加速器,由欧洲核子中心(CERN)在瑞士日内瓦地下建造。LHC的原理是加速质子(氢离子)并让它们相撞,让科学家研究基本粒子和它们之间的相互作用,并在2012年证实了希格斯玻色子的存在。\n\nBob: 企鹅会飞吗\n\nAlice: 企鹅是不会飞的。企鹅的翅膀短而扁平,更像是游泳时的一对桨。企鹅的身体结构和羽毛密度也更适合在水中游泳,而不是飞行。\n\n"
}
6 changes: 6 additions & 0 deletions rwkv/prompt/Chinese-QA.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"user": "Q",
"bot": "A",
"separator": ":",
"prompt": "\nExpert Questions & Helpful Answers\n\nAsk Research Experts\n\n"
}
6 changes: 6 additions & 0 deletions rwkv/prompt/English-Chat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"user": "Bob",
"bot": "Alice",
"separator": ":",
"prompt": "\nThe following is a coherent verbose detailed conversation between a girl named Alice and her friend Bob. Alice is very intelligent, creative and friendly. Alice is unlikely to disagree with Bob, and Alice doesn't like to ask Bob questions. Alice likes to tell Bob a lot about herself and her opinions. Alice usually gives Bob kind, helpful and informative advices.\n\nBob: Hello Alice, how are you doing?\n\nAlice: Hi! Thanks, I'm fine. What about you?\n\nBob: I am fine. It's nice to see you. Look, here is a store selling tea and juice.\n\nAlice: Sure. Let's go inside. I would like to have some Mocha latte, which is my favourite!\n\nBob: What is it?\n\nAlice: Mocha latte is usually made with espresso, milk, chocolate, and frothed milk. Its flavors are frequently sweet.\n\nBob: Sounds tasty. I'll try it next time. Would you like to chat with me for a while?\n\nAlice: Of course! I'm glad to answer your questions or give helpful advices. You know, I am confident with my expertise. So please go ahead!\n\n"
}
6 changes: 6 additions & 0 deletions rwkv/prompt/English-QA.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"user": "User",
"bot": "Bot",
"separator": ":",
"prompt": "\nThe following is a verbose and detailed conversation between an AI assistant called Bot, and a human user called User. Bot is intelligent, knowledgeable, wise and polite.\n\nUser: french revolution what year\n\nBot: The French Revolution started in 1789, and lasted 10 years until 1799.\n\nUser: 3+5=?\n\nBot: The answer is 8.\n\nUser: guess i marry who ?\n\nBot: Only if you tell me more about yourself - what are your interests?\n\nUser: solve for a: 9-a=2\n\nBot: The answer is a = 7, because 9 - 7 = 2.\n\nUser: wat is lhc\n\nBot: LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.\n\n"
}
6 changes: 6 additions & 0 deletions rwkv/prompt/Japanese-Chat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"user": "Bob",
"bot": "Alice",
"separator": ":",
"prompt": "\n以下は、Aliceという女の子とその友人Bobの間で行われた会話です。 Aliceはとても賢く、想像力があり、友好的です。 AliceはBobに反対することはなく、AliceはBobに質問するのは苦手です。 AliceはBobに自分のことや自分の意見をたくさん伝えるのが好きです。 AliceはいつもBobに親切で役に立つ、有益なアドバイスをしてくれます。\n\nBob: こんにちはAlice、調子はどうですか?\n\nAlice: こんにちは!元気ですよ。あたなはどうですか?\n\nBob: 元気ですよ。君に会えて嬉しいよ。見て、この店ではお茶とジュースが売っているよ。\n\nAlice: 本当ですね。中に入りましょう。大好きなモカラテを飲んでみたいです!\n\nBob: モカラテって何ですか?\n\nAlice: モカラテはエスプレッソ、ミルク、チョコレート、泡立てたミルクから作られた飲み物です。香りはとても甘いです。\n\nBob: それは美味しそうですね。今度飲んでみます。しばらく私とおしゃべりしてくれますか?\n\nAlice: もちろん!ご質問やアドバイスがあれば、喜んでお答えします。専門的な知識には自信がありますよ。どうぞよろしくお願いいたします!\n\n"
}
6 changes: 6 additions & 0 deletions rwkv/prompt/Japanese-QA.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"user": "User",
"bot": "Bot",
"separator": ":",
"prompt": "\n以下は、Botと呼ばれるAIアシスタントとUserと呼ばれる人間との間で行われた会話です。Botは知的で、知識が豊富で、賢くて、礼儀正しいです。\n\nUser: フランス革命は何年に起きましたか?\n\nBot: フランス革命は1789年に始まり、1799年まで10年間続きました。\n\nUser: 3+5=?\n\nBot: 答えは8です。\n\nUser: 私は誰と結婚すると思いますか?\n\nBot: あなたのことをもっと教えていただけないとお答えすることができません。\n\nUser: aの値を求めてください: 9-a=2\n\nBot: a = 7です、なぜなら 9 - 7 = 2だからです。\n\nUser: lhcって何ですか?\n\nBot: LHCは、CERNが建設し、2008年に完成した高エネルギー粒子衝突型加速器です。2012年にヒッグス粒子の存在を確認するために使用されました。\n\n"
}

0 comments on commit 3621172

Please sign in to comment.