diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 867bfd3d..4a6cd621 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,4 +1,4 @@ -use clap::Parser; +use clap::{Parser, ValueEnum}; use serde::Deserialize; use std::env; use std::ffi::OsString; @@ -16,6 +16,26 @@ use subprocess::{ExitStatus, Popen, PopenConfig, PopenError, Redirection}; mod env_runtime; +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Quantization { + Bitsandbytes, + Gptq, +} + +impl std::fmt::Display for Quantization { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + Quantization::Bitsandbytes => { + write!(f, "bitsandbytes") + } + Quantization::Gptq => { + write!(f, "gptq") + } + } + } +} + /// App Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -46,10 +66,10 @@ struct Args { #[clap(long, env)] num_shard: Option, - /// Wether you want the model to be quantized or not. This will use bitsandbytes for - /// quantization on the fly. - #[clap(long, env)] - quantize: bool, + /// Wether you want the model to be quantized or not. This will use `bitsandbytes` for + /// quantization on the fly, or `gptq`. + #[clap(long, env, value_enum)] + quantize: Option, /// The maximum amount of concurrent requests for this particular deployment. /// Having a low limit will refuse clients requests instead of having them @@ -218,7 +238,7 @@ enum ShardStatus { fn shard_manager( model_id: String, revision: Option, - quantize: bool, + quantize: Option, uds_path: String, rank: usize, world_size: usize, @@ -257,8 +277,9 @@ fn shard_manager( shard_argv.push("--sharded".to_string()); } - if quantize { - shard_argv.push("--quantize".to_string()) + if let Some(quantize) = quantize { + shard_argv.push("--quantize".to_string()); + shard_argv.push(quantize.to_string()) } // Model optional revision diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 92482a94..35cad20f 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -5,17 +5,23 @@ import typer from pathlib import Path from loguru import logger from typing import Optional +from enum import Enum app = typer.Typer() +class Quantization(str, Enum): + bitsandbytes = "bitsandbytes" + gptq = "gptq" + + @app.command() def serve( model_id: str, revision: Optional[str] = None, sharded: bool = False, - quantize: bool = False, + quantize: Optional[Quantization] = None, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", json_output: bool = False, @@ -55,6 +61,8 @@ def serve( if otlp_endpoint is not None: setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) + # Downgrade enum into str for easier management later on + quantize = None if quantize is None else quantize.value server.serve(model_id, revision, sharded, quantize, uds_path) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 221c9139..e02be3de 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -91,7 +91,7 @@ torch.set_grad_enabled(False) def get_model( - model_id: str, revision: Optional[str], sharded: bool, quantize: bool + model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str] ) -> Model: if "facebook/galactica" in model_id: if sharded: diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 7c50644a..877acb00 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -49,7 +49,12 @@ class BloomCausalLMBatch(CausalLMBatch): class BLOOM(CausalLM): - def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + ): super(BLOOM, self).__init__( model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1 ) @@ -61,7 +66,10 @@ class BLOOM(CausalLM): class BLOOMSharded(BLOOM): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() self.master = rank == 0 @@ -113,7 +121,7 @@ class BLOOMSharded(BLOOM): def load_weights( model, filenames: List[str], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, @@ -167,7 +175,7 @@ class BLOOMSharded(BLOOM): tensor = tensor.contiguous().to(dtype) - if quantize: + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -217,9 +225,14 @@ class BLOOMSharded(BLOOM): return linear module.linear = replace_linear(state) - - else: + elif quantize == "gptq": + raise NotImplementedError( + "`gptq` is not implemented for now" + ) + elif quantize is None: tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") module._parameters[param_name] = tensor if name == "word_embeddings.weight": diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index f838fc5c..0d521ac4 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -447,7 +447,7 @@ class CausalLM(Model): self, model_id: str, revision: Optional[str] = None, - quantize: bool = False, + quantize: Optional[str] = None, decode_buffer: int = 3, ): if torch.cuda.is_available(): @@ -468,7 +468,7 @@ class CausalLM(Model): revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, - load_in_8bit=quantize, + load_in_8bit=quantize == "bitsandbytes", ).eval() tokenizer.pad_token_id = ( self.model.config.pad_token_id diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6ae869db..5b32b22e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -105,7 +105,7 @@ class FastLinear(nn.Linear): self.bnb_linear = None def prepare_weights(self, quantize: bool = False): - if quantize: + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -129,8 +129,12 @@ class FastLinear(nn.Linear): # Delete reference to data self.weight = None self.bias = None - else: + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: self.weight = nn.Parameter(self.weight.T) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") def forward(self, input: torch.Tensor) -> torch.Tensor: if self.quantized: diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index d706df33..e7b878c0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -92,8 +92,8 @@ class FastLinear(nn.Linear): self.quantized = False self.bnb_linear = None - def prepare_weights(self, quantize: bool = False): - if quantize: + def prepare_weights(self, quantize: Optional[str] = None): + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -117,8 +117,12 @@ class FastLinear(nn.Linear): # Delete reference to data self.weight = None self.bias = None - else: + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: self.weight = nn.Parameter(self.weight.T) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") def forward(self, input: torch.Tensor) -> torch.Tensor: if self.quantized: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 20ad8385..309ec19f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -67,8 +67,8 @@ class FastLinear(nn.Linear): self.quantized = False self.bnb_linear = None - def prepare_weights(self, quantize: bool = False): - if quantize: + def prepare_weights(self, quantize: Optional[str] = None): + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -92,8 +92,12 @@ class FastLinear(nn.Linear): # Delete reference to data self.weight = None self.bias = None - else: + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: self.weight = nn.Parameter(self.weight.T) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") def forward(self, input: torch.Tensor) -> torch.Tensor: if self.quantized: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e862cfeb..0a9fccca 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -393,7 +393,7 @@ class FlashCausalLM(Model): model_cls: Type[PreTrainedModel], model_id: str, revision: Optional[str] = None, - quantize: bool = False, + quantize: Optional[str] = None, decode_buffer: int = 3, ): if torch.cuda.is_available(): @@ -410,7 +410,7 @@ class FlashCausalLM(Model): model_id, revision=revision, torch_dtype=dtype, - load_in_8bit=quantize, + load_in_8bit=quantize == "bitsandbytes", ) .eval() .to(device) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index a3ba2084..5f47cf66 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -154,7 +154,10 @@ class FlashLlama(FlashCausalLM): class FlashLlamaSharded(FlashLlama): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.past_pad = None self.process_group, rank, world_size = initialize_torch_distributed() diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index f6098e6c..4f94b348 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -193,7 +193,10 @@ class Galactica(OPT): class GalacticaSharded(Galactica): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() self.master = rank == 0 @@ -244,7 +247,7 @@ class GalacticaSharded(Galactica): def load_weights( model, filenames: List[str], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, @@ -299,7 +302,7 @@ class GalacticaSharded(Galactica): tensor = tensor.contiguous().to(dtype) - if quantize: + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -349,9 +352,14 @@ class GalacticaSharded(Galactica): return linear module.linear = replace_linear(state) - - else: + elif quantize == "gptq": + raise NotImplementedError( + "`gptq` is not implemented for now" + ) + elif quantize is None: tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") module._parameters[param_name] = tensor if name == "model.decoder.embed_tokens.weight": diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index cd36dba0..2d42e0b0 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -32,7 +32,10 @@ except Exception as e: class GPTNeoxSharded(CausalLM): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() self.master = rank == 0 @@ -83,7 +86,7 @@ class GPTNeoxSharded(CausalLM): def load_weights( model, filenames: List[str], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, @@ -148,7 +151,7 @@ class GPTNeoxSharded(CausalLM): tensor = tensor.contiguous().to(dtype) - if quantize: + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -198,9 +201,14 @@ class GPTNeoxSharded(CausalLM): return linear module.linear = replace_linear(state) - - else: + elif quantize == "gptq": + raise NotImplementedError( + "`gptq` is not implemented for now" + ) + elif quantize is None: tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") if current_parameter_tensor is not None: module._parameters[param_name] = tensor diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 5a142676..4bd56de1 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -14,7 +14,12 @@ EOD = "<|endoftext|>" class SantaCoder(CausalLM): - def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + ): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 @@ -46,7 +51,7 @@ class SantaCoder(CausalLM): model_id, revision=revision, torch_dtype=dtype, - load_in_8bit=quantize, + load_in_8bit=quantize == "bitsandbytes", trust_remote_code=True, # required ) .to(device) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 77912cff..84854f5d 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -501,7 +501,7 @@ class Seq2SeqLM(Model): self, model_id: str, revision: Optional[str] = None, - quantize: bool = False, + quantize: Optional[str] = None, decode_buffer: int = 3, ): if torch.cuda.is_available(): @@ -519,7 +519,7 @@ class Seq2SeqLM(Model): revision=revision, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, - load_in_8bit=quantize, + load_in_8bit=quantize == "bitsandbytes", ).eval() tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 4cfdea9e..381617b7 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -32,7 +32,10 @@ except Exception as e: class T5Sharded(Seq2SeqLM): def __init__( - self, model_id: str, revision: Optional[str] = None, quantize: bool = False + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() self.master = rank == 0 @@ -83,7 +86,7 @@ class T5Sharded(Seq2SeqLM): def load_weights( model, filenames: List[str], - quantize: bool, + quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, @@ -154,7 +157,7 @@ class T5Sharded(Seq2SeqLM): tensor = tensor.contiguous().to(dtype) - if quantize: + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -205,8 +208,14 @@ class T5Sharded(Seq2SeqLM): module.linear = replace_linear(state) - else: + elif quantize == "gptq": + raise NotImplementedError( + "`gptq` is not implemented for now" + ) + elif quantize is None: tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") if current_parameter_tensor is not None: module._parameters[param_name] = tensor diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 70f08ed7..d715207b 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -100,14 +100,14 @@ def serve( model_id: str, revision: Optional[str], sharded: bool, - quantize: bool, + quantize: Optional[str], uds_path: Path, ): async def serve_inner( model_id: str, revision: Optional[str], sharded: bool = False, - quantize: bool = False, + quantize: Optional[str] = None, ): unix_socket_template = "unix://{}-{}" if sharded: