Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python API restructurization & code style improvements #130

Merged
merged 15 commits into from
Sep 20, 2023
Merged
67 changes: 23 additions & 44 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [gger

Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **focused on CPU**, but cuBLAS is also supported.

This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it.
This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](python%2Frwkv_cpp%2Frwkv_cpp_model.py) for it.

[RWKV](https://arxiv.org/abs/2305.13048) is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.

Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py).

### Quality and performance
⚠️ **Python API was restructured on 2023-09-20**, you may need to change paths/package names in your code when updating `rwkv.cpp`.

## Quality and performance

If you use `rwkv.cpp` for anything serious, please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you.

Expand All @@ -26,7 +28,7 @@ Below table is for reference only. Measurements were made on 4C/8T x86 CPU with
| `FP16` | **15.623** | 117 | 2.82 |
| `FP32` | **15.623** | 198 | 5.64 |

#### With cuBLAS
### With cuBLAS

Measurements were made on Intel i7 13700K & NVIDIA 3060 Ti 8 GB. Latency per token in ms shown.

Expand Down Expand Up @@ -124,80 +126,64 @@ This option would require a little more manual work, but you can use it with any

```commandline
# Windows
python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin FP16
python python\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin FP16

# Linux / MacOS
python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin FP16
python python/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin FP16
```

**Optionally**, quantize the model into one of quantized formats from the table above:

```commandline
# Windows
python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1
python python\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1

# Linux / MacOS
python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q5_1.bin Q5_1
python python/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q5_1.bin Q5_1
```

### 4. Run the model

**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/).
#### Using the command line

**Note**: change the model path with the non-quantized model for the full weights model.
**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/).

To generate some text, run:

```commandline
# Windows
python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q5_1.bin
python python\generate_completions.py C:\rwkv.cpp-169M-Q5_1.bin

# Linux / MacOS
python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin
python python/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin
```

To chat with a bot, run:

```commandline
# Windows
python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q5_1.bin
python python\chat_with_bot.py C:\rwkv.cpp-169M-Q5_1.bin

# Linux / MacOS
python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin
python python/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin
```

Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings.

---

Example of using `rwkv.cpp` in your custom Python script:

```python
import rwkv_cpp_model
import rwkv_cpp_shared_library
#### Using in your own code

# Change to model paths used above (quantized or full weights)
model_path = r'C:\rwkv.cpp-169M.bin'
The short and simple script [inference_example.py](python%2Finference_example.py) demostrates the use of `rwkv.cpp` in Python.

To use `rwkv.cpp` in C/C++, include the header [rwkv.h](rwkv.h).

model = rwkv_cpp_model.RWKVModel(
rwkv_cpp_shared_library.load_rwkv_shared_library(),
model_path,
thread_count=4, #need to adjust when use cuBLAS
gpu_layers_count=5 #only enabled when use cuBLAS
)
To use `rwkv.cpp` in any other language, see [Bindings](#Bindings) section below. If your language is missing, you can try to bind to the C API using the tooling provided by your language.

logits, state = None, None

for token in [1, 2, 3]:
logits, state = model.eval(token, state)

print(f'Output logits: {logits}')
## Bindings

# Don't forget to free the memory after you've done working with the model
model.free()
These projects wrap `rwkv.cpp` for easier use in other languages/frameworks.

```
* Golang: [seasonjs/rwkv](https://github.com/seasonjs/rwkv)
* Node.js: [Atome-FE/llama-node](https://github.com/Atome-FE/llama-node)

## Compatibility

Expand All @@ -214,13 +200,6 @@ For reference only, here is a list of latest versions of `rwkv.cpp` that have su

See also [docs/FILE_FORMAT.md](docs/FILE_FORMAT.md) for version numbers of `rwkv.cpp` model files and their changelog.

## Bindings

These projects wrap `rwkv.cpp` for easier use in other languages/frameworks.

* Golang: [seasonjs/rwkv](https://github.com/seasonjs/rwkv)
* Node.js: [Atome-FE/llama-node](https://github.com/Atome-FE/llama-node)

## Contributing

Please follow the code style described in [docs/CODE_STYLE.md](docs/CODE_STYLE.md).
4 changes: 4 additions & 0 deletions docs/CODE_STYLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ Overall, keep code in similar style as it was before.
- Place braces on the same line as the statement.
- Always add braces to `if`, `for`, `while`, `do` and other similar statements.
- Prefix top-level function and struct names with `rwkv_`.
- Mark all functions that are not members of public API as `static`.
- Public API is the set of functions defined in `rwkv.h`.
- Mark all immutable function arguments as `const`.
- This is not required for local variables.

## Python

Expand Down
6 changes: 2 additions & 4 deletions extras/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,5 @@ function(rwkv_add_extra source)
endif()
endfunction()

file(GLOB extras *.c)
foreach (extra ${extras})
rwkv_add_extra(${extra})
endforeach()
rwkv_add_extra(cpu_info.c)
rwkv_add_extra(quantize.c)
8 changes: 5 additions & 3 deletions extras/cpu_info.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include "rwkv.h"

#include <stdio.h>

int main() {
#include <rwkv.h>

int main(void) {
printf("%s", rwkv_get_system_info_string());

return 0;
}
23 changes: 14 additions & 9 deletions extras/quantize.c
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include "ggml.h"
#include "rwkv.h"

#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#ifdef _WIN32
#include <ggml.h>
#include <rwkv.h>

#if defined(_WIN32)
bool QueryPerformanceFrequency(uint64_t* lpFrequency);
bool QueryPerformanceCounter(uint64_t* lpPerformanceCount);

Expand All @@ -22,7 +22,7 @@ bool QueryPerformanceCounter(uint64_t* lpPerformanceCount);
#define TIME_DIFF(freq, start, end) (double) ((end.tv_nsec - start.tv_nsec) / 1000000) / 1000
#endif

enum ggml_type type_from_string(const char* string) {
static enum ggml_type type_from_string(const char * string) {
if (strcmp(string, "Q4_0") == 0) return GGML_TYPE_Q4_0;
if (strcmp(string, "Q4_1") == 0) return GGML_TYPE_Q4_1;
if (strcmp(string, "Q5_0") == 0) return GGML_TYPE_Q5_0;
Expand All @@ -31,9 +31,10 @@ enum ggml_type type_from_string(const char* string) {
return GGML_TYPE_COUNT;
}

int main(int argc, char * argv[]) {
int main(const int argc, const char * argv[]) {
if (argc != 4 || type_from_string(argv[3]) == GGML_TYPE_COUNT) {
fprintf(stderr, "Usage: %s INPUT OUTPUT FORMAT\n\nAvailable formats: Q4_0 Q4_1 Q5_0 Q5_1 Q8_0\n", argv[0]);
fprintf(stderr, "Usage: %s INPUT_FILE OUTPUT_FILE FORMAT\n\nAvailable formats: Q4_0 Q4_1 Q5_0 Q5_1 Q8_0\n", argv[0]);

return EXIT_FAILURE;
}

Expand All @@ -48,11 +49,15 @@ int main(int argc, char * argv[]) {

double diff = TIME_DIFF(freq, start, end);

fprintf(stderr, "Took %.3f s\n", diff);

if (success) {
fprintf(stderr, "Succeeded in %.3fs\n", diff);
fprintf(stderr, "Success\n");

return EXIT_SUCCESS;
} else {
fprintf(stderr, "Error in %.3fs: 0x%.8X\n", diff, rwkv_get_last_error(NULL));
fprintf(stderr, "Error: 0x%.8X\n", rwkv_get_last_error(NULL));

return EXIT_FAILURE;
}
}
File renamed without changes.
23 changes: 11 additions & 12 deletions rwkv/chat_with_bot.py → python/chat_with_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
import argparse
import pathlib
import copy
import json
import time
import torch
import sampling
import rwkv_cpp_model
import rwkv_cpp_shared_library
from rwkv_tokenizer import get_tokenizer
import json
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
from tokenizer_util import get_tokenizer
from typing import List, Dict, Optional
import time

# ======================================== Script settings ========================================

Expand Down Expand Up @@ -98,23 +97,23 @@ def load_thread_state(_thread: str) -> None:

# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end.
# See https://github.com/BlinkDL/ChatRWKV/pull/110/files
def split_last_end_of_line(tokens):
def split_last_end_of_line(tokens: List[int]) -> List[int]:
if len(tokens) > 0 and tokens[-1] == DOUBLE_END_OF_LINE_TOKEN:
tokens = tokens[:-1] + [END_OF_LINE_TOKEN, END_OF_LINE_TOKEN]

return tokens

# =================================================================================================

processing_start = time.time()
processing_start: float = time.time()

prompt_tokens = tokenizer_encode(init_prompt)
prompt_token_count = len(prompt_tokens)
print(f'Processing {prompt_token_count} prompt tokens, may take a while')

process_tokens(split_last_end_of_line(prompt_tokens))

processing_duration = time.time() - processing_start
processing_duration: float = time.time() - processing_start

print(f'Processed in {int(processing_duration)} s, {int(processing_duration / prompt_token_count * 1000)} ms per token')

Expand All @@ -125,11 +124,11 @@ def split_last_end_of_line(tokens):

while True:
# Read user input
user_input = input(f'> {user}{separator} ')
msg = user_input.replace('\\n', '\n').strip()
user_input: str = input(f'> {user}{separator} ')
msg: str = user_input.replace('\\n', '\n').strip()

temperature = TEMPERATURE
top_p = TOP_P
temperature: float = TEMPERATURE
top_p: float = TOP_P

if '-temp=' in msg:
temperature = float(msg.split('-temp=')[1].split(' ')[0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def parse_args():
return parser.parse_args()

def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
n_layer = 0
n_layer: int = 0

while f'blocks.{n_layer}.ln1.weight' in state_dict:
n_layer += 1
Expand All @@ -28,9 +28,9 @@ def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str) -> None:
emb_weight: torch.Tensor = state_dict['emb.weight']

n_layer = get_layer_count(state_dict)
n_vocab = emb_weight.shape[0]
n_embed = emb_weight.shape[1]
n_layer: int = get_layer_count(state_dict)
n_vocab: int = emb_weight.shape[0]
n_embed: int = emb_weight.shape[1]

with open(dest_path, 'wb') as out_file:
is_FP16: bool = data_type == 'FP16' or data_type == 'float16'
Expand All @@ -48,7 +48,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
))

for k in state_dict.keys():
tensor = state_dict[k].float()
tensor: torch.Tensor = state_dict[k].float()

# Same processing as in "RWKV_in_150_lines.py"
if '.time_' in k:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Dict

def test() -> None:
test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp'
test_file_path: str = 'convert_pytorch_rwkv_to_ggml_test.tmp'

try:
state_dict: Dict[str, torch.Tensor] = {
Expand All @@ -15,8 +15,8 @@ def test() -> None:

convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='FP32')

with open(test_file_path, 'rb') as input:
actual_bytes: bytes = input.read()
with open(test_file_path, 'rb') as test_file:
actual_bytes: bytes = test_file.read()

expected_bytes: bytes = struct.pack(
'=iiiiii' + 'iiiii10sffffff' + 'iiii19sf',
Expand Down