diff --git a/flake.lock b/flake.lock index 1706385a..69ce6cd5 100644 --- a/flake.lock +++ b/flake.lock @@ -853,11 +853,11 @@ ] }, "locked": { - "lastModified": 1727836133, - "narHash": "sha256-JE0zciM5IGWvK8J/pE2VldNBf7oyMH5WrU8tZArefbg=", + "lastModified": 1729045942, + "narHash": "sha256-HjmK0x5Zm2TK2vFpC7XBM2e3EDNVnAIuEoU2FkeN8xw=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "02321540b0c8000b36889b1b974d1fec585b25a4", + "rev": "9de3cea452d2401d6f93c06ad985178a4e11d1fc", "type": "github" }, "original": { diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index 85ed8fd1..baa19643 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -3,7 +3,7 @@ import pytest @pytest.fixture(scope="module") def fused_kernel_mamba_handle(launcher): - with launcher("state-spaces/mamba-130m", num_shard=1) as handle: + with launcher("state-spaces/mamba-130m-hf", num_shard=1) as handle: yield handle diff --git a/router/src/config.rs b/router/src/config.rs index 7139b923..ce066ad0 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -145,6 +145,7 @@ pub enum Config { LlavaNext(LlavaNext), ClipVisionModel(ClipVisionModel), Mistral, + Mamba, Idefics, Mllama, Idefics2(Idefics2), diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 896f4f43..557e03cb 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -135,7 +135,7 @@ impl Infer { pub(crate) async fn tokenize( &self, request: GenerateRequest, - ) -> Result, InferError> { + ) -> Result { // Tokenize request let inputs = request.inputs; let add_special_tokens = request.add_special_tokens; @@ -150,7 +150,7 @@ impl Infer { })?; // Return Encoding - Ok(encoding.map(|(encoding, _)| encoding)) + Ok(encoding.0) } /// Apply the chat template to the chat request diff --git a/router/src/lib.rs b/router/src/lib.rs index 7c40c7e3..a5613f89 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -14,11 +14,92 @@ mod vertex; use crate::infer::{Infer, InferError}; use crate::server::prepare_chat_input; +use pyo3::prelude::*; +use pyo3::types::IntoPyDict; use serde::{Deserialize, Serialize}; +use tokenizers::Encoding; use tracing::warn; use utoipa::ToSchema; use validation::Validation; +#[derive(Clone)] +pub enum Tokenizer { + Python { + tokenizer_name: String, + revision: Option, + }, + Rust(tokenizers::Tokenizer), +} + +pub struct PyTokenizer<'a>(pyo3::Bound<'a, pyo3::PyAny>); + +impl<'a> PyTokenizer<'a> { + fn from_py( + py: Python<'a>, + tokenizer_name: String, + revision: Option, + ) -> PyResult> { + let transformers = py.import_bound("transformers")?; + let auto = transformers.getattr("AutoTokenizer")?; + let from_pretrained = auto.getattr("from_pretrained")?; + let args = (tokenizer_name,); + let kwargs = if let Some(rev) = &revision { + [("revision", rev.to_string())].into_py_dict_bound(py) + } else { + pyo3::types::PyDict::new_bound(py) + }; + let tokenizer = from_pretrained.call(args, Some(&kwargs))?; + tracing::info!("Loaded a python tokenizer"); + Ok(PyTokenizer(tokenizer)) + } +} + +trait TokenizerTrait { + fn encode_trait( + &self, + query: String, + add_special_tokens: bool, + ) -> Result>; +} + +impl TokenizerTrait for tokenizers::Tokenizer { + fn encode_trait( + &self, + query: String, + add_special_tokens: bool, + ) -> Result> { + self.encode(query, add_special_tokens) + } +} + +impl<'a> TokenizerTrait for PyTokenizer<'a> { + fn encode_trait( + &self, + query: String, + add_special_tokens: bool, + ) -> Result> { + let py = self.0.py(); + let kwargs = [ + ("text", query.into_py(py)), + ("add_special_tokens", add_special_tokens.into_py(py)), + ] + .into_py_dict_bound(py); + let encode = self.0.getattr("encode")?; + let input_ids: Vec = encode.call((), Some(&kwargs))?.extract()?; + Ok(Encoding::new( + input_ids, + vec![], // type ids + vec![], // tokens (strings) + vec![], // words + vec![], // offsets + vec![], // special_tokens_mask + vec![], // attention_mask + vec![], // overflowing + std::collections::HashMap::new(), //sequence_ranges + )) + } +} + /// Hub type #[derive(Clone, Debug, Deserialize)] pub struct HubModelInfo { @@ -1341,13 +1422,12 @@ impl Default for ModelsInfo { mod tests { use super::*; use serde_json::json; - use tokenizers::Tokenizer; - pub(crate) async fn get_tokenizer() -> Tokenizer { + pub(crate) fn get_tokenizer() -> Tokenizer { let api = hf_hub::api::sync::Api::new().unwrap(); let repo = api.model("gpt2".to_string()); let filename = repo.get("tokenizer.json").unwrap(); - Tokenizer::from_file(filename).unwrap() + Tokenizer::Rust(tokenizers::Tokenizer::from_file(filename).unwrap()) } #[test] diff --git a/router/src/server.rs b/router/src/server.rs index eb1d2544..863607b1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -19,7 +19,8 @@ use crate::{ GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse, - TextMessage, Token, TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation, + TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage, + Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -45,6 +46,7 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; use hf_hub::{Cache, Repo, RepoType}; use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; +use pyo3::prelude::*; use pyo3::types::IntoPyDict; use regex::Regex; use serde_json::Value; @@ -54,7 +56,6 @@ use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use thiserror::Error; -use tokenizers::Tokenizer; use tokio::select; use tokio::signal; use tokio::sync::oneshot; @@ -64,6 +65,41 @@ use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; +fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec { + let offsets = encoding.get_offsets(); + let input_ids = encoding.get_ids(); + if offsets.len() == input_ids.len() { + input_ids + .iter() + .zip(offsets) + .map(|(&id, &(start, stop))| { + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); + SimpleToken { + id, + text, + start, + stop, + } + }) + .collect() + } else { + encoding + .get_ids() + .iter() + .map(|&id| SimpleToken { + id, + text: "".to_string(), + start: 0, + stop: 0, + }) + .collect() + } +} + /// Generate tokens if `stream == false` or a stream of token if `stream == true` #[utoipa::path( post, @@ -161,40 +197,14 @@ async fn get_chat_tokenize( let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0; let input = generate_request.inputs.clone(); let encoding = infer.tokenize(generate_request).await?; - if let Some(encoding) = encoding { - let tokens: Vec = encoding - .get_ids() - .iter() - .zip(encoding.get_offsets()) - .map(|(&id, &(start, stop))| { - let text = input - .chars() - .skip(start) - .take(stop - start) - .collect::(); - SimpleToken { - id, - text, - start, - stop, - } - }) - .collect(); - let resp = ChatTokenizeResponse { - tokenize_response: TokenizeResponse(tokens), - templated_text: input, - }; - Ok((HeaderMap::new(), Json(resp))) - } else { - Err(( - StatusCode::NOT_FOUND, - Json(ErrorResponse { - error: "No fast tokenizer or tokenizer.json for this model".to_string(), - error_type: "no fast tokenizer".to_string(), - }), - )) - } + let tokens = encoding_to_tokens(&encoding, &input); + + let resp = ChatTokenizeResponse { + tokenize_response: TokenizeResponse(tokens), + templated_text: input, + }; + Ok((HeaderMap::new(), Json(resp))) } #[utoipa::path( @@ -1458,35 +1468,8 @@ async fn tokenize( ) -> Result, (StatusCode, Json)> { let input = req.inputs.clone(); let encoding = infer.tokenize(req).await?; - if let Some(encoding) = encoding { - let tokens: Vec = encoding - .get_ids() - .iter() - .zip(encoding.get_offsets()) - .map(|(&id, &(start, stop))| { - let text = input - .chars() - .skip(start) - .take(stop - start) - .collect::(); - SimpleToken { - id, - text, - start, - stop, - } - }) - .collect(); - Ok(Json(TokenizeResponse(tokens))) - } else { - Err(( - StatusCode::NOT_FOUND, - Json(ErrorResponse { - error: "No fast tokenizer or tokenizer.json for this model".to_string(), - error_type: "no fast tokenizer".to_string(), - }), - )) - } + let tokens = encoding_to_tokens(&encoding, &input); + Ok(Json(TokenizeResponse(tokens))) } /// Prometheus metrics scrape endpoint @@ -1594,6 +1577,71 @@ pub fn schema() -> ApiDoc { ApiDoc } +fn py_resolve_tokenizer( + py: pyo3::Python, + tokenizer_name: &str, + revision: Option<&str>, + trust_remote_code: bool, +) -> pyo3::PyResult<()> { + let transformers = py.import_bound("transformers")?; + let auto = transformers.getattr("AutoTokenizer")?; + let from_pretrained = auto.getattr("from_pretrained")?; + let args = (tokenizer_name,); + let kwargs = if let Some(rev) = &revision { + [ + ("revision", rev.to_string().into_py(py)), + ("trust_remote_code", trust_remote_code.into_py(py)), + ] + .into_py_dict_bound(py) + } else { + [("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py) + }; + let tokenizer = from_pretrained.call(args, Some(&kwargs))?; + let save = tokenizer.getattr("save_pretrained")?; + let args = ("out".to_string(),); + save.call1(args)?; + Ok(()) +} + +fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> { + // XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3 + // and state-spaces/mamba-130m + tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization"); + + #[derive(serde::Deserialize)] + struct FallbackConfig { + base_model_name_or_path: Option, + model_type: Option, + ssm_config: Option, + } + config_filename.and_then(|filename| { + std::fs::read_to_string(filename) + .ok() + .as_ref() + .and_then(|c| { + let config: Result = serde_json::from_str(c); + if let Ok(config) = config { + if config.model_type.is_none() { + if let Some(base) = config.base_model_name_or_path { + pyo3::Python::with_gil(|py| -> PyResult<()> { + py_resolve_tokenizer(py, &base, Some("main"), false) + }) + .ok()?; + } + } + if config.ssm_config.is_some() { + // XXX Legacy mamba + pyo3::Python::with_gil(|py| -> PyResult<()> { + py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main"), false) + }) + .ok()?; + } + } + Some(()) + }) + }) +} + /// Serving method #[allow(clippy::too_many_arguments)] pub async fn run( @@ -1687,7 +1735,6 @@ pub async fn run( // Load tokenizer and model info let ( - tokenizer_filename, config_filename, tokenizer_config_filename, preprocessor_config_filename, @@ -1695,7 +1742,6 @@ pub async fn run( model_info, ) = match api { Type::None => ( - Some(local_path.join("tokenizer.json")), Some(local_path.join("config.json")), Some(local_path.join("tokenizer_config.json")), Some(local_path.join("preprocessor_config.json")), @@ -1709,10 +1755,6 @@ pub async fn run( revision.clone().unwrap_or_else(|| "main".to_string()), )); - let tokenizer_filename = match api_repo.get("tokenizer.json").await { - Ok(tokenizer_filename) => Some(tokenizer_filename), - Err(_) => get_base_tokenizer(&api, &api_repo).await, - }; let config_filename = api_repo.get("config.json").await.ok(); let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok(); let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok(); @@ -1725,7 +1767,6 @@ pub async fn run( None }; ( - tokenizer_filename, config_filename, tokenizer_config_filename, preprocessor_config_filename, @@ -1740,7 +1781,6 @@ pub async fn run( revision.clone().unwrap_or_else(|| "main".to_string()), )); ( - repo.get("tokenizer.json"), repo.get("config.json"), repo.get("tokenizer_config.json"), repo.get("preprocessor_config.json"), @@ -1762,39 +1802,30 @@ pub async fn run( HubTokenizerConfig::default() }); - let tokenizer: Option = tokenizer_filename.and_then(|filename| { + let tokenizer: Tokenizer = { use pyo3::prelude::*; - let convert = pyo3::Python::with_gil(|py| -> PyResult<()> { - let transformers = py.import_bound("transformers")?; - let auto = transformers.getattr("AutoTokenizer")?; - let from_pretrained = auto.getattr("from_pretrained")?; - let args = (tokenizer_name.to_string(),); - let kwargs = [ - ( - "revision", - (revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py), - ), - ("trust_remote_code", trust_remote_code.into_py(py)), - ] - .into_py_dict_bound(py); - let tokenizer = from_pretrained.call(args, Some(&kwargs))?; - let save = tokenizer.getattr("save_pretrained")?; - let args = ("out".to_string(),); - save.call1(args)?; + pyo3::Python::with_gil(|py| -> PyResult<()> { + py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?; Ok(()) }) .inspect_err(|err| { tracing::error!("Failed to import python tokenizer {err}"); - }); - let filename = if convert.is_ok() { - // If we have correctly loaded and resaved with transformers - // We might have modified the tokenizer.json according to transformers - "out/tokenizer.json".into() + }) + .or_else(|err| { + let out = legacy_tokenizer_handle(config_filename.as_ref()); + out.ok_or(err) + }) + .expect("We cannot load a tokenizer"); + let filename = "out/tokenizer.json"; + if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) { + Tokenizer::Rust(tok) } else { - filename - }; - Tokenizer::from_file(filename).ok() - }); + Tokenizer::Python { + tokenizer_name: tokenizer_name.clone(), + revision: revision.clone(), + } + } + }; let config: Option = config_filename.and_then(|filename| { std::fs::read_to_string(filename) @@ -1822,10 +1853,6 @@ pub async fn run( preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file); tracing::info!("Using config {config:?}"); - if tokenizer.is_none() { - tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}"); - tracing::warn!("Rust input length validation and truncation is disabled"); - } // Only send usage stats when TGI is run in container and the function returns Some let is_container = matches!(usage_stats::is_container(), Ok(true)); @@ -1940,7 +1967,7 @@ async fn start( validation_workers: usize, api_key: Option, config: Option, - (tokenizer, tokenizer_config): (Option, HubTokenizerConfig), + (tokenizer, tokenizer_config): (Tokenizer, HubTokenizerConfig), (preprocessor_config, processor_config): (Option, HubProcessorConfig), hostname: String, port: u16, @@ -2400,30 +2427,6 @@ pub async fn get_hub_model_info(api: &ApiRepo) -> Option { } } -/// get base tokenizer -pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option { - let config_filename = api_repo.get("config.json").await.ok()?; - - // Open the file in read-only mode with buffer. - let file = File::open(config_filename).ok()?; - let reader = BufReader::new(file); - - // Read the JSON contents of the file as an instance of `User`. - let config: serde_json::Value = serde_json::from_reader(reader).ok()?; - - if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") { - let api_base_repo = api.repo(Repo::with_revision( - base_model_id.to_string(), - RepoType::Model, - "main".to_string(), - )); - - api_base_repo.get("tokenizer.json").await.ok() - } else { - None - } -} - /// get tokenizer_config from the Huggingface Hub pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option { let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?; @@ -2566,10 +2569,11 @@ mod tests { use crate::TokenizerConfigToken; use crate::Tool; + use crate::tests::get_tokenizer; use serde_json::json; - #[test] - fn test_prepare_chat_input() { + #[tokio::test] + async fn test_prepare_chat_input() { // Mock Backend to avoid network requests struct MockBackend; @@ -2610,9 +2614,11 @@ mod tests { ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string()) ); + let tokenizer = get_tokenizer(); + let infer = Infer::new( backend, - Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false), + Validation::new(1, tokenizer, None, None, 1, 1, 1, 1, 1, false), 1, tokenizer_config, HubProcessorConfig::default(), diff --git a/router/src/validation.rs b/router/src/validation.rs index 85b4220b..8159ede4 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -3,7 +3,9 @@ use crate::config::Config; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{ GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor, + TokenizerTrait, }; +use crate::{PyTokenizer, Tokenizer}; use base64::{engine::general_purpose::STANDARD, Engine}; use image::{ImageFormat, ImageReader}; use jsonschema::{Draft, JSONSchema}; @@ -13,7 +15,6 @@ use std::io::Cursor; use std::iter; use std::sync::Arc; use thiserror::Error; -use tokenizers::tokenizer::Tokenizer; use tokio::sync::mpsc; use tokio::sync::oneshot; use tracing::{instrument, Span}; @@ -30,14 +31,14 @@ pub struct Validation { max_total_tokens: usize, disable_grammar_support: bool, /// Channel to communicate with the background tokenization task - sender: Option>, + sender: mpsc::UnboundedSender, } impl Validation { #[allow(clippy::too_many_arguments)] pub(crate) fn new( workers: usize, - tokenizer: Option, + tokenizer: Tokenizer, config: Option, preprocessor_config: Option, max_best_of: usize, @@ -47,8 +48,13 @@ impl Validation { max_total_tokens: usize, disable_grammar_support: bool, ) -> Self { + let workers = if let Tokenizer::Python { .. } = &tokenizer { + 1 + } else { + workers + }; // If we have a fast tokenizer - let sender = if let Some(tokenizer) = tokenizer { + let sender = { // Create round robin channel let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel(); let mut senders = Vec::with_capacity(workers); @@ -75,9 +81,7 @@ impl Validation { // Create tokenization round robin task tokio::spawn(round_robin_task(validation_round_robin_receiver, senders)); - Some(validation_sender) - } else { - None + validation_sender }; Self { @@ -97,28 +101,25 @@ impl Validation { inputs: String, add_special_tokens: bool, truncate: Option, - ) -> Result)>, ValidationError> { + ) -> Result<(tokenizers::Encoding, Vec), ValidationError> { // If we have a fast tokenizer - if let Some(sender) = &self.sender { - // Create response channel - let (response_sender, response_receiver) = oneshot::channel(); - // Send request to the background validation task - // Unwrap is safe here - sender - .send(( - (inputs, add_special_tokens, truncate), - response_sender, - Span::current(), - )) - .unwrap(); + // Create response channel + let (response_sender, response_receiver) = oneshot::channel(); + // Send request to the background validation task + // Unwrap is safe here + let _ = &self + .sender + .send(( + (inputs, add_special_tokens, truncate), + response_sender, + Span::current(), + )) + .unwrap(); - // Await on response channel - // Unwrap is safe here - let encoding = response_receiver.await.unwrap()?; - Ok(Some(encoding)) - } else { - Ok(None) - } + // Await on response channel + // Unwrap is safe here + let encoding = response_receiver.await.unwrap()?; + Ok(encoding) } #[allow(clippy::type_complexity)] @@ -131,76 +132,46 @@ impl Validation { max_new_tokens: Option, ) -> Result<(Vec, Option>, usize, u32), ValidationError> { // If we have a fast tokenizer - if let Some((encoding, inputs)) = self + let (encoding, inputs) = self .tokenize(inputs.clone(), add_special_tokens, truncate) - .await? - { - // Create response channel - let input_length = if let Some(truncate) = truncate { - std::cmp::min(encoding.len(), truncate) - } else { - encoding.len() - }; + .await?; + // Create response channel + let input_length = if let Some(truncate) = truncate { + std::cmp::min(encoding.len(), truncate) + } else { + encoding.len() + }; - // Get total tokens - let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { - max_new_tokens - } else { - self.max_total_tokens.saturating_sub(input_length) as u32 - }; - let total_tokens = input_length + max_new_tokens as usize; + // Get total tokens + let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { + max_new_tokens + } else { + self.max_total_tokens.saturating_sub(input_length) as u32 + }; + let total_tokens = input_length + max_new_tokens as usize; - // Validate MaxTotalTokens - if total_tokens > self.max_total_tokens { - return Err(ValidationError::MaxTotalTokens( - self.max_total_tokens, - input_length, - max_new_tokens, - )); - } - - // Validate InputLength - if input_length > self.max_input_length { - return Err(ValidationError::InputLength( - self.max_input_length, - input_length, - )); - } - - let ids = encoding.get_ids(); - let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned(); - - metrics::histogram!("tgi_request_input_length").record(input_length as f64); - Ok((inputs, Some(input_ids), input_length, max_new_tokens)) - } - // Return inputs without validation - else { - // In this case, we don't know the real length in tokens of the inputs - // However, the inputs will be truncated by the python servers - // We make sure that truncate + max_new_tokens <= self.max_total_tokens - let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens { - max_new_tokens - } else if let Some(truncate) = truncate { - self.max_total_tokens.saturating_sub(truncate) as u32 - } else { - return Err(ValidationError::UnsetMaxNewTokens); - }; - let mut input_length = truncate.unwrap_or(self.max_input_length); - - // We don't have a tokenizer, therefore we have no idea how long is the query, let - // them through and hope for the best. - // Validate MaxNewTokens - if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { - input_length = input_length.saturating_sub(max_new_tokens as usize); - } - - Ok(( - vec![Chunk::Text(inputs)], - None, + // Validate MaxTotalTokens + if total_tokens > self.max_total_tokens { + return Err(ValidationError::MaxTotalTokens( + self.max_total_tokens, input_length, max_new_tokens, - )) + )); } + + // Validate InputLength + if input_length > self.max_input_length { + return Err(ValidationError::InputLength( + self.max_input_length, + input_length, + )); + } + + let ids = encoding.get_ids(); + let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned(); + + metrics::histogram!("tgi_request_input_length").record(input_length as f64); + Ok((inputs, Some(input_ids), input_length, max_new_tokens)) } /// Validate a payload and get the number of tokens in the input @@ -464,22 +435,52 @@ fn tokenizer_worker( preprocessor_config: Option, mut receiver: mpsc::UnboundedReceiver, ) { - // Loop over requests - while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = - receiver.blocking_recv() - { - parent_span.in_scope(|| { - response_tx - .send(prepare_input( - inputs, - truncate, - add_special_tokens, - &tokenizer, - config.as_ref(), - preprocessor_config.as_ref(), - )) - .unwrap_or(()) - }) + match tokenizer { + Tokenizer::Python { + tokenizer_name, + revision, + } => { + pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> { + let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision)?; + // Loop over requests + while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = + receiver.blocking_recv() + { + parent_span.in_scope(|| { + response_tx + .send(prepare_input( + inputs, + truncate, + add_special_tokens, + &tokenizer, + config.as_ref(), + preprocessor_config.as_ref(), + )) + .unwrap_or(()) + }) + } + Ok(()) + }) + .expect("Failure in python tokenizer worker"); + } + Tokenizer::Rust(tokenizer) => { + while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = + receiver.blocking_recv() + { + parent_span.in_scope(|| { + response_tx + .send(prepare_input( + inputs, + truncate, + add_special_tokens, + &tokenizer, + config.as_ref(), + preprocessor_config.as_ref(), + )) + .unwrap_or(()) + }) + } + } } } @@ -608,11 +609,11 @@ fn image_tokens_fixup(config: &Config, text: String) -> String { } /// Get input length and optionally truncate it -fn prepare_input( +fn prepare_input( inputs: String, _truncate: Option, add_special_tokens: bool, - tokenizer: &Tokenizer, + tokenizer: &T, config: Option<&Config>, preprocessor_config: Option<&HubPreprocessorConfig>, ) -> Result<(tokenizers::Encoding, Vec), ValidationError> { @@ -649,7 +650,7 @@ fn prepare_input( // Get the number of tokens in the input let encoding = tokenizer - .encode(tokenizer_query, add_special_tokens) + .encode_trait(tokenizer_query, add_special_tokens) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; Ok((encoding, input_chunks)) @@ -824,7 +825,7 @@ mod tests { #[tokio::test] async fn test_validation_max_new_tokens() { - let tokenizer = None; + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequence = 3; let max_top_n_tokens = 4; @@ -851,15 +852,15 @@ mod tests { .validate_input("Hello".to_string(), true, None, Some(max_new_tokens)) .await { - // Err(ValidationError::MaxNewTokens(1, 10)) => (), - Ok((_s, _, 0, 10)) => (), + Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), + // Ok((_s, _, 0, 10)) => (), r => panic!("Unexpected not max new tokens: {r:?}"), } } #[tokio::test] async fn test_validation_input_length() { - let tokenizer = Some(get_tokenizer().await); + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequence = 3; let max_top_n_tokens = 4; @@ -893,7 +894,7 @@ mod tests { #[tokio::test] async fn test_validation_best_of_sampling() { - let tokenizer = Some(get_tokenizer().await); + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequence = 3; let max_top_n_tokens = 4; @@ -933,7 +934,7 @@ mod tests { #[tokio::test] async fn test_validation_top_p() { - let tokenizer = Some(get_tokenizer().await); + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequence = 3; let max_top_n_tokens = 4; @@ -1004,7 +1005,7 @@ mod tests { #[tokio::test] async fn test_validation_top_n_tokens() { - let tokenizer = Some(get_tokenizer().await); + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequences = 3; let max_top_n_tokens = 4; @@ -1089,7 +1090,7 @@ mod tests { async fn test_prepare_input_chunks() { let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap(); - let tokenizer = Some(get_tokenizer().await); + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequence = 3; @@ -1124,7 +1125,7 @@ mod tests { ) .await { - Ok(Some((_encoding, chunks))) => chunks, + Ok((_encoding, chunks)) => chunks, _ => panic!("Unexpected tokenization failure"), }; @@ -1146,7 +1147,7 @@ mod tests { async fn test_idefics2_correct_n_fake_tokens() { let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap(); - let tokenizer = Some(get_tokenizer().await); + let tokenizer = get_tokenizer(); let max_best_of = 2; let max_stop_sequence = 3; @@ -1184,7 +1185,7 @@ mod tests { ) .await { - Ok(Some((encoding, chunks))) => (encoding, chunks), + Ok((encoding, chunks)) => (encoding, chunks), _ => panic!("Unexpected tokenization failure"), }; diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f4fa431c..99e3d343 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -226,7 +226,7 @@ class ModelType(enum.Enum): "url": "https://huggingface.co/databricks/dbrx-instruct", } MAMBA = { - "type": "ssm", + "type": "mamba", "name": "Mamba", "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj", } @@ -618,6 +618,10 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif model_type == "ssm": + raise RuntimeError( + "`ssm` models have been deprecated in favor of `mamba` models, which follow standard HF formats. Check out a list here: https://huggingface.co/models?search=mamba%20-hf" + ) if model_id.startswith("facebook/galactica"): return CausalLM( diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 293051c2..07284e6a 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -196,7 +196,10 @@ class MambaModel(nn.Module): def __init__(self, config, weights): super().__init__() prefix = "backbone" - self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) + try: + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights) + except RuntimeError: + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.blocks = nn.ModuleList( [ ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i) @@ -206,7 +209,10 @@ class MambaModel(nn.Module): self.norm_f = FastRMSNorm.load( f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon ) - self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) + try: + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights) + except RuntimeError: + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights) self.config = config def forward(