From e3e487dc711449c23826cfe1d74786f71309d6bd Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 23 May 2023 20:40:39 +0200 Subject: [PATCH] feat(server): support trust_remote_code (#363) --- .github/workflows/build.yaml | 3 +- launcher/src/main.rs | 28 +++- server/text_generation_server/cli.py | 3 +- .../text_generation_server/models/__init__.py | 147 +++++++++++++++--- server/text_generation_server/models/bloom.py | 23 ++- .../models/causal_lm.py | 49 ++++-- .../models/flash_causal_lm.py | 8 +- .../models/flash_llama.py | 10 +- .../models/flash_neox.py | 17 +- .../models/flash_santacoder.py | 14 +- .../models/galactica.py | 16 +- .../text_generation_server/models/gpt_neox.py | 16 +- server/text_generation_server/models/opt.py | 16 +- .../models/santacoder.py | 9 +- .../models/seq2seq_lm.py | 12 +- server/text_generation_server/models/t5.py | 16 +- server/text_generation_server/server.py | 6 +- 17 files changed, 321 insertions(+), 72 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 9992e0a2..124e6a33 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -213,13 +213,12 @@ jobs: sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }} - name: Install run: | - pip install pytest-xdist make install-integration-tests - name: Run tests run: | export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} - pytest -s -vv -n 2 --dist loadfile integration-tests + pytest -s -vv integration-tests stop-runner: name: Stop self-hosted EC2 runner diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 4a6cd621..dc12c90f 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -53,7 +53,7 @@ struct Args { #[clap(long, env)] revision: Option, - /// Wether to shard or not the model across multiple GPUs + /// Whether to shard the model across multiple GPUs /// By default text-generation-inference will use all available GPUs to run /// the model. Setting it to `false` deactivates `num_shard`. #[clap(long, env)] @@ -66,11 +66,17 @@ struct Args { #[clap(long, env)] num_shard: Option, - /// Wether you want the model to be quantized or not. This will use `bitsandbytes` for + /// Whether you want the model to be quantized. This will use `bitsandbytes` for /// quantization on the fly, or `gptq`. #[clap(long, env, value_enum)] quantize: Option, + /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is + /// encouraged when loading a model with custom code to ensure no malicious code has been + /// contributed in a newer revision. + #[clap(long, env, value_enum)] + trust_remote_code: bool, + /// The maximum amount of concurrent requests for this particular deployment. /// Having a low limit will refuse clients requests instead of having them /// wait for too long and is usually good to handle backpressure correctly. @@ -239,6 +245,7 @@ fn shard_manager( model_id: String, revision: Option, quantize: Option, + trust_remote_code: bool, uds_path: String, rank: usize, world_size: usize, @@ -272,6 +279,11 @@ fn shard_manager( "--json-output".to_string(), ]; + // Activate trust remote code + if trust_remote_code { + shard_argv.push("--trust-remote-code".to_string()); + } + // Activate tensor parallelism if world_size > 1 { shard_argv.push("--sharded".to_string()); @@ -692,6 +704,16 @@ fn spawn_shards( status_sender: mpsc::Sender, running: Arc, ) -> Result<(), LauncherError> { + if args.trust_remote_code { + tracing::warn!( + "`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.", + args.model_id + ); + if args.revision.is_none() { + tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision."); + } + } + // Start shard processes for rank in 0..num_shard { let model_id = args.model_id.clone(); @@ -705,6 +727,7 @@ fn spawn_shards( let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); let quantize = args.quantize; + let trust_remote_code = args.trust_remote_code; let master_port = args.master_port; let disable_custom_kernels = args.disable_custom_kernels; let watermark_gamma = args.watermark_gamma; @@ -714,6 +737,7 @@ fn spawn_shards( model_id, revision, quantize, + trust_remote_code, uds_path, rank, num_shard, diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 35cad20f..c0e6c2dc 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -22,6 +22,7 @@ def serve( revision: Optional[str] = None, sharded: bool = False, quantize: Optional[Quantization] = None, + trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", json_output: bool = False, @@ -63,7 +64,7 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value - server.serve(model_id, revision, sharded, quantize, uds_path) + server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path) @app.command() diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ec990fde..bf7a2849 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -91,13 +91,27 @@ torch.set_grad_enabled(False) def get_model( - model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str] + model_id: str, + revision: Optional[str], + sharded: bool, + quantize: Optional[str], + trust_remote_code: bool, ) -> Model: if "facebook/galactica" in model_id: if sharded: - return GalacticaSharded(model_id, revision, quantize=quantize) + return GalacticaSharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) else: - return Galactica(model_id, revision, quantize=quantize) + return Galactica( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) if model_id.startswith("bigcode/"): if sharded: @@ -105,12 +119,24 @@ def get_model( raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") ) - return FlashSantacoderSharded(model_id, revision, quantize=quantize) + return FlashSantacoderSharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) else: santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder - return santacoder_cls(model_id, revision, quantize=quantize) + return santacoder_cls( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) - config = AutoConfig.from_pretrained(model_id, revision=revision) + config = AutoConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) model_type = config.model_type if model_type == "gpt_bigcode": @@ -119,52 +145,133 @@ def get_model( raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") ) - return FlashSantacoderSharded(model_id, revision, quantize=quantize) + return FlashSantacoderSharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) else: santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder - return santacoder_cls(model_id, revision, quantize=quantize) + return santacoder_cls( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) if model_type == "bloom": if sharded: - return BLOOMSharded(model_id, revision, quantize=quantize) + return BLOOMSharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) else: - return BLOOM(model_id, revision, quantize=quantize) + return BLOOM( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) if model_type == "gpt_neox": if sharded: neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded - return neox_cls(model_id, revision, quantize=quantize) + return neox_cls( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) else: neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM - return neox_cls(model_id, revision, quantize=quantize) + return neox_cls( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) if model_type == "llama": if sharded: if FLASH_ATTENTION: - return FlashLlamaSharded(model_id, revision, quantize=quantize) + return FlashLlamaSharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama")) else: llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM - return llama_cls(model_id, revision, quantize=quantize) + return llama_cls( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) if config.model_type == "opt": if sharded: - return OPTSharded(model_id, revision, quantize=quantize) + return OPTSharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) else: - return OPT(model_id, revision, quantize=quantize) + return OPT( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) if model_type == "t5": if sharded: - return T5Sharded(model_id, revision, quantize=quantize) + return T5Sharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) else: - return Seq2SeqLM(model_id, revision, quantize=quantize) + return Seq2SeqLM( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) if sharded: raise ValueError("sharded is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM(model_id, revision, quantize=quantize) + return CausalLM( + model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + ) if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: - return Seq2SeqLM(model_id, revision, quantize=quantize) + return Seq2SeqLM( + model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + ) + + auto_map = getattr(config, "auto_map", None) + if trust_remote_code and auto_map is not None: + if "AutoModelForCausalLM" in auto_map.keys(): + return CausalLM( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) + if "AutoModelForSeq2SeqLM" in auto_map.keys: + return Seq2SeqLM( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) raise ValueError(f"Unsupported model type {model_type}") diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 9d609185..5eddc8cf 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -54,9 +54,13 @@ class BLOOM(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): super(BLOOM, self).__init__( - model_id=model_id, revision=revision, quantize=quantize + model_id=model_id, + revision=revision, + quantize=quantize, + trust_remote_code=trust_remote_code, ) @property @@ -70,6 +74,7 @@ class BLOOMSharded(BLOOM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -80,11 +85,19 @@ class BLOOMSharded(BLOOM): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left", truncation_side="left" + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, ) config = AutoConfig.from_pretrained( - model_id, revision=revision, slow_but_exact=False, tp_parallel=True + model_id, + revision=revision, + slow_but_exact=False, + tp_parallel=True, + trust_remote_code=trust_remote_code, ) config.pad_token_id = 3 @@ -92,7 +105,9 @@ class BLOOMSharded(BLOOM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config) + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code + ) torch.distributed.barrier(group=self.process_group) self.load_weights( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ab92feed..09df70d2 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,4 +1,5 @@ import torch +import inspect from dataclasses import dataclass from opentelemetry import trace @@ -450,6 +451,7 @@ class CausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -462,22 +464,38 @@ class CausalLM(Model): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left", truncation_side="left" + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, ) model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, - device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, + device_map="auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None, load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, ) if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda() - tokenizer.pad_token_id = ( - model.config.pad_token_id - if model.config.pad_token_id is not None - else model.config.eos_token_id + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + self.has_position_ids = ( + inspect.signature(model.forward).parameters.get("position_ids", None) + is not None ) super(CausalLM, self).__init__( @@ -501,14 +519,17 @@ class CausalLM(Model): self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=True, - return_dict=True, - ) + kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": True, + "return_dict": True, + } + if self.has_position_ids: + kwargs["position_ids"] = position_ids + + outputs = self.model.forward(**kwargs) return outputs.logits, outputs.past_key_values @tracer.start_as_current_span("generate_token") diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index aee0480d..d6e73ad8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -394,6 +394,7 @@ class FlashCausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -402,13 +403,18 @@ class FlashCausalLM(Model): raise NotImplementedError("FlashCausalLM is only available on GPU") tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left", truncation_side="left" + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, ) model = model_cls.from_pretrained( model_id, revision=revision, torch_dtype=dtype, load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, ).to(device) super(FlashCausalLM, self).__init__( diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index ebdbe206..fe28580d 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -33,6 +33,7 @@ class FlashLlama(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -45,11 +46,11 @@ class FlashLlama(FlashCausalLM): revision=revision, padding_side="left", truncation_side="left", + trust_remote_code=trust_remote_code, ) config = AutoConfig.from_pretrained( - model_id, - revision=revision, + model_id, revision=revision, trust_remote_code=trust_remote_code ) # We do not use from_pretrained as we modified the model internal module layout @@ -153,6 +154,7 @@ class FlashLlamaSharded(FlashLlama): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -166,11 +168,11 @@ class FlashLlamaSharded(FlashLlama): revision=revision, padding_side="left", truncation_side="left", + trust_remote_code=trust_remote_code, ) config = AutoConfig.from_pretrained( - model_id, - revision=revision, + model_id, revision=revision, trust_remote_code=trust_remote_code ) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 8c1c1a00..31ae7914 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -28,9 +28,14 @@ class FlashNeoX(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): super(FlashNeoX, self).__init__( - FlashGPTNeoXForCausalLM, model_id, revision, quantize + FlashGPTNeoXForCausalLM, + model_id, + revision, + quantize, + trust_remote_code=trust_remote_code, ) @@ -40,6 +45,7 @@ class FlashNeoXSharded(FlashNeoX): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -49,12 +55,15 @@ class FlashNeoXSharded(FlashNeoX): raise NotImplementedError("FlashNeoX is only available on GPU") tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left", truncation_side="left" + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, ) config = AutoConfig.from_pretrained( - model_id, - revision=revision, + model_id, revision=revision, trust_remote_code=trust_remote_code ) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 1fbaf252..482e0f54 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -32,6 +32,7 @@ class FlashSantacoder(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -40,7 +41,11 @@ class FlashSantacoder(FlashCausalLM): raise NotImplementedError("FlashSantacoder is only available on GPU") tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left", truncation_side="left" + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, ) config = GPT2Config.from_pretrained( @@ -178,6 +183,7 @@ class FlashSantacoderSharded(FlashSantacoder): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -187,7 +193,11 @@ class FlashSantacoderSharded(FlashSantacoder): raise NotImplementedError("FlashSantacoderSharded is only available on GPU") tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left", truncation_side="left" + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, ) config = GPT2Config.from_pretrained( diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 24c37c19..bc3096c6 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -199,6 +199,7 @@ class GalacticaSharded(Galactica): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -209,11 +210,18 @@ class GalacticaSharded(Galactica): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left", truncation_side="left" + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, ) config = AutoConfig.from_pretrained( - model_id, revision=revision, tp_parallel=True + model_id, + revision=revision, + tp_parallel=True, + trust_remote_code=trust_remote_code, ) tokenizer.pad_token_id = config.pad_token_id @@ -221,7 +229,9 @@ class GalacticaSharded(Galactica): filenames = weight_files(model_id, revision=revision, extension=".safetensors") with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config) + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code + ) torch.distributed.barrier(group=self.process_group) self.load_weights( diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index a10dfcb8..e4a85082 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -36,6 +36,7 @@ class GPTNeoxSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -46,19 +47,28 @@ class GPTNeoxSharded(CausalLM): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left", truncation_side="left" + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, ) tokenizer.pad_token = tokenizer.eos_token config = AutoConfig.from_pretrained( - model_id, revision=revision, tp_parallel=True + model_id, + revision=revision, + tp_parallel=True, + trust_remote_code=trust_remote_code, ) torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config) + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code + ) torch.distributed.barrier(group=self.process_group) self.load_weights( diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 8c676a51..bccce5b3 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -52,6 +52,7 @@ class OPTSharded(OPT): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -62,11 +63,18 @@ class OPTSharded(OPT): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left", truncation_side="left" + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, ) config = AutoConfig.from_pretrained( - model_id, revision=revision, tp_parallel=True + model_id, + revision=revision, + tp_parallel=True, + trust_remote_code=trust_remote_code, ) tokenizer.pad_token_id = config.pad_token_id @@ -74,7 +82,9 @@ class OPTSharded(OPT): filenames = weight_files(model_id, revision=revision, extension=".safetensors") with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config) + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code + ) torch.distributed.barrier(group=self.process_group) self.load_weights( diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 23f89f48..d0fd3070 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -19,6 +19,7 @@ class SantaCoder(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -31,7 +32,11 @@ class SantaCoder(CausalLM): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left", truncation_side="left" + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, ) tokenizer.add_special_tokens( { @@ -51,7 +56,7 @@ class SantaCoder(CausalLM): revision=revision, torch_dtype=dtype, load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=True, # required + trust_remote_code=trust_remote_code, ).to(device) super(CausalLM, self).__init__( diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index f8b404a9..a1a39fd4 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -503,6 +503,7 @@ class Seq2SeqLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -518,14 +519,21 @@ class Seq2SeqLM(Model): model_id, revision=revision, torch_dtype=dtype, - device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, + device_map="auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None, load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, ) if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda() tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left", truncation_side="left" + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, ) tokenizer.bos_token_id = model.config.decoder_start_token_id diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 7ecf948b..17cc50e0 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -36,6 +36,7 @@ class T5Sharded(Seq2SeqLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -46,11 +47,18 @@ class T5Sharded(Seq2SeqLM): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left", truncation_side="left" + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, ) config = AutoConfig.from_pretrained( - model_id, revision=revision, tp_parallel=True + model_id, + revision=revision, + tp_parallel=True, + trust_remote_code=trust_remote_code, ) tokenizer.bos_token_id = config.decoder_start_token_id @@ -58,7 +66,9 @@ class T5Sharded(Seq2SeqLM): filenames = weight_files(model_id, revision=revision, extension=".safetensors") with init_empty_weights(): - model = AutoModelForSeq2SeqLM.from_config(config) + model = AutoModelForSeq2SeqLM.from_config( + config, trust_remote_code=trust_remote_code + ) torch.distributed.barrier(group=self.process_group) self.load_weights( diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index d715207b..7ca5054e 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -101,6 +101,7 @@ def serve( revision: Optional[str], sharded: bool, quantize: Optional[str], + trust_remote_code: bool, uds_path: Path, ): async def serve_inner( @@ -108,6 +109,7 @@ def serve( revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, + trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" if sharded: @@ -121,7 +123,7 @@ def serve( server_urls = [local_url] try: - model = get_model(model_id, revision, sharded, quantize) + model = get_model(model_id, revision, sharded, quantize, trust_remote_code) except Exception: logger.exception("Error when initializing model") raise @@ -152,4 +154,4 @@ def serve( logger.info("Signal received. Shutting down") await server.stop(0) - asyncio.run(serve_inner(model_id, revision, sharded, quantize)) + asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code))