From eec20c56d890aefbc4b2177cc1d7451602110dec Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Mon, 17 Jul 2023 08:25:29 +0800 Subject: [PATCH 01/10] Add API server --- rwkv/api.py | 305 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 305 insertions(+) create mode 100644 rwkv/api.py diff --git a/rwkv/api.py b/rwkv/api.py new file mode 100644 index 0000000..7efe687 --- /dev/null +++ b/rwkv/api.py @@ -0,0 +1,305 @@ +import time +import json +import logging +import argparse +import uvicorn +import sampling +from functools import partial +import rwkv_cpp_model +import rwkv_cpp_shared_library +from rwkv_tokenizer import get_tokenizer +from fastapi import FastAPI, Request +from flask import Flask, request, Response, stream_with_context, jsonify +from threading import Lock +from werkzeug.local import Local +from typing import List, Dict, Optional +from pydantic import BaseModel, Field +from sse_starlette.sse import EventSourceResponse +from contextlib import asynccontextmanager + + +# ----------- +END_OF_LINE_TOKEN: int = 187 +DOUBLE_END_OF_LINE_TOKEN: int = 535 +END_OF_TEXT_TOKEN: int = 0 +DEFAULT_PROMPT = 'Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it' +DEFAULT_STOP = '\n\nUser' + +parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model') +parser.add_argument('model_path', help='Path to RWKV model in ggml format') +parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='world') +args = parser.parse_args() + +completion_lock = Lock() +requests_num = 0 + + +async def run_with_lock(func, request): + global requests_num + requests_num = requests_num + 1 + logging.debug("Start Waiting. RequestsNum: %r", requests_num) + while completion_lock.locked(): + if await request.is_disconnected(): + logging.debug("Stop Waiting (Lock). RequestsNum: %r", requests_num) + return + # 等待 + logging.debug("Waiting. RequestsNum: %r", requests_num) + time.sleep(0.1) + else: + with completion_lock: + if await request.is_disconnected(): + logging.debug("Stop Waiting (Lock). RequestsNum: %r", requests_num) + return + # if await request.is_disconnected(): + # new = f'{user}{separator} {msg}\n\n{bot}{separator}' + # process_tokens(tokenizer_encode(new), new_line_logit_bias=-999999999) + return func() + + +def generate_completions( + model, + prompt, + max_tokens=256, # 这个是不是不应该用? + temperature=0.8, + top_p=0.5, + presence_penalty=0.2, # [控制主题的重复度] + frequency_penalty=0.2, # [重复度惩罚因子] + stop=DEFAULT_STOP, + usage=dict(), + **kwargs, +): + state = Local() + state.logits = None + state.state = None + # logits, state = None, None + prompt_tokens = tokenizer_encode(prompt) + prompt_token_count = len(prompt_tokens) + usage['prompt_tokens'] = prompt_token_count + logging.debug(f'{prompt_token_count} tokens in prompt') + for token in prompt_tokens: + state.logits, state.state = model.eval(token, state.state, state.state, state.logits) + logging.debug('end eval prompt_tokens') + + accumulated_tokens: List[int] = [] # 用于处理UTF8字符问题 + completion_tokens = [] + token_counts: Dict[int, int] = {} + result = '' + while True: + for n in token_counts: + state.logits[n] -= presence_penalty + token_counts[n] * frequency_penalty + token = sampling.sample_logits(state.logits, temperature, top_p) + completion_tokens.append(token) + # 退出生成 + if token == END_OF_TEXT_TOKEN: + break + if token not in token_counts: + token_counts[token] = 1 + else: + token_counts[token] += 1 + + decoded = tokenizer_decode([token]) + # Avoid UTF-8 display issues + accumulated_tokens += [token] + decoded: str = tokenizer_decode(accumulated_tokens) + if '\uFFFD' not in decoded: + # 退出生成 + result += decoded + if stop in result: + break + # 输出 + print(decoded, end='', flush=True) + yield decoded + accumulated_tokens = [] + + if len(completion_tokens) >= max_tokens: + break + state.logits, state.state = model.eval(token, state.state, state.state, state.logits) + usage['prompt_tokens'] = prompt_token_count + usage['completion_tokens'] = len(completion_tokens) + + +def format_message(response, delta, chunk=False, chat_model=False, model_name='rwkv', finish_reason=None): + if chat_model: + object = 'text_completion' + else: + if chunk: + object = 'chat.completion.chunk' + else: + object = 'chat.completion' + + return { + 'object': object, + 'response': response, + 'model': model_name, + 'choices': [{ + 'delta': {'content': delta}, + 'index': 0, + 'finish_reason': finish_reason, + } if chat_model else { + 'text': delta, + 'index': 0, + 'finish_reason': finish_reason, + }] + } + + +tokenizer_decode, tokenizer_encode, model = None, None, None +app = FastAPI() + + +@app.on_event("startup") +async def startup_event(): + # 只初始化一次 + global tokenizer_decode, tokenizer_encode, model + # get world tokenizer + tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer) + library = rwkv_cpp_shared_library.load_rwkv_shared_library() + logging.info('System info: %r', library.rwkv_get_system_info_string()) + logging.info('Start Loading RWKV model') + model = rwkv_cpp_model.RWKVModel(library, args.model_path) + logging.info('End Loading RWKV model') + + +@app.on_event("shutdown") +def shutdown_event(): + model.free() + + +async def process_generate(prompt, stop, stream, body, request): + usage = {} + func = partial( + generate_completions, + model, f'User: {prompt}\n\nBot: ', + max_tokens=body.max_tokens or 1000, + temperature=body.temperature, + top_p=body.top_p, + presence_penalty=body.presence_penalty, + frequency_penalty=body.frequency_penalty, + stop=stop, usage=usage, + ) + + async def generate(): + response = '' + for delta in await run_with_lock(func, request): + response += delta + if stream: + chunk = format_message('', delta, chunk=True) + yield json.dumps(chunk) + if stream: + result = format_message('', '', chunk=True, finish_reason='stop') + result.update(usage=usage) + yield json.dumps(result) + else: + result = format_message(response, '', chunk=False, finish_reason='stop') + result.update(usage=usage) + yield result + + if stream: + return EventSourceResponse(generate()) + return await generate().__anext__() + + +class ModelConfigBody(BaseModel): + max_tokens: int = Field(default=1000, gt=0, le=102400) + temperature: float = Field(default=0.8, ge=0, le=2) + top_p: float = Field(default=0.5, ge=0, le=1) + presence_penalty: float = Field(default=0.2, ge=-2, le=2) + frequency_penalty: float = Field(default=0.2, ge=-2, le=2) + + class Config: + schema_extra = { + "example": { + "max_tokens": 1000, + "temperature": 1.2, + "top_p": 0.5, + "presence_penalty": 0.4, + "frequency_penalty": 0.4, + } + } + + +class Message(BaseModel): + role: str + content: str + + +class ChatCompletionBody(ModelConfigBody): + messages: List[Message] + model: str = "rwkv" + stream: bool = False + stop: str = DEFAULT_STOP + + class Config: + schema_extra = { + "example": { + "messages": [{"role": "user", "content": "hello"}], + "model": "rwkv", + "stream": False, + "stop": None, + "max_tokens": 1000, + "temperature": 1.2, + "top_p": 0.5, + "presence_penalty": 0.4, + "frequency_penalty": 0.4, + } + } + + +class CompletionBody(ModelConfigBody): + prompt: str or List[str] + model: str = "rwkv" + stream: bool = False + stop: str = DEFAULT_STOP + + class Config: + schema_extra = { + "example": { + "prompt": "The following is an epic science fiction masterpiece that is immortalized, " + + "with delicate descriptions and grand depictions of interstellar civilization wars.\nChapter 1.\n", + "model": "rwkv", + "stream": False, + "stop": None, + "max_tokens": 100, + "temperature": 1.2, + "top_p": 0.5, + "presence_penalty": 0.4, + "frequency_penalty": 0.4, + } + } + + + +@app.post('/v1/completions') +@app.post('/completions') +async def completions(body: CompletionBody, request: Request): + return await process_generate(body.prompt, body.stop, body.stream, body, request) + + +@app.post('/v1/chat/completions') +@app.post('/chat/completions') +async def chat_completions(body: ChatCompletionBody, request: Request): + usage = {} + + if len(body.messages) == 0 or body.messages[-1].role != 'user': + raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found") + + system_role = DEFAULT_PROMPT + for message in body.messages: + if message.role == 'system': + system_role = message.content + + completion_text = f'User: {DEFAULT_PROMPT}\n\n' + for message in body.messages: + if message.role == 'user': + content = message.content.replace("\\n", "\n").replace("\r\n", "\n").replace("\n\n", "\n").strip() + completion_text += f'User: {content}\n\n' + elif message.role == 'assistant': + content = message.content.replace("\\n", "\n").replace("\r\n", "\n").replace("\n\n", "\n").strip() + completion_text += f'Bot: {content}\n\n' + completion_text += f"Bot: " + + return await process_generate(completion_text, body.stop, body.stream, body, request) + + +if __name__ == "__main__": + uvicorn.run("api:app", workers=0) From 347e0c9409aeb5b9f281715cf19583d66420f540 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Mon, 17 Jul 2023 08:28:10 +0800 Subject: [PATCH 02/10] Add API server --- rwkv/api.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/rwkv/api.py b/rwkv/api.py index 7efe687..97ca722 100644 --- a/rwkv/api.py +++ b/rwkv/api.py @@ -9,9 +9,7 @@ import rwkv_cpp_shared_library from rwkv_tokenizer import get_tokenizer from fastapi import FastAPI, Request -from flask import Flask, request, Response, stream_with_context, jsonify from threading import Lock -from werkzeug.local import Local from typing import List, Dict, Optional from pydantic import BaseModel, Field from sse_starlette.sse import EventSourceResponse @@ -68,16 +66,13 @@ def generate_completions( usage=dict(), **kwargs, ): - state = Local() - state.logits = None - state.state = None - # logits, state = None, None + logits, state = None, None prompt_tokens = tokenizer_encode(prompt) prompt_token_count = len(prompt_tokens) usage['prompt_tokens'] = prompt_token_count logging.debug(f'{prompt_token_count} tokens in prompt') for token in prompt_tokens: - state.logits, state.state = model.eval(token, state.state, state.state, state.logits) + logits, state = model.eval(token, state, state, logits) logging.debug('end eval prompt_tokens') accumulated_tokens: List[int] = [] # 用于处理UTF8字符问题 @@ -86,8 +81,8 @@ def generate_completions( result = '' while True: for n in token_counts: - state.logits[n] -= presence_penalty + token_counts[n] * frequency_penalty - token = sampling.sample_logits(state.logits, temperature, top_p) + logits[n] -= presence_penalty + token_counts[n] * frequency_penalty + token = sampling.sample_logits(logits, temperature, top_p) completion_tokens.append(token) # 退出生成 if token == END_OF_TEXT_TOKEN: @@ -113,7 +108,7 @@ def generate_completions( if len(completion_tokens) >= max_tokens: break - state.logits, state.state = model.eval(token, state.state, state.state, state.logits) + logits, state = model.eval(token, state, state, logits) usage['prompt_tokens'] = prompt_token_count usage['completion_tokens'] = len(completion_tokens) From 5dd7aa9e0246c4e08fae9500d98de1f5028e2ff1 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Mon, 17 Jul 2023 08:33:44 +0800 Subject: [PATCH 03/10] format message --- rwkv/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rwkv/api.py b/rwkv/api.py index 97ca722..7ea9bde 100644 --- a/rwkv/api.py +++ b/rwkv/api.py @@ -181,11 +181,11 @@ async def generate(): chunk = format_message('', delta, chunk=True) yield json.dumps(chunk) if stream: - result = format_message('', '', chunk=True, finish_reason='stop') + result = format_message(response, '', chunk=True, finish_reason='stop') result.update(usage=usage) yield json.dumps(result) else: - result = format_message(response, '', chunk=False, finish_reason='stop') + result = format_message(response, response, chunk=False, finish_reason='stop') result.update(usage=usage) yield result From 8b2977caa9103ed509ef47c6cbd7ca57d9fab680 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Mon, 17 Jul 2023 08:48:58 +0800 Subject: [PATCH 04/10] Add API server --- rwkv/api.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/rwkv/api.py b/rwkv/api.py index 7ea9bde..ad7e8d8 100644 --- a/rwkv/api.py +++ b/rwkv/api.py @@ -26,6 +26,8 @@ parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model') parser.add_argument('model_path', help='Path to RWKV model in ggml format') parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='world') +parser.add_argument('host', help='host', nargs='?', type=str, default='0.0.0.0') +parser.add_argument('port', help='port', nargs='?', type=int, default=8000) args = parser.parse_args() completion_lock = Lock() @@ -114,7 +116,7 @@ def generate_completions( def format_message(response, delta, chunk=False, chat_model=False, model_name='rwkv', finish_reason=None): - if chat_model: + if not chat_model: object = 'text_completion' else: if chunk: @@ -160,7 +162,7 @@ def shutdown_event(): model.free() -async def process_generate(prompt, stop, stream, body, request): +async def process_generate(prompt, stop, stream, chat_model, body, request): usage = {} func = partial( generate_completions, @@ -178,14 +180,14 @@ async def generate(): for delta in await run_with_lock(func, request): response += delta if stream: - chunk = format_message('', delta, chunk=True) + chunk = format_message('', delta, chunk=True, chat_model=chat_model) yield json.dumps(chunk) if stream: - result = format_message(response, '', chunk=True, finish_reason='stop') + result = format_message(response, '', chunk=True, chat_model=chat_model, finish_reason='stop') result.update(usage=usage) yield json.dumps(result) else: - result = format_message(response, response, chunk=False, finish_reason='stop') + result = format_message(response, response, chunk=False, chat_model=chat_model, finish_reason='stop') result.update(usage=usage) yield result @@ -267,7 +269,7 @@ class Config: @app.post('/v1/completions') @app.post('/completions') async def completions(body: CompletionBody, request: Request): - return await process_generate(body.prompt, body.stop, body.stream, body, request) + return await process_generate(body.prompt, body.stop, body.stream, False, body, request) @app.post('/v1/chat/completions') @@ -293,8 +295,8 @@ async def chat_completions(body: ChatCompletionBody, request: Request): completion_text += f'Bot: {content}\n\n' completion_text += f"Bot: " - return await process_generate(completion_text, body.stop, body.stream, body, request) + return await process_generate(completion_text, body.stop, body.stream, True, body, request) if __name__ == "__main__": - uvicorn.run("api:app", workers=0) + uvicorn.run("api:app", host=args.host, port=args.port) From 3c25a3533a0dbbc5c51a82c0724ea5e5fa06d920 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 18 Jul 2023 17:27:15 +0800 Subject: [PATCH 05/10] add cors --- rwkv/api.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/rwkv/api.py b/rwkv/api.py index ad7e8d8..2c19796 100644 --- a/rwkv/api.py +++ b/rwkv/api.py @@ -8,12 +8,13 @@ import rwkv_cpp_model import rwkv_cpp_shared_library from rwkv_tokenizer import get_tokenizer -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, HTTPException, status from threading import Lock from typing import List, Dict, Optional from pydantic import BaseModel, Field from sse_starlette.sse import EventSourceResponse from contextlib import asynccontextmanager +from fastapi.middleware.cors import CORSMiddleware # ----------- @@ -142,6 +143,13 @@ def format_message(response, delta, chunk=False, chat_model=False, model_name='r tokenizer_decode, tokenizer_encode, model = None, None, None app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) @app.on_event("startup") From fbd525b0f209be59cc6d22cbb85c69f9cb1c8a5c Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 18 Jul 2023 23:10:45 +0800 Subject: [PATCH 06/10] hotfix --- rwkv/api.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/rwkv/api.py b/rwkv/api.py index 2c19796..c2ab061 100644 --- a/rwkv/api.py +++ b/rwkv/api.py @@ -51,9 +51,6 @@ async def run_with_lock(func, request): if await request.is_disconnected(): logging.debug("Stop Waiting (Lock). RequestsNum: %r", requests_num) return - # if await request.is_disconnected(): - # new = f'{user}{separator} {msg}\n\n{bot}{separator}' - # process_tokens(tokenizer_encode(new), new_line_logit_bias=-999999999) return func() @@ -145,7 +142,7 @@ def format_message(response, delta, chunk=False, chat_model=False, model_name='r app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=origins, + allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -293,7 +290,7 @@ async def chat_completions(body: ChatCompletionBody, request: Request): if message.role == 'system': system_role = message.content - completion_text = f'User: {DEFAULT_PROMPT}\n\n' + completion_text = f'{system_role}\n\n' for message in body.messages: if message.role == 'user': content = message.content.replace("\\n", "\n").replace("\r\n", "\n").replace("\n\n", "\n").strip() @@ -301,7 +298,6 @@ async def chat_completions(body: ChatCompletionBody, request: Request): elif message.role == 'assistant': content = message.content.replace("\\n", "\n").replace("\r\n", "\n").replace("\n\n", "\n").strip() completion_text += f'Bot: {content}\n\n' - completion_text += f"Bot: " return await process_generate(completion_text, body.stop, body.stream, True, body, request) From 24d91c35b73956401e0a0d672499562876b370fb Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Tue, 18 Jul 2023 23:40:57 +0800 Subject: [PATCH 07/10] using pydantic setting from env --- rwkv/api.py | 101 +++++++++++++++++++++++++++++----------------------- 1 file changed, 57 insertions(+), 44 deletions(-) diff --git a/rwkv/api.py b/rwkv/api.py index c2ab061..b29e33f 100644 --- a/rwkv/api.py +++ b/rwkv/api.py @@ -1,7 +1,6 @@ import time import json import logging -import argparse import uvicorn import sampling from functools import partial @@ -11,25 +10,25 @@ from fastapi import FastAPI, Request, HTTPException, status from threading import Lock from typing import List, Dict, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, BaseSettings from sse_starlette.sse import EventSourceResponse from contextlib import asynccontextmanager from fastapi.middleware.cors import CORSMiddleware -# ----------- +# ----- constant ---- END_OF_LINE_TOKEN: int = 187 DOUBLE_END_OF_LINE_TOKEN: int = 535 END_OF_TEXT_TOKEN: int = 0 -DEFAULT_PROMPT = 'Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it' -DEFAULT_STOP = '\n\nUser' -parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model') -parser.add_argument('model_path', help='Path to RWKV model in ggml format') -parser.add_argument('tokenizer', help='Tokenizer to use; supported tokenizers: 20B, world', nargs='?', type=str, default='world') -parser.add_argument('host', help='host', nargs='?', type=str, default='0.0.0.0') -parser.add_argument('port', help='port', nargs='?', type=int, default=8000) -args = parser.parse_args() + +# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end. +# See https://github.com/BlinkDL/ChatRWKV/pull/110/files +def split_last_end_of_line(tokens): + if len(tokens) > 0 and tokens[-1] == DOUBLE_END_OF_LINE_TOKEN: + tokens = tokens[:-1] + [END_OF_LINE_TOKEN, END_OF_LINE_TOKEN] + return tokens + completion_lock = Lock() requests_num = 0 @@ -62,15 +61,16 @@ def generate_completions( top_p=0.5, presence_penalty=0.2, # [控制主题的重复度] frequency_penalty=0.2, # [重复度惩罚因子] - stop=DEFAULT_STOP, + stop='', usage=dict(), **kwargs, ): logits, state = None, None - prompt_tokens = tokenizer_encode(prompt) + prompt_tokens = split_last_end_of_line(tokenizer_encode(prompt)) prompt_token_count = len(prompt_tokens) usage['prompt_tokens'] = prompt_token_count logging.debug(f'{prompt_token_count} tokens in prompt') + for token in prompt_tokens: logits, state = model.eval(token, state, state, logits) logging.debug('end eval prompt_tokens') @@ -113,32 +113,20 @@ def generate_completions( usage['completion_tokens'] = len(completion_tokens) -def format_message(response, delta, chunk=False, chat_model=False, model_name='rwkv', finish_reason=None): - if not chat_model: - object = 'text_completion' - else: - if chunk: - object = 'chat.completion.chunk' - else: - object = 'chat.completion' - - return { - 'object': object, - 'response': response, - 'model': model_name, - 'choices': [{ - 'delta': {'content': delta}, - 'index': 0, - 'finish_reason': finish_reason, - } if chat_model else { - 'text': delta, - 'index': 0, - 'finish_reason': finish_reason, - }] - } +class Settings(BaseSettings): + server_name: str = "RWKV API Server" + default_prompt: str = "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it" + default_stop: str = '\n\nUser' + user_name: str = 'User' + bot_name: str = 'Bot' + model_path: str = '' # Path to RWKV model in ggml format + tokenizer: str = 'world' # Tokenizer to use; supported tokenizers: 20B, world + host: str = '0.0.0.0' + port: int = 8000 tokenizer_decode, tokenizer_encode, model = None, None, None +settings = Settings() app = FastAPI() app.add_middleware( CORSMiddleware, @@ -154,11 +142,11 @@ async def startup_event(): # 只初始化一次 global tokenizer_decode, tokenizer_encode, model # get world tokenizer - tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer) + tokenizer_decode, tokenizer_encode = get_tokenizer(settings.tokenizer) library = rwkv_cpp_shared_library.load_rwkv_shared_library() logging.info('System info: %r', library.rwkv_get_system_info_string()) logging.info('Start Loading RWKV model') - model = rwkv_cpp_model.RWKVModel(library, args.model_path) + model = rwkv_cpp_model.RWKVModel(library, settings.model_path) logging.info('End Loading RWKV model') @@ -201,6 +189,31 @@ async def generate(): return await generate().__anext__() +def format_message(response, delta, chunk=False, chat_model=False, model_name='rwkv', finish_reason=None): + if not chat_model: + object = 'text_completion' + else: + if chunk: + object = 'chat.completion.chunk' + else: + object = 'chat.completion' + + return { + 'object': object, + 'response': response, + 'model': model_name, + 'choices': [{ + 'delta': {'content': delta}, + 'index': 0, + 'finish_reason': finish_reason, + } if chat_model else { + 'text': delta, + 'index': 0, + 'finish_reason': finish_reason, + }] + } + + class ModelConfigBody(BaseModel): max_tokens: int = Field(default=1000, gt=0, le=102400) temperature: float = Field(default=0.8, ge=0, le=2) @@ -229,7 +242,7 @@ class ChatCompletionBody(ModelConfigBody): messages: List[Message] model: str = "rwkv" stream: bool = False - stop: str = DEFAULT_STOP + stop: str = '' class Config: schema_extra = { @@ -251,7 +264,7 @@ class CompletionBody(ModelConfigBody): prompt: str or List[str] model: str = "rwkv" stream: bool = False - stop: str = DEFAULT_STOP + stop: str = '' class Config: schema_extra = { @@ -274,7 +287,7 @@ class Config: @app.post('/v1/completions') @app.post('/completions') async def completions(body: CompletionBody, request: Request): - return await process_generate(body.prompt, body.stop, body.stream, False, body, request) + return await process_generate(body.prompt, body.stop or settings.default_stop, body.stream, False, body, request) @app.post('/v1/chat/completions') @@ -285,7 +298,7 @@ async def chat_completions(body: ChatCompletionBody, request: Request): if len(body.messages) == 0 or body.messages[-1].role != 'user': raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found") - system_role = DEFAULT_PROMPT + system_role = settings.default_prompt for message in body.messages: if message.role == 'system': system_role = message.content @@ -299,8 +312,8 @@ async def chat_completions(body: ChatCompletionBody, request: Request): content = message.content.replace("\\n", "\n").replace("\r\n", "\n").replace("\n\n", "\n").strip() completion_text += f'Bot: {content}\n\n' - return await process_generate(completion_text, body.stop, body.stream, True, body, request) + return await process_generate(completion_text, body.stop or settings.default_stop, body.stream, True, body, request) if __name__ == "__main__": - uvicorn.run("api:app", host=args.host, port=args.port) + uvicorn.run("api:app", host=settings.host, port=settings.port) From 2228fa472ea43ca77716a402fc17ace47f0b9080 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Wed, 19 Jul 2023 01:26:27 +0800 Subject: [PATCH 08/10] add Dockerfile --- .dockerignore | 3 +++ Dockerfile | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 .dockerignore create mode 100644 Dockerfile diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..314af24 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +CMakeCache.txt +*.bin +data/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..629c413 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,29 @@ +FROM python:3.8 as builder + +RUN sed -i "s@http://deb.debian.org@http://mirrors.aliyun.com@g" /etc/apt/sources.list +RUN apt-get update && apt-get install -y g++ cmake + +ADD . /work + +RUN cd /work && cmake . && cmake --build . --config Release + + +FROM python:3.8 + +COPY --from=builder /work/librwkv.so /librwkv.so + +ADD rwkv/rwkv_cpp_model.py /rwkv/rwkv_cpp_model.py +ADD rwkv/rwkv_cpp_shared_library.py /rwkv/rwkv_cpp_shared_library.py +ADD rwkv/rwkv_tokenizer.py /rwkv/rwkv_tokenizer.py +ADD rwkv/sampling.py /rwkv/sampling.py +ADD rwkv/20B_tokenizer.json /rwkv/20B_tokenizer.json +ADD rwkv/rwkv_vocab_v20230424.txt /rwkv/rwkv_vocab_v20230424.txt +ADD rwkv/api.py /rwkv/api.py + +RUN pip3 install uvicorn numpy tokenizers fastapi==0.92.0 sse_starlette -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn --no-cache-dir +RUN pip3 install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + +WORKDIR /rwkv + +CMD ["python", "api.py"] + From c0d317be9153a5523422369199e2c56fa1b3f620 Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Wed, 19 Jul 2023 10:06:32 +0800 Subject: [PATCH 09/10] reduce docker image size --- Dockerfile | 9 ++++----- rwkv/api.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Dockerfile b/Dockerfile index 629c413..a617746 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.8 as builder +FROM python:3.9 as builder RUN sed -i "s@http://deb.debian.org@http://mirrors.aliyun.com@g" /etc/apt/sources.list RUN apt-get update && apt-get install -y g++ cmake @@ -8,7 +8,9 @@ ADD . /work RUN cd /work && cmake . && cmake --build . --config Release -FROM python:3.8 +FROM python:3.9 + +RUN pip3 install uvicorn numpy tokenizers fastapi==0.92.0 sse_starlette torch==2.0.1+cpu -f https://download.pytorch.org/whl/torch_stable.html -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn --no-cache-dir && strip /usr/local/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so COPY --from=builder /work/librwkv.so /librwkv.so @@ -20,9 +22,6 @@ ADD rwkv/20B_tokenizer.json /rwkv/20B_tokenizer.json ADD rwkv/rwkv_vocab_v20230424.txt /rwkv/rwkv_vocab_v20230424.txt ADD rwkv/api.py /rwkv/api.py -RUN pip3 install uvicorn numpy tokenizers fastapi==0.92.0 sse_starlette -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn --no-cache-dir -RUN pip3 install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html - WORKDIR /rwkv CMD ["python", "api.py"] diff --git a/rwkv/api.py b/rwkv/api.py index b29e33f..98e0176 100644 --- a/rwkv/api.py +++ b/rwkv/api.py @@ -119,7 +119,7 @@ class Settings(BaseSettings): default_stop: str = '\n\nUser' user_name: str = 'User' bot_name: str = 'Bot' - model_path: str = '' # Path to RWKV model in ggml format + model_path: str = '/model.bin' # Path to RWKV model in ggml format tokenizer: str = 'world' # Tokenizer to use; supported tokenizers: 20B, world host: str = '0.0.0.0' port: int = 8000 From cae32b7b43620e7c8abf67fcb075e71db21b044b Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Wed, 19 Jul 2023 17:47:56 +0800 Subject: [PATCH 10/10] change run_with_lock to decorate --- rwkv/api.py | 59 +++++++++++++++++++++++++++++------------------------ 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/rwkv/api.py b/rwkv/api.py index 98e0176..6b325a2 100644 --- a/rwkv/api.py +++ b/rwkv/api.py @@ -3,7 +3,7 @@ import logging import uvicorn import sampling -from functools import partial +import functools import rwkv_cpp_model import rwkv_cpp_shared_library from rwkv_tokenizer import get_tokenizer @@ -34,26 +34,32 @@ def split_last_end_of_line(tokens): requests_num = 0 -async def run_with_lock(func, request): - global requests_num - requests_num = requests_num + 1 - logging.debug("Start Waiting. RequestsNum: %r", requests_num) - while completion_lock.locked(): - if await request.is_disconnected(): - logging.debug("Stop Waiting (Lock). RequestsNum: %r", requests_num) - return - # 等待 - logging.debug("Waiting. RequestsNum: %r", requests_num) - time.sleep(0.1) - else: - with completion_lock: +def run_with_lock(method): + @functools.wraps(method) + async def wrapper(request, *args, **kwargs): + global requests_num + requests_num = requests_num + 1 + logging.debug("Start Waiting. RequestsNum: %r", requests_num) + while completion_lock.locked(): if await request.is_disconnected(): logging.debug("Stop Waiting (Lock). RequestsNum: %r", requests_num) return - return func() + # 等待 + logging.debug("Waiting. RequestsNum: %r", requests_num) + time.sleep(0.1) + else: + with completion_lock: + if await request.is_disconnected(): + logging.debug("Stop Waiting (Lock). RequestsNum: %r", requests_num) + return + return method(request, *args, **kwargs) + + return wrapper -def generate_completions( +@run_with_lock +async def generate_completions( + request, # using in run_with_lock model, prompt, max_tokens=256, # 这个是不是不应该用? @@ -157,20 +163,19 @@ def shutdown_event(): async def process_generate(prompt, stop, stream, chat_model, body, request): usage = {} - func = partial( - generate_completions, - model, f'User: {prompt}\n\nBot: ', - max_tokens=body.max_tokens or 1000, - temperature=body.temperature, - top_p=body.top_p, - presence_penalty=body.presence_penalty, - frequency_penalty=body.frequency_penalty, - stop=stop, usage=usage, - ) async def generate(): response = '' - for delta in await run_with_lock(func, request): + async for delta in await generate_completions( + request, + model, f'User: {prompt}\n\nBot: ', + max_tokens=body.max_tokens or 1000, + temperature=body.temperature, + top_p=body.top_p, + presence_penalty=body.presence_penalty, + frequency_penalty=body.frequency_penalty, + stop=stop, usage=usage, + ): response += delta if stream: chunk = format_message('', delta, chunk=True, chat_model=chat_model)