diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 0dfaf0ab..ce79009a 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -33,11 +33,13 @@ impl RadixAllocator { window_size: Option, prefix_caching: bool, ) -> Self { - assert_eq!( - block_size, 1, - "Radix tree allocator only works with block_size=1, was: {}", - block_size - ); + if prefix_caching { + assert_eq!( + block_size, 1, + "Radix tree allocator only works with block_size=1, was: {}", + block_size + ); + } // if window_size.is_some() { // unimplemented!("Window size not supported in the prefix-caching block allocator yet"); // } diff --git a/flake.lock b/flake.lock index 69f9ef13..1a6353f5 100644 --- a/flake.lock +++ b/flake.lock @@ -835,11 +835,11 @@ ] }, "locked": { - "lastModified": 1724206841, - "narHash": "sha256-L8dKaX4T3k+TR2fEHCfGbH4UXdspovz/pj87iai9qmc=", + "lastModified": 1724379657, + "narHash": "sha256-+CFDh1FUgyY7q0FiWhKJpHS7LlD3KbiqN5Z4Z+4bGmc=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "45e98fbd62c32e5927e952d2833fa1ba4fb35a61", + "rev": "a18034322c7703fcfe5d7352a77981ba4a936a61", "type": "github" }, "original": { @@ -944,11 +944,11 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1724218652, - "narHash": "sha256-Y7Kt+AZRIdo7tr/VhKGzdwYf7stiYQ4JD7flusEpXQw=", + "lastModified": 1724270760, + "narHash": "sha256-KX566x0+3HZcB20HPdvdwyMm7ZJg21M+iqVrs/HCimA=", "owner": "danieldk", "repo": "tgi-nix", - "rev": "ab2761aa7b970e737492b8cc41ca580dcb094808", + "rev": "12cbaa76ff258351741d3b5afb7161f617fe7b4c", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 83feb26a..87c6b555 100644 --- a/flake.nix +++ b/flake.nix @@ -56,6 +56,7 @@ in { devShells = with pkgs; rec { + default = pure; pure = mkShell { diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 35dde5d0..faa84db3 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -24,36 +24,38 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter}; mod env_runtime; -fn resolve_attention(config: &Config, lora_adapters: &Option) -> (String, String) { +fn resolve_attention(config: &Option, lora_adapters: &Option) -> (String, String) { let mut prefix_caching: Option = std::env::var("USE_PREFIX_CACHING").ok(); let mut attention: Option = std::env::var("ATTENTION").ok(); - match config.head_dim { - Some(h) if h == 64 || h == 128 || h == 256 => { - if lora_adapters.is_some() && prefix_caching.is_none() { - tracing::info!("Disabling prefix caching because of lora adapters"); - prefix_caching = Some("0".to_string()); - } - match config.model_type.as_deref() { - Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { - // Required because gemma2 needs bfloat16 which is not supported by - // flashinfer ? - if prefix_caching.is_none() { - tracing::info!( - "Forcing flash decoding because model {} requires it", - config.model_type.as_ref().unwrap() - ); - prefix_caching = Some("0".to_string()); - attention = Some("flashdecoding".to_string()); - } + if let Some(config) = config { + match config.head_dim { + Some(h) if h == 64 || h == 128 || h == 256 => { + if lora_adapters.is_some() && prefix_caching.is_none() { + tracing::info!("Disabling prefix caching because of lora adapters"); + prefix_caching = Some("0".to_string()); + } + match config.model_type.as_deref() { + Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { + // Required because gemma2 needs bfloat16 which is not supported by + // flashinfer ? + if prefix_caching.is_none() { + tracing::info!( + "Forcing flash decoding because model {} requires it", + config.model_type.as_ref().unwrap() + ); + prefix_caching = Some("0".to_string()); + attention = Some("flashdecoding".to_string()); + } + } + _ => {} } - _ => {} } - } - _ => { - if prefix_caching.is_none() { - tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); - prefix_caching = Some("0".to_string()); - attention = Some("flashdecoding".to_string()); + _ => { + if prefix_caching.is_none() { + tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); + prefix_caching = Some("0".to_string()); + attention = Some("flashdecoding".to_string()); + } } } } @@ -1502,68 +1504,68 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:#?}", args); - let get_max_positions_quantize = - || -> Result<(usize, Option), Box> { - let model_id = args.model_id.clone(); - let mut path = std::path::Path::new(&args.model_id).to_path_buf(); - let filename = if !path.exists() { - // Assume it's a hub id + let get_config = || -> Result> { + let model_id = args.model_id.clone(); + let mut path = std::path::Path::new(&args.model_id).to_path_buf(); + let filename = if !path.exists() { + // Assume it's a hub id - let api = if let Ok(token) = std::env::var("HF_TOKEN") { - // env variable has precedence over on file token. - ApiBuilder::new().with_token(Some(token)).build()? - } else { - Api::new()? - }; - let repo = if let Some(ref revision) = args.revision { - api.repo(Repo::with_revision( - model_id, - RepoType::Model, - revision.to_string(), - )) - } else { - api.model(model_id) - }; - repo.get("config.json")? + let api = if let Ok(token) = std::env::var("HF_TOKEN") { + // env variable has precedence over on file token. + ApiBuilder::new().with_token(Some(token)).build()? } else { - path.push("config.json"); - path + Api::new()? }; - - let content = std::fs::read_to_string(filename)?; - let config: RawConfig = serde_json::from_str(&content)?; - - let config: Config = config.into(); - let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); - tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); - std::env::set_var("USE_PREFIX_CACHING", prefix_caching); - std::env::set_var("ATTENTION", attention); - let quantize = config.quantize; - - // Quantization usually means you're even more RAM constrained. - let max_default = 4096; - - if let Some(max_position_embeddings) = config.max_position_embeddings { - if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } - Ok((max_default, quantize)) - } else { - Ok((max_position_embeddings, quantize)) - } + let repo = if let Some(ref revision) = args.revision { + api.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.to_string(), + )) } else { - Err(Box::new(LauncherError::ArgumentValidation( - "no max defined".to_string(), - ))) - } + api.model(model_id) + }; + repo.get("config.json")? + } else { + path.push("config.json"); + path }; - let (max_position_embeddings, quantize): (usize, Option) = - get_max_positions_quantize().unwrap_or((4096, None)); + + let content = std::fs::read_to_string(filename)?; + let config: RawConfig = serde_json::from_str(&content)?; + + let config: Config = config.into(); + Ok(config) + }; + let config: Option = get_config().ok(); + let quantize = config.as_ref().and_then(|c| c.quantize); + // Quantization usually means you're even more RAM constrained. + let max_default = 4096; + + let max_position_embeddings = if let Some(config) = &config { + if let Some(max_position_embeddings) = config.max_position_embeddings { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + } + max_default + } else { + max_position_embeddings + } + } else { + max_default + } + } else { + max_default + }; + let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); + tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); + std::env::set_var("USE_PREFIX_CACHING", prefix_caching); + std::env::set_var("ATTENTION", attention); let max_input_tokens = { match (args.max_input_tokens, args.max_input_length) {