Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to tokenize a string and return the decoded tokens using the correct BPE model #17

Merged
merged 3 commits into from Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion tiktoken-rs/src/api.rs
Expand Up @@ -46,7 +46,7 @@ pub fn get_completion_max_tokens(model: &str, prompt: &str) -> Result<usize> {
Ok(context_size.saturating_sub(prompt_tokens))
}

#[derive(Debug, Default, Clone, PartialEq)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct ChatCompletionRequestMessage {
/// The role of the author of this message.
pub role: String,
Expand Down
63 changes: 63 additions & 0 deletions tiktoken-rs/src/vendor_tiktoken.rs
Expand Up @@ -235,6 +235,20 @@ impl CoreBPE {
ret
}

#[allow(clippy::needless_lifetimes)] // the iterator captures a lifetime outside of the function
fn _decode_native_and_split<'a>(
&'a self,
tokens: Vec<usize>,
) -> impl Iterator<Item = Vec<u8>> + '_ {
tokens.into_iter().map(move |token| {
let token_bytes = self
.decoder
.get(&token)
.unwrap_or_else(|| &self.special_tokens_decoder[&token]);
token_bytes.clone()
})
}

fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
// This is the core of the encoding logic; the other functions in here
// just make things complicated :-)
Expand Down Expand Up @@ -541,6 +555,55 @@ impl CoreBPE {
Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)),
}
}

/// Tokenize a string and return the decoded tokens using the correct BPE model.
///
/// This method takes a string, encodes it using the BPE model, and decodes the encoded tokens into
/// a vector of strings. It can be used to tokenize a string and return the decoded tokens using the
/// correct BPE model.
///
/// # Examples
///
/// ```
/// use tiktoken_rs::cl100k_base;
/// let bpe = cl100k_base().unwrap();
/// let tokenized: Result<Vec<_>, _> = bpe
/// .split_by_token_with_special_tokens("This is a test with a lot of spaces")
/// .collect();
/// let tokenized = tokenized.unwrap();
/// assert_eq!(
/// tokenized,
/// vec!["This", " is", " a", " test", " ", " with", " a", " lot", " of", " spaces"]
/// );
/// ```
///
/// # Arguments
///
/// * text: A string slice containing the text to be tokenized.
///
/// # Returns
///
/// * Result<Vec<String>>: A Result containing a vector of decoded tokens as strings, or an error
/// if the string cannot be converted into a valid UTF-8 string.
///
/// # Errors
///
/// This function will return an error if:
///
/// * The input text cannot be converted into a valid UTF-8 string during the decoding process.
///
pub fn split_by_token_with_special_tokens<'a>(
&'a self,
text: &'a str,
) -> impl Iterator<Item = Result<String>> + 'a {
// First, encode the text using the BPE model
let encoded = self.encode_with_special_tokens(text);

self._decode_native_and_split(encoded).map(|token|
// Map each token to a Result<String>
String::from_utf8(token)
.map_err(|e| anyhow!(e.to_string())))
}
}

#[cfg(feature = "python")]
Expand Down
13 changes: 13 additions & 0 deletions tiktoken-rs/tests/tiktoken.rs
Expand Up @@ -82,6 +82,19 @@ fn cl100k_base_test() {
);
}

#[test]
fn cl100k_split_test() {
let bpe = cl100k_base().unwrap();
let tokenized: Result<Vec<_>, _> = bpe
.split_by_token_with_special_tokens("This is a test with a lot of spaces")
.collect();
let tokenized = tokenized.unwrap();
assert_eq!(
tokenized,
vec!["This", " is", " a", " test", " ", " with", " a", " lot", " of", " spaces"]
);
}

#[test]
fn p50k_base_singleton_test() {
// let now = std::time::Instant::now();
Expand Down