Skip to content

Commit

Permalink
Python API restructurization & code style improvements (#130)
Browse files Browse the repository at this point in the history
* Replace tabs with 4 spaces

* Refactor tests

* Rename Python scripts directory to "python"

* Create a separate package for the official Python API

* Move Python inference example to a separate file

* Add missing const

* Refactor extras

* Split rwkv.cpp into smaller files

* Clean up cpp code

* Rename rwkv package to rwkv_cpp

* Add missing type hints

* Rewrite automatic library lookup

* Add compatibility warning

* Fix MacOS build

* Fix MacOS build
  • Loading branch information
saharNooby committed Sep 20, 2023
1 parent 8db73b1 commit 6caa45e
Show file tree
Hide file tree
Showing 48 changed files with 1,964 additions and 1,879 deletions.
67 changes: 23 additions & 44 deletions README.md
Expand Up @@ -4,13 +4,15 @@ This is a port of [BlinkDL/RWKV-LM](https://github.com/BlinkDL/RWKV-LM) to [gger

Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **focused on CPU**, but cuBLAS is also supported.

This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](rwkv%2Frwkv_cpp_model.py) for it.
This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrapper](python%2Frwkv_cpp%2Frwkv_cpp_model.py) for it.

[RWKV](https://arxiv.org/abs/2305.13048) is a novel large language model architecture, [with the largest model in the family having 14B parameters](https://huggingface.co/BlinkDL/rwkv-4-pile-14b). In contrast to Transformer with `O(n^2)` attention, RWKV requires only state from previous step to calculate logits. This makes RWKV very CPU-friendly on large context lenghts.

Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py).

### Quality and performance
⚠️ **Python API was restructured on 2023-09-20**, you may need to change paths/package names in your code when updating `rwkv.cpp`.

## Quality and performance

If you use `rwkv.cpp` for anything serious, please [test all available formats for perplexity and latency](rwkv%2Fmeasure_pexplexity.py) on a representative dataset, and decide which trade-off is best for you.

Expand All @@ -26,7 +28,7 @@ Below table is for reference only. Measurements were made on 4C/8T x86 CPU with
| `FP16` | **15.623** | 117 | 2.82 |
| `FP32` | **15.623** | 198 | 5.64 |

#### With cuBLAS
### With cuBLAS

Measurements were made on Intel i7 13700K & NVIDIA 3060 Ti 8 GB. Latency per token in ms shown.

Expand Down Expand Up @@ -124,80 +126,64 @@ This option would require a little more manual work, but you can use it with any

```commandline
# Windows
python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin FP16
python python\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin FP16
# Linux / MacOS
python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin FP16
python python/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin FP16
```

**Optionally**, quantize the model into one of quantized formats from the table above:

```commandline
# Windows
python rwkv\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1
python python\quantize.py C:\rwkv.cpp-169M.bin C:\rwkv.cpp-169M-Q5_1.bin Q5_1
# Linux / MacOS
python rwkv/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q5_1.bin Q5_1
python python/quantize.py ~/Downloads/rwkv.cpp-169M.bin ~/Downloads/rwkv.cpp-169M-Q5_1.bin Q5_1
```

### 4. Run the model

**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/).
#### Using the command line

**Note**: change the model path with the non-quantized model for the full weights model.
**Requirements**: Python 3.x with [PyTorch](https://pytorch.org/get-started/locally/) and [tokenizers](https://pypi.org/project/tokenizers/).

To generate some text, run:

```commandline
# Windows
python rwkv\generate_completions.py C:\rwkv.cpp-169M-Q5_1.bin
python python\generate_completions.py C:\rwkv.cpp-169M-Q5_1.bin
# Linux / MacOS
python rwkv/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin
python python/generate_completions.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin
```

To chat with a bot, run:

```commandline
# Windows
python rwkv\chat_with_bot.py C:\rwkv.cpp-169M-Q5_1.bin
python python\chat_with_bot.py C:\rwkv.cpp-169M-Q5_1.bin
# Linux / MacOS
python rwkv/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin
python python/chat_with_bot.py ~/Downloads/rwkv.cpp-169M-Q5_1.bin
```

Edit [generate_completions.py](rwkv%2Fgenerate_completions.py) or [chat_with_bot.py](rwkv%2Fchat_with_bot.py) to change prompts and sampling settings.

---

Example of using `rwkv.cpp` in your custom Python script:

```python
import rwkv_cpp_model
import rwkv_cpp_shared_library
#### Using in your own code

# Change to model paths used above (quantized or full weights)
model_path = r'C:\rwkv.cpp-169M.bin'
The short and simple script [inference_example.py](python%2Finference_example.py) demostrates the use of `rwkv.cpp` in Python.

To use `rwkv.cpp` in C/C++, include the header [rwkv.h](rwkv.h).

model = rwkv_cpp_model.RWKVModel(
rwkv_cpp_shared_library.load_rwkv_shared_library(),
model_path,
thread_count=4, #need to adjust when use cuBLAS
gpu_layers_count=5 #only enabled when use cuBLAS
)
To use `rwkv.cpp` in any other language, see [Bindings](#Bindings) section below. If your language is missing, you can try to bind to the C API using the tooling provided by your language.

logits, state = None, None

for token in [1, 2, 3]:
logits, state = model.eval(token, state)

print(f'Output logits: {logits}')
## Bindings

# Don't forget to free the memory after you've done working with the model
model.free()
These projects wrap `rwkv.cpp` for easier use in other languages/frameworks.

```
* Golang: [seasonjs/rwkv](https://github.com/seasonjs/rwkv)
* Node.js: [Atome-FE/llama-node](https://github.com/Atome-FE/llama-node)

## Compatibility

Expand All @@ -214,13 +200,6 @@ For reference only, here is a list of latest versions of `rwkv.cpp` that have su

See also [docs/FILE_FORMAT.md](docs/FILE_FORMAT.md) for version numbers of `rwkv.cpp` model files and their changelog.

## Bindings

These projects wrap `rwkv.cpp` for easier use in other languages/frameworks.

* Golang: [seasonjs/rwkv](https://github.com/seasonjs/rwkv)
* Node.js: [Atome-FE/llama-node](https://github.com/Atome-FE/llama-node)

## Contributing

Please follow the code style described in [docs/CODE_STYLE.md](docs/CODE_STYLE.md).
4 changes: 4 additions & 0 deletions docs/CODE_STYLE.md
Expand Up @@ -27,6 +27,10 @@ Overall, keep code in similar style as it was before.
- Place braces on the same line as the statement.
- Always add braces to `if`, `for`, `while`, `do` and other similar statements.
- Prefix top-level function and struct names with `rwkv_`.
- Mark all functions that are not members of public API as `static`.
- Public API is the set of functions defined in `rwkv.h`.
- Mark all immutable function arguments as `const`.
- This is not required for local variables.

## Python

Expand Down
6 changes: 2 additions & 4 deletions extras/CMakeLists.txt
Expand Up @@ -9,7 +9,5 @@ function(rwkv_add_extra source)
endif()
endfunction()

file(GLOB extras *.c)
foreach (extra ${extras})
rwkv_add_extra(${extra})
endforeach()
rwkv_add_extra(cpu_info.c)
rwkv_add_extra(quantize.c)
8 changes: 5 additions & 3 deletions extras/cpu_info.c
@@ -1,7 +1,9 @@
#include "rwkv.h"

#include <stdio.h>

int main() {
#include <rwkv.h>

int main(void) {
printf("%s", rwkv_get_system_info_string());

return 0;
}
23 changes: 14 additions & 9 deletions extras/quantize.c
@@ -1,11 +1,11 @@
#include "ggml.h"
#include "rwkv.h"

#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#ifdef _WIN32
#include <ggml.h>
#include <rwkv.h>

#if defined(_WIN32)
bool QueryPerformanceFrequency(uint64_t* lpFrequency);
bool QueryPerformanceCounter(uint64_t* lpPerformanceCount);

Expand All @@ -22,7 +22,7 @@ bool QueryPerformanceCounter(uint64_t* lpPerformanceCount);
#define TIME_DIFF(freq, start, end) (double) ((end.tv_nsec - start.tv_nsec) / 1000000) / 1000
#endif

enum ggml_type type_from_string(const char* string) {
static enum ggml_type type_from_string(const char * string) {
if (strcmp(string, "Q4_0") == 0) return GGML_TYPE_Q4_0;
if (strcmp(string, "Q4_1") == 0) return GGML_TYPE_Q4_1;
if (strcmp(string, "Q5_0") == 0) return GGML_TYPE_Q5_0;
Expand All @@ -31,9 +31,10 @@ enum ggml_type type_from_string(const char* string) {
return GGML_TYPE_COUNT;
}

int main(int argc, char * argv[]) {
int main(const int argc, const char * argv[]) {
if (argc != 4 || type_from_string(argv[3]) == GGML_TYPE_COUNT) {
fprintf(stderr, "Usage: %s INPUT OUTPUT FORMAT\n\nAvailable formats: Q4_0 Q4_1 Q5_0 Q5_1 Q8_0\n", argv[0]);
fprintf(stderr, "Usage: %s INPUT_FILE OUTPUT_FILE FORMAT\n\nAvailable formats: Q4_0 Q4_1 Q5_0 Q5_1 Q8_0\n", argv[0]);

return EXIT_FAILURE;
}

Expand All @@ -48,11 +49,15 @@ int main(int argc, char * argv[]) {

double diff = TIME_DIFF(freq, start, end);

fprintf(stderr, "Took %.3f s\n", diff);

if (success) {
fprintf(stderr, "Succeeded in %.3fs\n", diff);
fprintf(stderr, "Success\n");

return EXIT_SUCCESS;
} else {
fprintf(stderr, "Error in %.3fs: 0x%.8X\n", diff, rwkv_get_last_error(NULL));
fprintf(stderr, "Error: 0x%.8X\n", rwkv_get_last_error(NULL));

return EXIT_FAILURE;
}
}
File renamed without changes.
23 changes: 11 additions & 12 deletions rwkv/chat_with_bot.py → python/chat_with_bot.py
Expand Up @@ -6,14 +6,13 @@
import argparse
import pathlib
import copy
import json
import time
import torch
import sampling
import rwkv_cpp_model
import rwkv_cpp_shared_library
from rwkv_tokenizer import get_tokenizer
import json
from rwkv_cpp import rwkv_cpp_shared_library, rwkv_cpp_model
from tokenizer_util import get_tokenizer
from typing import List, Dict, Optional
import time

# ======================================== Script settings ========================================

Expand Down Expand Up @@ -98,23 +97,23 @@ def load_thread_state(_thread: str) -> None:

# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end.
# See https://github.com/BlinkDL/ChatRWKV/pull/110/files
def split_last_end_of_line(tokens):
def split_last_end_of_line(tokens: List[int]) -> List[int]:
if len(tokens) > 0 and tokens[-1] == DOUBLE_END_OF_LINE_TOKEN:
tokens = tokens[:-1] + [END_OF_LINE_TOKEN, END_OF_LINE_TOKEN]

return tokens

# =================================================================================================

processing_start = time.time()
processing_start: float = time.time()

prompt_tokens = tokenizer_encode(init_prompt)
prompt_token_count = len(prompt_tokens)
print(f'Processing {prompt_token_count} prompt tokens, may take a while')

process_tokens(split_last_end_of_line(prompt_tokens))

processing_duration = time.time() - processing_start
processing_duration: float = time.time() - processing_start

print(f'Processed in {int(processing_duration)} s, {int(processing_duration / prompt_token_count * 1000)} ms per token')

Expand All @@ -125,11 +124,11 @@ def split_last_end_of_line(tokens):

while True:
# Read user input
user_input = input(f'> {user}{separator} ')
msg = user_input.replace('\\n', '\n').strip()
user_input: str = input(f'> {user}{separator} ')
msg: str = user_input.replace('\\n', '\n').strip()

temperature = TEMPERATURE
top_p = TOP_P
temperature: float = TEMPERATURE
top_p: float = TOP_P

if '-temp=' in msg:
temperature = float(msg.split('-temp=')[1].split(' ')[0])
Expand Down
Expand Up @@ -16,7 +16,7 @@ def parse_args():
return parser.parse_args()

def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
n_layer = 0
n_layer: int = 0

while f'blocks.{n_layer}.ln1.weight' in state_dict:
n_layer += 1
Expand All @@ -28,9 +28,9 @@ def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int:
def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str) -> None:
emb_weight: torch.Tensor = state_dict['emb.weight']

n_layer = get_layer_count(state_dict)
n_vocab = emb_weight.shape[0]
n_embed = emb_weight.shape[1]
n_layer: int = get_layer_count(state_dict)
n_vocab: int = emb_weight.shape[0]
n_embed: int = emb_weight.shape[1]

with open(dest_path, 'wb') as out_file:
is_FP16: bool = data_type == 'FP16' or data_type == 'float16'
Expand All @@ -48,7 +48,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
))

for k in state_dict.keys():
tensor = state_dict[k].float()
tensor: torch.Tensor = state_dict[k].float()

# Same processing as in "RWKV_in_150_lines.py"
if '.time_' in k:
Expand Down
Expand Up @@ -5,7 +5,7 @@
from typing import Dict

def test() -> None:
test_file_path = 'convert_pytorch_rwkv_to_ggml_test.tmp'
test_file_path: str = 'convert_pytorch_rwkv_to_ggml_test.tmp'

try:
state_dict: Dict[str, torch.Tensor] = {
Expand All @@ -15,8 +15,8 @@ def test() -> None:

convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='FP32')

with open(test_file_path, 'rb') as input:
actual_bytes: bytes = input.read()
with open(test_file_path, 'rb') as test_file:
actual_bytes: bytes = test_file.read()

expected_bytes: bytes = struct.pack(
'=iiiiii' + 'iiiii10sffffff' + 'iiii19sf',
Expand Down

0 comments on commit 6caa45e

Please sign in to comment.