diff --git a/Dockerfile b/Dockerfile index 74e7d990..b2d274d7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -250,6 +250,9 @@ RUN cd server && \ pip install nvidia-nccl-cu12==2.22.3 ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 +# This is needed because exl2 tries to load flash-attn +# And fails with our builds. +ENV EXLLAMA_NO_FLASH_ATTN=1 # Deps before the binaries # The binaries change on every build given we burn the SHA into them diff --git a/integration-tests/models/test_flash_llama_exl2.py b/integration-tests/models/test_flash_llama_exl2.py index 7169c999..18319f60 100644 --- a/integration-tests/models/test_flash_llama_exl2.py +++ b/integration-tests/models/test_flash_llama_exl2.py @@ -21,7 +21,6 @@ async def flash_llama_exl2(flash_llama_exl2_handle): return flash_llama_exl2_handle.client -@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): @@ -33,7 +32,6 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh assert response == ignore_logprob_response_snapshot -@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2_all_params( @@ -60,7 +58,6 @@ async def test_flash_llama_exl2_all_params( assert response == ignore_logprob_response_snapshot -@pytest.mark.release @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2_load( diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a64b1d71..9a90a673 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -30,11 +30,18 @@ struct RawConfig { n_positions: Option, model_type: Option, max_seq_len: Option, + quantization_config: Option, +} + +#[derive(Deserialize)] +struct QuantizationConfig { + quant_method: Option, } #[derive(Deserialize)] struct Config { max_position_embeddings: Option, + quantize: Option, } impl From for Config { @@ -43,13 +50,16 @@ impl From for Config { .max_position_embeddings .or(other.max_seq_len) .or(other.n_positions); + let quantize = other.quantization_config.and_then(|q| q.quant_method); Config { max_position_embeddings, + quantize, } } } -#[derive(Clone, Copy, Debug, ValueEnum)] +#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)] +#[serde(rename_all = "kebab-case")] enum Quantization { /// 4 bit quantization. Requires a specific AWQ quantized model: /// . @@ -72,17 +82,17 @@ enum Quantization { Marlin, /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, /// but it is known that the model will be much slower to run than the native f16. - #[deprecated( - since = "1.1.0", - note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" - )] + // #[deprecated( + // since = "1.1.0", + // note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" + // )] Bitsandbytes, /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, /// but it is known that the model will be much slower to run than the native f16. - BitsandbytesNF4, + BitsandbytesNf4, /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better /// perplexity performance for you model - BitsandbytesFP4, + BitsandbytesFp4, /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above /// This dtype has native ops should be the fastest if available. /// This is currently not the fastest because of local unpacking + padding to satisfy matrix @@ -99,10 +109,10 @@ impl std::fmt::Display for Quantization { Quantization::Bitsandbytes => { write!(f, "bitsandbytes") } - Quantization::BitsandbytesNF4 => { + Quantization::BitsandbytesNf4 => { write!(f, "bitsandbytes-nf4") } - Quantization::BitsandbytesFP4 => { + Quantization::BitsandbytesFp4 => { write!(f, "bitsandbytes-fp4") } Quantization::Exl2 => { @@ -1085,6 +1095,7 @@ fn spawn_shards( cuda_graphs: Vec, max_total_tokens: usize, max_input_tokens: usize, + quantize: Option, max_log_level: LevelFilter, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1106,7 +1117,6 @@ fn spawn_shards( let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); let otlp_service_name = args.otlp_service_name.clone(); - let quantize = args.quantize; let speculate = args.speculate; let dtype = args.dtype; let trust_remote_code = args.trust_remote_code; @@ -1429,65 +1439,68 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:#?}", args); - let get_max_position_embeddings = || -> Result> { - let model_id = args.model_id.clone(); - let mut path = std::path::Path::new(&args.model_id).to_path_buf(); - let filename = if !path.exists() { - // Assume it's a hub id + let get_max_positions_quantize = + || -> Result<(usize, Option), Box> { + let model_id = args.model_id.clone(); + let mut path = std::path::Path::new(&args.model_id).to_path_buf(); + let filename = if !path.exists() { + // Assume it's a hub id - let api = if let Ok(token) = std::env::var("HF_TOKEN") { - // env variable has precedence over on file token. - ApiBuilder::new().with_token(Some(token)).build()? + let api = if let Ok(token) = std::env::var("HF_TOKEN") { + // env variable has precedence over on file token. + ApiBuilder::new().with_token(Some(token)).build()? + } else { + Api::new()? + }; + let repo = if let Some(ref revision) = args.revision { + api.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.to_string(), + )) + } else { + api.model(model_id) + }; + repo.get("config.json")? } else { - Api::new()? + path.push("config.json"); + path }; - let repo = if let Some(ref revision) = args.revision { - api.repo(Repo::with_revision( - model_id, - RepoType::Model, - revision.to_string(), - )) - } else { - api.model(model_id) - }; - repo.get("config.json")? - } else { - path.push("config.json"); - path - }; - let content = std::fs::read_to_string(filename)?; - let config: RawConfig = serde_json::from_str(&content)?; + let content = std::fs::read_to_string(filename)?; + let config: RawConfig = serde_json::from_str(&content)?; - if config.model_type == Some("gemma2".to_string()) { - tracing::info!("Forcing flash decoding because of softcap usage"); - std::env::set_var("ATTENTION", "flashdecoding"); - } - let config: Config = config.into(); - - // Quantization usually means you're even more RAM constrained. - let max_default = 4096; - - if let Some(max_position_embeddings) = config.max_position_embeddings { - if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } - Ok(max_default) - } else { - Ok(max_position_embeddings) + if config.model_type == Some("gemma2".to_string()) { + tracing::info!("Forcing flash decoding because of softcap usage"); + std::env::set_var("ATTENTION", "flashdecoding"); } - } else { - Err(Box::new(LauncherError::ArgumentValidation( - "no max defined".to_string(), - ))) - } - }; - let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); + let config: Config = config.into(); + let quantize = config.quantize; + + // Quantization usually means you're even more RAM constrained. + let max_default = 4096; + + if let Some(max_position_embeddings) = config.max_position_embeddings { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + } + Ok((max_default, quantize)) + } else { + Ok((max_position_embeddings, quantize)) + } + } else { + Err(Box::new(LauncherError::ArgumentValidation( + "no max defined".to_string(), + ))) + } + }; + let (max_position_embeddings, quantize): (usize, Option) = + get_max_positions_quantize().unwrap_or((4096, None)); let max_input_tokens = { match (args.max_input_tokens, args.max_input_length) { @@ -1544,18 +1557,26 @@ fn main() -> Result<(), LauncherError> { ))); } - let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { + if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { + tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases."); + } + let quantize = args.quantize.or(quantize); + let cuda_graphs = match (&args.cuda_graphs, &quantize) { (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), #[allow(deprecated)] ( None, Some( Quantization::Bitsandbytes - | Quantization::BitsandbytesNF4 - | Quantization::BitsandbytesFP4, + | Quantization::BitsandbytesNf4 + | Quantization::BitsandbytesFp4, ), ) => { - tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); + tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); + vec![] + } + (None, Some(Quantization::Exl2)) => { + tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them"); vec![] } _ => { @@ -1672,6 +1693,7 @@ fn main() -> Result<(), LauncherError> { cuda_graphs, max_total_tokens, max_input_tokens, + quantize, max_log_level, shutdown.clone(), &shutdown_receiver, diff --git a/server/poetry.lock b/server/poetry.lock index 5072aa0b..fc1a54a3 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1070,6 +1070,30 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +optional = false +python-versions = ">=3.8" +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + [[package]] name = "markupsafe" version = "2.1.5" @@ -1207,6 +1231,17 @@ torch = "*" type = "url" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -2277,6 +2312,20 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pygments" +version = "2.18.0" +description = "Pygments is a syntax highlighting package written in Python." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"}, + {file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"}, +] + +[package.extras] +windows-terminal = ["colorama (>=0.4.6)"] + [[package]] name = "pytest" version = "7.4.4" @@ -2508,6 +2557,24 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rich" +version = "13.7.1" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"}, + {file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + [[package]] name = "rpds-py" version = "0.19.0" @@ -3584,4 +3651,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "a89867b23017d2efa8a7aa14d4764bcbd3b4dea9bfbf06a7a68464cb184ac6a1" +content-hash = "0ff7a244a409b616490cb238995bbe28dedf67ccb8855edafa2b71ee2e777dbd" diff --git a/server/pyproject.toml b/server/pyproject.toml index 15da4a8f..57deb1b8 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -46,6 +46,7 @@ marlin-kernels = [ { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] +rich = "^13.7.1" [tool.poetry.extras] torch = ["torch"] diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index 828b6fca..eb521bd6 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt index 828b6fca..eb521bd6 100644 --- a/server/requirements_intel.txt +++ b/server/requirements_intel.txt @@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index 828b6fca..eb521bd6 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" @@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" +rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ba168b13..28534d0f 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -652,6 +652,7 @@ class CausalLM(Model): dtype=dtype, device=device, ) + self.quantize = quantize return self @property diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 5d6ce3c7..f6dcde68 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -412,6 +412,7 @@ class Mamba(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + self.quantize = quantize self.process_group, _rank, world_size = initialize_torch_distributed() if world_size > 1: raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3c92128a..04d4c28b 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -676,6 +676,7 @@ class Seq2SeqLM(Model): dtype=dtype, device=device, ) + self.quantize = quantize return self @property