diff --git a/tiktoken-rs/src/api.rs b/tiktoken-rs/src/api.rs index 3544927..08aee24 100644 --- a/tiktoken-rs/src/api.rs +++ b/tiktoken-rs/src/api.rs @@ -46,7 +46,7 @@ pub fn get_completion_max_tokens(model: &str, prompt: &str) -> Result { Ok(context_size.saturating_sub(prompt_tokens)) } -#[derive(Debug, Default, Clone, PartialEq)] +#[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct ChatCompletionRequestMessage { /// The role of the author of this message. pub role: String, diff --git a/tiktoken-rs/src/vendor_tiktoken.rs b/tiktoken-rs/src/vendor_tiktoken.rs index d94b938..1da339e 100644 --- a/tiktoken-rs/src/vendor_tiktoken.rs +++ b/tiktoken-rs/src/vendor_tiktoken.rs @@ -235,6 +235,20 @@ impl CoreBPE { ret } + #[allow(clippy::needless_lifetimes)] // the iterator captures a lifetime outside of the function + fn _decode_native_and_split<'a>( + &'a self, + tokens: Vec, + ) -> impl Iterator> + '_ { + tokens.into_iter().map(move |token| { + let token_bytes = self + .decoder + .get(&token) + .unwrap_or_else(|| &self.special_tokens_decoder[&token]); + token_bytes.clone() + }) + } + fn _encode_ordinary_native(&self, text: &str) -> Vec { // This is the core of the encoding logic; the other functions in here // just make things complicated :-) @@ -541,6 +555,55 @@ impl CoreBPE { Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)), } } + + /// Tokenize a string and return the decoded tokens using the correct BPE model. + /// + /// This method takes a string, encodes it using the BPE model, and decodes the encoded tokens into + /// a vector of strings. It can be used to tokenize a string and return the decoded tokens using the + /// correct BPE model. + /// + /// # Examples + /// + /// ``` + /// use tiktoken_rs::cl100k_base; + /// let bpe = cl100k_base().unwrap(); + /// let tokenized: Result, _> = bpe + /// .split_by_token_with_special_tokens("This is a test with a lot of spaces") + /// .collect(); + /// let tokenized = tokenized.unwrap(); + /// assert_eq!( + /// tokenized, + /// vec!["This", " is", " a", " test", " ", " with", " a", " lot", " of", " spaces"] + /// ); + /// ``` + /// + /// # Arguments + /// + /// * text: A string slice containing the text to be tokenized. + /// + /// # Returns + /// + /// * Result>: A Result containing a vector of decoded tokens as strings, or an error + /// if the string cannot be converted into a valid UTF-8 string. + /// + /// # Errors + /// + /// This function will return an error if: + /// + /// * The input text cannot be converted into a valid UTF-8 string during the decoding process. + /// + pub fn split_by_token_with_special_tokens<'a>( + &'a self, + text: &'a str, + ) -> impl Iterator> + 'a { + // First, encode the text using the BPE model + let encoded = self.encode_with_special_tokens(text); + + self._decode_native_and_split(encoded).map(|token| + // Map each token to a Result + String::from_utf8(token) + .map_err(|e| anyhow!(e.to_string()))) + } } #[cfg(feature = "python")] diff --git a/tiktoken-rs/tests/tiktoken.rs b/tiktoken-rs/tests/tiktoken.rs index 184920f..bfe633d 100644 --- a/tiktoken-rs/tests/tiktoken.rs +++ b/tiktoken-rs/tests/tiktoken.rs @@ -82,6 +82,19 @@ fn cl100k_base_test() { ); } +#[test] +fn cl100k_split_test() { + let bpe = cl100k_base().unwrap(); + let tokenized: Result, _> = bpe + .split_by_token_with_special_tokens("This is a test with a lot of spaces") + .collect(); + let tokenized = tokenized.unwrap(); + assert_eq!( + tokenized, + vec!["This", " is", " a", " test", " ", " with", " a", " lot", " of", " spaces"] + ); +} + #[test] fn p50k_base_singleton_test() { // let now = std::time::Instant::now();