We can have a tokenizer anywhere. (#2527)

* We can have a tokenizer anywhere.

* Handling potential lack of offsets (python tokenizer)

* Remove redundancy.

* Fixing the tests.

* Flake.lock update ?

* Fixing the  GIL locking.

* Fixing mamba by using the transformers version.

* Adding the legacy handle.

* Ellide lifetime.

* Lint.

* Deprecation message.

* Fixing bad rebase.
This commit is contained in:
Nicolas Patry 2024-10-28 05:00:24 +01:00 committed by GitHub
parent 0c9b6cdd76
commit 90b226db29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 362 additions and 264 deletions

View File

@ -853,11 +853,11 @@
]
},
"locked": {
"lastModified": 1727836133,
"narHash": "sha256-JE0zciM5IGWvK8J/pE2VldNBf7oyMH5WrU8tZArefbg=",
"lastModified": 1729045942,
"narHash": "sha256-HjmK0x5Zm2TK2vFpC7XBM2e3EDNVnAIuEoU2FkeN8xw=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "02321540b0c8000b36889b1b974d1fec585b25a4",
"rev": "9de3cea452d2401d6f93c06ad985178a4e11d1fc",
"type": "github"
},
"original": {

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module")
def fused_kernel_mamba_handle(launcher):
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
with launcher("state-spaces/mamba-130m-hf", num_shard=1) as handle:
yield handle

View File

@ -145,6 +145,7 @@ pub enum Config {
LlavaNext(LlavaNext),
ClipVisionModel(ClipVisionModel),
Mistral,
Mamba,
Idefics,
Mllama,
Idefics2(Idefics2),

View File

@ -135,7 +135,7 @@ impl Infer {
pub(crate) async fn tokenize(
&self,
request: GenerateRequest,
) -> Result<Option<tokenizers::Encoding>, InferError> {
) -> Result<tokenizers::Encoding, InferError> {
// Tokenize request
let inputs = request.inputs;
let add_special_tokens = request.add_special_tokens;
@ -150,7 +150,7 @@ impl Infer {
})?;
// Return Encoding
Ok(encoding.map(|(encoding, _)| encoding))
Ok(encoding.0)
}
/// Apply the chat template to the chat request

View File

@ -14,11 +14,92 @@ mod vertex;
use crate::infer::{Infer, InferError};
use crate::server::prepare_chat_input;
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use serde::{Deserialize, Serialize};
use tokenizers::Encoding;
use tracing::warn;
use utoipa::ToSchema;
use validation::Validation;
#[derive(Clone)]
pub enum Tokenizer {
Python {
tokenizer_name: String,
revision: Option<String>,
},
Rust(tokenizers::Tokenizer),
}
pub struct PyTokenizer<'a>(pyo3::Bound<'a, pyo3::PyAny>);
impl<'a> PyTokenizer<'a> {
fn from_py(
py: Python<'a>,
tokenizer_name: String,
revision: Option<String>,
) -> PyResult<PyTokenizer<'a>> {
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name,);
let kwargs = if let Some(rev) = &revision {
[("revision", rev.to_string())].into_py_dict_bound(py)
} else {
pyo3::types::PyDict::new_bound(py)
};
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
tracing::info!("Loaded a python tokenizer");
Ok(PyTokenizer(tokenizer))
}
}
trait TokenizerTrait {
fn encode_trait(
&self,
query: String,
add_special_tokens: bool,
) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>>;
}
impl TokenizerTrait for tokenizers::Tokenizer {
fn encode_trait(
&self,
query: String,
add_special_tokens: bool,
) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>> {
self.encode(query, add_special_tokens)
}
}
impl<'a> TokenizerTrait for PyTokenizer<'a> {
fn encode_trait(
&self,
query: String,
add_special_tokens: bool,
) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>> {
let py = self.0.py();
let kwargs = [
("text", query.into_py(py)),
("add_special_tokens", add_special_tokens.into_py(py)),
]
.into_py_dict_bound(py);
let encode = self.0.getattr("encode")?;
let input_ids: Vec<u32> = encode.call((), Some(&kwargs))?.extract()?;
Ok(Encoding::new(
input_ids,
vec![], // type ids
vec![], // tokens (strings)
vec![], // words
vec![], // offsets
vec![], // special_tokens_mask
vec![], // attention_mask
vec![], // overflowing
std::collections::HashMap::new(), //sequence_ranges
))
}
}
/// Hub type
#[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo {
@ -1341,13 +1422,12 @@ impl Default for ModelsInfo {
mod tests {
use super::*;
use serde_json::json;
use tokenizers::Tokenizer;
pub(crate) async fn get_tokenizer() -> Tokenizer {
pub(crate) fn get_tokenizer() -> Tokenizer {
let api = hf_hub::api::sync::Api::new().unwrap();
let repo = api.model("gpt2".to_string());
let filename = repo.get("tokenizer.json").unwrap();
Tokenizer::from_file(filename).unwrap()
Tokenizer::Rust(tokenizers::Tokenizer::from_file(filename).unwrap())
}
#[test]

View File

@ -19,7 +19,8 @@ use crate::{
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
TextMessage, Token, TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage,
Validation,
};
use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
@ -45,6 +46,7 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use regex::Regex;
use serde_json::Value;
@ -54,7 +56,6 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use thiserror::Error;
use tokenizers::Tokenizer;
use tokio::select;
use tokio::signal;
use tokio::sync::oneshot;
@ -64,6 +65,41 @@ use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
fn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec<SimpleToken> {
let offsets = encoding.get_offsets();
let input_ids = encoding.get_ids();
if offsets.len() == input_ids.len() {
input_ids
.iter()
.zip(offsets)
.map(|(&id, &(start, stop))| {
let text = input
.chars()
.skip(start)
.take(stop - start)
.collect::<String>();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect()
} else {
encoding
.get_ids()
.iter()
.map(|&id| SimpleToken {
id,
text: "".to_string(),
start: 0,
stop: 0,
})
.collect()
}
}
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path(
post,
@ -161,40 +197,14 @@ async fn get_chat_tokenize(
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
let input = generate_request.inputs.clone();
let encoding = infer.tokenize(generate_request).await?;
if let Some(encoding) = encoding {
let tokens: Vec<SimpleToken> = encoding
.get_ids()
.iter()
.zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| {
let text = input
.chars()
.skip(start)
.take(stop - start)
.collect::<String>();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect();
let resp = ChatTokenizeResponse {
tokenize_response: TokenizeResponse(tokens),
templated_text: input,
};
Ok((HeaderMap::new(), Json(resp)))
} else {
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
error_type: "no fast tokenizer".to_string(),
}),
))
}
let tokens = encoding_to_tokens(&encoding, &input);
let resp = ChatTokenizeResponse {
tokenize_response: TokenizeResponse(tokens),
templated_text: input,
};
Ok((HeaderMap::new(), Json(resp)))
}
#[utoipa::path(
@ -1458,35 +1468,8 @@ async fn tokenize(
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
let input = req.inputs.clone();
let encoding = infer.tokenize(req).await?;
if let Some(encoding) = encoding {
let tokens: Vec<SimpleToken> = encoding
.get_ids()
.iter()
.zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| {
let text = input
.chars()
.skip(start)
.take(stop - start)
.collect::<String>();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect();
Ok(Json(TokenizeResponse(tokens)))
} else {
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
error_type: "no fast tokenizer".to_string(),
}),
))
}
let tokens = encoding_to_tokens(&encoding, &input);
Ok(Json(TokenizeResponse(tokens)))
}
/// Prometheus metrics scrape endpoint
@ -1594,6 +1577,71 @@ pub fn schema() -> ApiDoc {
ApiDoc
}
fn py_resolve_tokenizer(
py: pyo3::Python,
tokenizer_name: &str,
revision: Option<&str>,
trust_remote_code: bool,
) -> pyo3::PyResult<()> {
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name,);
let kwargs = if let Some(rev) = &revision {
[
("revision", rev.to_string().into_py(py)),
("trust_remote_code", trust_remote_code.into_py(py)),
]
.into_py_dict_bound(py)
} else {
[("trust_remote_code", trust_remote_code.into_py(py))].into_py_dict_bound(py)
};
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
let args = ("out".to_string(),);
save.call1(args)?;
Ok(())
}
fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {
// XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3
// and state-spaces/mamba-130m
tracing::warn!("Odd tokenizer detected, falling back on legacy tokenization");
#[derive(serde::Deserialize)]
struct FallbackConfig {
base_model_name_or_path: Option<String>,
model_type: Option<String>,
ssm_config: Option<serde_json::Value>,
}
config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<FallbackConfig, _> = serde_json::from_str(c);
if let Ok(config) = config {
if config.model_type.is_none() {
if let Some(base) = config.base_model_name_or_path {
pyo3::Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, &base, Some("main"), false)
})
.ok()?;
}
}
if config.ssm_config.is_some() {
// XXX Legacy mamba
pyo3::Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, "EleutherAI/gpt-neox-20b", Some("main"), false)
})
.ok()?;
}
}
Some(())
})
})
}
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
@ -1687,7 +1735,6 @@ pub async fn run(
// Load tokenizer and model info
let (
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
@ -1695,7 +1742,6 @@ pub async fn run(
model_info,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
@ -1709,10 +1755,6 @@ pub async fn run(
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
@ -1725,7 +1767,6 @@ pub async fn run(
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
@ -1740,7 +1781,6 @@ pub async fn run(
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
@ -1762,39 +1802,30 @@ pub async fn run(
HubTokenizerConfig::default()
});
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let tokenizer: Tokenizer = {
use pyo3::prelude::*;
let convert = pyo3::Python::with_gil(|py| -> PyResult<()> {
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name.to_string(),);
let kwargs = [
(
"revision",
(revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py),
),
("trust_remote_code", trust_remote_code.into_py(py)),
]
.into_py_dict_bound(py);
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
let args = ("out".to_string(),);
save.call1(args)?;
pyo3::Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?;
Ok(())
})
.inspect_err(|err| {
tracing::error!("Failed to import python tokenizer {err}");
});
let filename = if convert.is_ok() {
// If we have correctly loaded and resaved with transformers
// We might have modified the tokenizer.json according to transformers
"out/tokenizer.json".into()
})
.or_else(|err| {
let out = legacy_tokenizer_handle(config_filename.as_ref());
out.ok_or(err)
})
.expect("We cannot load a tokenizer");
let filename = "out/tokenizer.json";
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
Tokenizer::Rust(tok)
} else {
filename
};
Tokenizer::from_file(filename).ok()
});
Tokenizer::Python {
tokenizer_name: tokenizer_name.clone(),
revision: revision.clone(),
}
}
};
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
@ -1822,10 +1853,6 @@ pub async fn run(
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
tracing::info!("Using config {config:?}");
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled");
}
// Only send usage stats when TGI is run in container and the function returns Some
let is_container = matches!(usage_stats::is_container(), Ok(true));
@ -1940,7 +1967,7 @@ async fn start(
validation_workers: usize,
api_key: Option<String>,
config: Option<Config>,
(tokenizer, tokenizer_config): (Option<Tokenizer>, HubTokenizerConfig),
(tokenizer, tokenizer_config): (Tokenizer, HubTokenizerConfig),
(preprocessor_config, processor_config): (Option<HubPreprocessorConfig>, HubProcessorConfig),
hostname: String,
port: u16,
@ -2400,30 +2427,6 @@ pub async fn get_hub_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
}
}
/// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of `User`.
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
let api_base_repo = api.repo(Repo::with_revision(
base_model_id.to_string(),
RepoType::Model,
"main".to_string(),
));
api_base_repo.get("tokenizer.json").await.ok()
} else {
None
}
}
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
@ -2566,10 +2569,11 @@ mod tests {
use crate::TokenizerConfigToken;
use crate::Tool;
use crate::tests::get_tokenizer;
use serde_json::json;
#[test]
fn test_prepare_chat_input() {
#[tokio::test]
async fn test_prepare_chat_input() {
// Mock Backend to avoid network requests
struct MockBackend;
@ -2610,9 +2614,11 @@ mod tests {
ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string())
);
let tokenizer = get_tokenizer();
let infer = Infer::new(
backend,
Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false),
Validation::new(1, tokenizer, None, None, 1, 1, 1, 1, 1, false),
1,
tokenizer_config,
HubProcessorConfig::default(),

View File

@ -3,7 +3,9 @@ use crate::config::Config;
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
TokenizerTrait,
};
use crate::{PyTokenizer, Tokenizer};
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{ImageFormat, ImageReader};
use jsonschema::{Draft, JSONSchema};
@ -13,7 +15,6 @@ use std::io::Cursor;
use std::iter;
use std::sync::Arc;
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tracing::{instrument, Span};
@ -30,14 +31,14 @@ pub struct Validation {
max_total_tokens: usize,
disable_grammar_support: bool,
/// Channel to communicate with the background tokenization task
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
sender: mpsc::UnboundedSender<TokenizerRequest>,
}
impl Validation {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
workers: usize,
tokenizer: Option<Tokenizer>,
tokenizer: Tokenizer,
config: Option<Config>,
preprocessor_config: Option<HubPreprocessorConfig>,
max_best_of: usize,
@ -47,8 +48,13 @@ impl Validation {
max_total_tokens: usize,
disable_grammar_support: bool,
) -> Self {
let workers = if let Tokenizer::Python { .. } = &tokenizer {
1
} else {
workers
};
// If we have a fast tokenizer
let sender = if let Some(tokenizer) = tokenizer {
let sender = {
// Create round robin channel
let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
let mut senders = Vec::with_capacity(workers);
@ -75,9 +81,7 @@ impl Validation {
// Create tokenization round robin task
tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));
Some(validation_sender)
} else {
None
validation_sender
};
Self {
@ -97,28 +101,25 @@ impl Validation {
inputs: String,
add_special_tokens: bool,
truncate: Option<usize>,
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.sender {
// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
// Send request to the background validation task
// Unwrap is safe here
sender
.send((
(inputs, add_special_tokens, truncate),
response_sender,
Span::current(),
))
.unwrap();
// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
// Send request to the background validation task
// Unwrap is safe here
let _ = &self
.sender
.send((
(inputs, add_special_tokens, truncate),
response_sender,
Span::current(),
))
.unwrap();
// Await on response channel
// Unwrap is safe here
let encoding = response_receiver.await.unwrap()?;
Ok(Some(encoding))
} else {
Ok(None)
}
// Await on response channel
// Unwrap is safe here
let encoding = response_receiver.await.unwrap()?;
Ok(encoding)
}
#[allow(clippy::type_complexity)]
@ -131,76 +132,46 @@ impl Validation {
max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self
let (encoding, inputs) = self
.tokenize(inputs.clone(), add_special_tokens, truncate)
.await?
{
// Create response channel
let input_length = if let Some(truncate) = truncate {
std::cmp::min(encoding.len(), truncate)
} else {
encoding.len()
};
.await?;
// Create response channel
let input_length = if let Some(truncate) = truncate {
std::cmp::min(encoding.len(), truncate)
} else {
encoding.len()
};
// Get total tokens
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
max_new_tokens
} else {
self.max_total_tokens.saturating_sub(input_length) as u32
};
let total_tokens = input_length + max_new_tokens as usize;
// Get total tokens
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
max_new_tokens
} else {
self.max_total_tokens.saturating_sub(input_length) as u32
};
let total_tokens = input_length + max_new_tokens as usize;
// Validate MaxTotalTokens
if total_tokens > self.max_total_tokens {
return Err(ValidationError::MaxTotalTokens(
self.max_total_tokens,
input_length,
max_new_tokens,
));
}
// Validate InputLength
if input_length > self.max_input_length {
return Err(ValidationError::InputLength(
self.max_input_length,
input_length,
));
}
let ids = encoding.get_ids();
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
}
// Return inputs without validation
else {
// In this case, we don't know the real length in tokens of the inputs
// However, the inputs will be truncated by the python servers
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
max_new_tokens
} else if let Some(truncate) = truncate {
self.max_total_tokens.saturating_sub(truncate) as u32
} else {
return Err(ValidationError::UnsetMaxNewTokens);
};
let mut input_length = truncate.unwrap_or(self.max_input_length);
// We don't have a tokenizer, therefore we have no idea how long is the query, let
// them through and hope for the best.
// Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
input_length = input_length.saturating_sub(max_new_tokens as usize);
}
Ok((
vec![Chunk::Text(inputs)],
None,
// Validate MaxTotalTokens
if total_tokens > self.max_total_tokens {
return Err(ValidationError::MaxTotalTokens(
self.max_total_tokens,
input_length,
max_new_tokens,
))
));
}
// Validate InputLength
if input_length > self.max_input_length {
return Err(ValidationError::InputLength(
self.max_input_length,
input_length,
));
}
let ids = encoding.get_ids();
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
}
/// Validate a payload and get the number of tokens in the input
@ -464,22 +435,52 @@ fn tokenizer_worker(
preprocessor_config: Option<HubPreprocessorConfig>,
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) {
// Loop over requests
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
{
parent_span.in_scope(|| {
response_tx
.send(prepare_input(
inputs,
truncate,
add_special_tokens,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
))
.unwrap_or(())
})
match tokenizer {
Tokenizer::Python {
tokenizer_name,
revision,
} => {
pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> {
let tokenizer = PyTokenizer::from_py(py, tokenizer_name, revision)?;
// Loop over requests
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
{
parent_span.in_scope(|| {
response_tx
.send(prepare_input(
inputs,
truncate,
add_special_tokens,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
))
.unwrap_or(())
})
}
Ok(())
})
.expect("Failure in python tokenizer worker");
}
Tokenizer::Rust(tokenizer) => {
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
{
parent_span.in_scope(|| {
response_tx
.send(prepare_input(
inputs,
truncate,
add_special_tokens,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
))
.unwrap_or(())
})
}
}
}
}
@ -608,11 +609,11 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
}
/// Get input length and optionally truncate it
fn prepare_input(
fn prepare_input<T: TokenizerTrait>(
inputs: String,
_truncate: Option<usize>,
add_special_tokens: bool,
tokenizer: &Tokenizer,
tokenizer: &T,
config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>,
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
@ -649,7 +650,7 @@ fn prepare_input(
// Get the number of tokens in the input
let encoding = tokenizer
.encode(tokenizer_query, add_special_tokens)
.encode_trait(tokenizer_query, add_special_tokens)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
Ok((encoding, input_chunks))
@ -824,7 +825,7 @@ mod tests {
#[tokio::test]
async fn test_validation_max_new_tokens() {
let tokenizer = None;
let tokenizer = get_tokenizer();
let max_best_of = 2;
let max_stop_sequence = 3;
let max_top_n_tokens = 4;
@ -851,15 +852,15 @@ mod tests {
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.await
{
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
Ok((_s, _, 0, 10)) => (),
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
// Ok((_s, _, 0, 10)) => (),
r => panic!("Unexpected not max new tokens: {r:?}"),
}
}
#[tokio::test]
async fn test_validation_input_length() {
let tokenizer = Some(get_tokenizer().await);
let tokenizer = get_tokenizer();
let max_best_of = 2;
let max_stop_sequence = 3;
let max_top_n_tokens = 4;
@ -893,7 +894,7 @@ mod tests {
#[tokio::test]
async fn test_validation_best_of_sampling() {
let tokenizer = Some(get_tokenizer().await);
let tokenizer = get_tokenizer();
let max_best_of = 2;
let max_stop_sequence = 3;
let max_top_n_tokens = 4;
@ -933,7 +934,7 @@ mod tests {
#[tokio::test]
async fn test_validation_top_p() {
let tokenizer = Some(get_tokenizer().await);
let tokenizer = get_tokenizer();
let max_best_of = 2;
let max_stop_sequence = 3;
let max_top_n_tokens = 4;
@ -1004,7 +1005,7 @@ mod tests {
#[tokio::test]
async fn test_validation_top_n_tokens() {
let tokenizer = Some(get_tokenizer().await);
let tokenizer = get_tokenizer();
let max_best_of = 2;
let max_stop_sequences = 3;
let max_top_n_tokens = 4;
@ -1089,7 +1090,7 @@ mod tests {
async fn test_prepare_input_chunks() {
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
let tokenizer = Some(get_tokenizer().await);
let tokenizer = get_tokenizer();
let max_best_of = 2;
let max_stop_sequence = 3;
@ -1124,7 +1125,7 @@ mod tests {
)
.await
{
Ok(Some((_encoding, chunks))) => chunks,
Ok((_encoding, chunks)) => chunks,
_ => panic!("Unexpected tokenization failure"),
};
@ -1146,7 +1147,7 @@ mod tests {
async fn test_idefics2_correct_n_fake_tokens() {
let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();
let tokenizer = Some(get_tokenizer().await);
let tokenizer = get_tokenizer();
let max_best_of = 2;
let max_stop_sequence = 3;
@ -1184,7 +1185,7 @@ mod tests {
)
.await
{
Ok(Some((encoding, chunks))) => (encoding, chunks),
Ok((encoding, chunks)) => (encoding, chunks),
_ => panic!("Unexpected tokenization failure"),
};

View File

@ -226,7 +226,7 @@ class ModelType(enum.Enum):
"url": "https://huggingface.co/databricks/dbrx-instruct",
}
MAMBA = {
"type": "ssm",
"type": "mamba",
"name": "Mamba",
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
}
@ -618,6 +618,10 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "ssm":
raise RuntimeError(
"`ssm` models have been deprecated in favor of `mamba` models, which follow standard HF formats. Check out a list here: https://huggingface.co/models?search=mamba%20-hf"
)
if model_id.startswith("facebook/galactica"):
return CausalLM(

View File

@ -196,7 +196,10 @@ class MambaModel(nn.Module):
def __init__(self, config, weights):
super().__init__()
prefix = "backbone"
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
try:
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights)
except RuntimeError:
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
self.blocks = nn.ModuleList(
[
ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i)
@ -206,7 +209,10 @@ class MambaModel(nn.Module):
self.norm_f = FastRMSNorm.load(
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
)
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
try:
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights)
except RuntimeError:
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights)
self.config = config
def forward(