diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e6fe1372..29b049cf 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -153,7 +153,7 @@ def get_model( ) elif model_type == "mpt": return MPTSharded( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code ) elif model_type == "gpt_neox": diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 79fb60c6..0151b017 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -51,7 +51,7 @@ class BLOOMSharded(CausalLM): dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 + dtype = torch.float32 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 4e338263..cec9ae55 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -492,7 +492,7 @@ class CausalLM(Model): raise ValueError("quantization is not available on CPU") device = torch.device("cpu") - dtype = torch.float32 + dtype = torch.float32 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 047a1872..297d5c68 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -40,7 +40,7 @@ from text_generation_server.utils.layers import ( ) CUSTOM_KERNELS_ENABLED = False -if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": +if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": try: from custom_kernels import fused_bloom_attention_cuda diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 1951b171..c5b0c7fd 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -49,7 +49,7 @@ from text_generation_server.utils.layers import ( CUSTOM_KERNELS_ENABLED = False -if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": +if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": try: from custom_kernels import fused_attention_cuda diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index d4211734..cfd5e658 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -167,7 +167,7 @@ class GalacticaSharded(CausalLM): dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 + dtype = torch.float32 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index accedf14..d4c64dfe 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -33,7 +33,7 @@ class GPTNeoxSharded(CausalLM): dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 + dtype = torch.float32 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index c54b539b..fa23d1f9 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -42,7 +42,7 @@ class IDEFICSSharded(IdeficsCausalLM): dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 + dtype = torch.float32 if dtype is None else dtype self.device, self.dtype = device, dtype config = IdeficsConfig.from_pretrained( diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 2dac87bc..f4177145 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -560,7 +560,7 @@ class IdeficsCausalLM(Model): raise ValueError("quantization is not available on CPU") device = torch.device("cpu") - dtype = torch.float32 + dtype = torch.float32 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 909d9852..19de497c 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -43,14 +43,16 @@ class MPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: - raise NotImplementedError("MPTSharded is only available on GPU") + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index f3a23d07..b2b87246 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -31,7 +31,7 @@ class OPTSharded(CausalLM): dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 + dtype = torch.float32 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index d97c1c73..802a4aa6 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -23,7 +23,7 @@ class RW(CausalLM): raise ValueError("quantization is not available on CPU") device = torch.device("cpu") - dtype = torch.float32 + dtype = torch.float32 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 81928c1d..7b269d8e 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -30,7 +30,7 @@ class SantaCoder(CausalLM): raise ValueError("quantization is not available on CPU") device = torch.device("cpu") - dtype = torch.float32 + dtype = torch.float32 if dtype is None else dtype tokenizer = AutoTokenizer.from_pretrained( model_id, diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 361453fb..1a7911ac 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -541,7 +541,7 @@ class Seq2SeqLM(Model): raise ValueError("quantization is not available on CPU") device = torch.device("cpu") - dtype = torch.float32 + dtype = torch.float32 if dtype is None else dtype model = AutoModelForSeq2SeqLM.from_pretrained( model_id, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 133aafd8..161e69ba 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -34,7 +34,7 @@ class T5Sharded(Seq2SeqLM): dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") - dtype = torch.float32 + dtype = torch.float32 if dtype is None else dtype config = AutoConfig.from_pretrained( model_id,