diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 9d6cd4dd..51131f42 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -36,6 +36,26 @@ impl std::fmt::Display for Quantization { } } +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Dtype { + Float16, + BFloat16, +} + +impl std::fmt::Display for Dtype { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + Dtype::Float16 => { + write!(f, "float16") + } + Dtype::BFloat16 => { + write!(f, "bfloat16") + } + } + } +} + /// App Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -71,6 +91,10 @@ struct Args { #[clap(long, env, value_enum)] quantize: Option, + /// The dtype to be forced upon the model. This option cannot be used with `--quantize`. + #[clap(long, env, value_enum)] + dtype: Option, + /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is /// encouraged when loading a model with custom code to ensure no malicious code has been /// contributed in a newer revision. @@ -258,6 +282,7 @@ fn shard_manager( model_id: String, revision: Option, quantize: Option, + dtype: Option, trust_remote_code: bool, uds_path: String, rank: usize, @@ -307,6 +332,11 @@ fn shard_manager( shard_argv.push(quantize.to_string()) } + if let Some(dtype) = dtype { + shard_argv.push("--dtype".to_string()); + shard_argv.push(dtype.to_string()) + } + // Model optional revision if let Some(revision) = revision { shard_argv.push("--revision".to_string()); @@ -743,6 +773,7 @@ fn spawn_shards( let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); let quantize = args.quantize; + let dtype = args.dtype; let trust_remote_code = args.trust_remote_code; let master_port = args.master_port; let disable_custom_kernels = args.disable_custom_kernels; @@ -753,6 +784,7 @@ fn spawn_shards( model_id, revision, quantize, + dtype, trust_remote_code, uds_path, rank, diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index aeb1f13b..3463049a 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -16,12 +16,18 @@ class Quantization(str, Enum): gptq = "gptq" +class Dtype(str, Enum): + float16 = "float16" + bloat16 = "bfloat16" + + @app.command() def serve( model_id: str, revision: Optional[str] = None, sharded: bool = False, quantize: Optional[Quantization] = None, + dtype: Optional[Dtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", @@ -64,7 +70,14 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value - server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path) + dtype = None if dtype is None else dtype.value + if dtype is not None and quantize is not None: + raise RuntimeError( + "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." + ) + server.serve( + model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path + ) @app.command() diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 2abde685..e45e198a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -100,11 +100,25 @@ def get_model( revision: Optional[str], sharded: bool, quantize: Optional[str], + dtype: Optional[str], trust_remote_code: bool, ) -> Model: + if dtype is None: + dtype = torch.float16 + elif dtype == "float16": + dtype = torch.float16 + elif dtype == "bfloat16": + dtype = torch.bfloat16 + else: + raise RuntimeError(f"Unknown dtype {dtype}") + if "facebook/galactica" in model_id: return GalacticaSharded( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + dtypetrust_remote_code=trust_remote_code, ) if model_id.startswith("bigcode/"): @@ -113,6 +127,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: @@ -124,6 +139,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -138,6 +154,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: @@ -149,12 +166,17 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) if model_type == "bloom": return BLOOMSharded( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) elif model_type == "gpt_neox": @@ -163,6 +185,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: @@ -170,6 +193,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) else: @@ -177,6 +201,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -186,6 +211,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: @@ -195,6 +221,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -210,6 +237,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) raise NotImplementedError( @@ -221,6 +249,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) else: @@ -228,12 +257,17 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif model_type == "opt": return OPTSharded( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) elif model_type == "t5": @@ -241,6 +275,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -253,11 +288,19 @@ def get_model( if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: return Seq2SeqLM( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) auto_map = config_dict.get("auto_map", None) @@ -267,6 +310,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) if "AutoModelForSeq2SeqLM" in auto_map.keys(): @@ -274,6 +318,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 50b3b76a..101da207 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -42,12 +42,13 @@ class BLOOMSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 6d47c6eb..cbdf4808 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -454,11 +454,12 @@ class CausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 2c59f01e..417ccabb 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -25,12 +25,13 @@ class FlashLlama(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index e64af0c6..61004d8e 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -24,12 +24,13 @@ class FlashNeoXSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index a55f9118..12b862d7 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -25,12 +25,13 @@ class FlashRWSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index ef202785..415ec2df 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -24,12 +24,13 @@ class FlashSantacoderSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 01e1c773..01e58bad 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -158,12 +158,13 @@ class GalacticaSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 0abf0239..91877fa0 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -24,12 +24,13 @@ class GPTNeoxSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 16cb48b7..d407b44a 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -22,12 +22,13 @@ class OPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 2b1e4959..92bb135b 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -12,11 +12,12 @@ class RW(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index d0fd3070..a2b38737 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -19,11 +19,12 @@ class SantaCoder(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 999b6637..9e5c21d1 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -504,11 +504,12 @@ class Seq2SeqLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index c89462fc..1b7073af 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -25,12 +25,13 @@ class T5Sharded(Seq2SeqLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 6cc5beeb..c375330a 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -106,6 +106,7 @@ def serve( revision: Optional[str], sharded: bool, quantize: Optional[str], + dtype: Optional[str], trust_remote_code: bool, uds_path: Path, ): @@ -114,6 +115,7 @@ def serve( revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, + dtype: Optional[str] = None, trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" @@ -128,7 +130,9 @@ def serve( server_urls = [local_url] try: - model = get_model(model_id, revision, sharded, quantize, trust_remote_code) + model = get_model( + model_id, revision, sharded, quantize, dtype, trust_remote_code + ) except Exception: logger.exception("Error when initializing model") raise @@ -159,4 +163,6 @@ def serve( logger.info("Signal received. Shutting down") await server.stop(0) - asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code)) + asyncio.run( + serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) + )