feat(server): support trust_remote_code (#363)

This commit is contained in:
OlivierDehaene 2023-05-23 20:40:39 +02:00 committed by GitHub
parent e9669a4085
commit e3e487dc71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 321 additions and 72 deletions

View File

@ -213,13 +213,12 @@ jobs:
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- name: Install
run: |
pip install pytest-xdist
make install-integration-tests
- name: Run tests
run: |
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv -n 2 --dist loadfile integration-tests
pytest -s -vv integration-tests
stop-runner:
name: Stop self-hosted EC2 runner

View File

@ -53,7 +53,7 @@ struct Args {
#[clap(long, env)]
revision: Option<String>,
/// Wether to shard or not the model across multiple GPUs
/// Whether to shard the model across multiple GPUs
/// By default text-generation-inference will use all available GPUs to run
/// the model. Setting it to `false` deactivates `num_shard`.
#[clap(long, env)]
@ -66,11 +66,17 @@ struct Args {
#[clap(long, env)]
num_shard: Option<usize>,
/// Wether you want the model to be quantized or not. This will use `bitsandbytes` for
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
/// quantization on the fly, or `gptq`.
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision.
#[clap(long, env, value_enum)]
trust_remote_code: bool,
/// The maximum amount of concurrent requests for this particular deployment.
/// Having a low limit will refuse clients requests instead of having them
/// wait for too long and is usually good to handle backpressure correctly.
@ -239,6 +245,7 @@ fn shard_manager(
model_id: String,
revision: Option<String>,
quantize: Option<Quantization>,
trust_remote_code: bool,
uds_path: String,
rank: usize,
world_size: usize,
@ -272,6 +279,11 @@ fn shard_manager(
"--json-output".to_string(),
];
// Activate trust remote code
if trust_remote_code {
shard_argv.push("--trust-remote-code".to_string());
}
// Activate tensor parallelism
if world_size > 1 {
shard_argv.push("--sharded".to_string());
@ -692,6 +704,16 @@ fn spawn_shards(
status_sender: mpsc::Sender<ShardStatus>,
running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
if args.trust_remote_code {
tracing::warn!(
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
args.model_id
);
if args.revision.is_none() {
tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.");
}
}
// Start shard processes
for rank in 0..num_shard {
let model_id = args.model_id.clone();
@ -705,6 +727,7 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma;
@ -714,6 +737,7 @@ fn spawn_shards(
model_id,
revision,
quantize,
trust_remote_code,
uds_path,
rank,
num_shard,

View File

@ -22,6 +22,7 @@ def serve(
revision: Optional[str] = None,
sharded: bool = False,
quantize: Optional[Quantization] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
@ -63,7 +64,7 @@ def serve(
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
server.serve(model_id, revision, sharded, quantize, uds_path)
server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path)
@app.command()

View File

@ -91,13 +91,27 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
trust_remote_code: bool,
) -> Model:
if "facebook/galactica" in model_id:
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
return GalacticaSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return Galactica(model_id, revision, quantize=quantize)
return Galactica(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_id.startswith("bigcode/"):
if sharded:
@ -105,12 +119,24 @@ def get_model(
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize=quantize)
return santacoder_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(model_id, revision=revision)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config.model_type
if model_type == "gpt_bigcode":
@ -119,52 +145,133 @@ def get_model(
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize=quantize)
return santacoder_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "bloom":
if sharded:
return BLOOMSharded(model_id, revision, quantize=quantize)
return BLOOMSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return BLOOM(model_id, revision, quantize=quantize)
return BLOOM(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "gpt_neox":
if sharded:
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
return neox_cls(model_id, revision, quantize=quantize)
return neox_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
return neox_cls(model_id, revision, quantize=quantize)
return neox_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "llama":
if sharded:
if FLASH_ATTENTION:
return FlashLlamaSharded(model_id, revision, quantize=quantize)
return FlashLlamaSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
else:
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
return llama_cls(model_id, revision, quantize=quantize)
return llama_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if config.model_type == "opt":
if sharded:
return OPTSharded(model_id, revision, quantize=quantize)
return OPTSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return OPT(model_id, revision, quantize=quantize)
return OPT(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "t5":
if sharded:
return T5Sharded(model_id, revision, quantize=quantize)
return T5Sharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return Seq2SeqLM(model_id, revision, quantize=quantize)
return Seq2SeqLM(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if sharded:
raise ValueError("sharded is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(model_id, revision, quantize=quantize)
return CausalLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM(model_id, revision, quantize=quantize)
return Seq2SeqLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
auto_map = getattr(config, "auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if "AutoModelForSeq2SeqLM" in auto_map.keys:
return Seq2SeqLM(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
raise ValueError(f"Unsupported model type {model_type}")

View File

@ -54,9 +54,13 @@ class BLOOM(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
super(BLOOM, self).__init__(
model_id=model_id, revision=revision, quantize=quantize
model_id=model_id,
revision=revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
@property
@ -70,6 +74,7 @@ class BLOOMSharded(BLOOM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -80,11 +85,19 @@ class BLOOMSharded(BLOOM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, slow_but_exact=False, tp_parallel=True
model_id,
revision=revision,
slow_but_exact=False,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.pad_token_id = 3
@ -92,7 +105,9 @@ class BLOOMSharded(BLOOM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(

View File

@ -1,4 +1,5 @@
import torch
import inspect
from dataclasses import dataclass
from opentelemetry import trace
@ -450,6 +451,7 @@ class CausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -462,22 +464,38 @@ class CausalLM(Model):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
tokenizer.pad_token_id = (
model.config.pad_token_id
if model.config.pad_token_id is not None
else model.config.eos_token_id
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
super(CausalLM, self).__init__(
@ -501,14 +519,17 @@ class CausalLM(Model):
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
)
kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
"return_dict": True,
}
if self.has_position_ids:
kwargs["position_ids"] = position_ids
outputs = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token")

View File

@ -394,6 +394,7 @@ class FlashCausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -402,13 +403,18 @@ class FlashCausalLM(Model):
raise NotImplementedError("FlashCausalLM is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = model_cls.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
).to(device)
super(FlashCausalLM, self).__init__(

View File

@ -33,6 +33,7 @@ class FlashLlama(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -45,11 +46,11 @@ class FlashLlama(FlashCausalLM):
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# We do not use from_pretrained as we modified the model internal module layout
@ -153,6 +154,7 @@ class FlashLlamaSharded(FlashLlama):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -166,11 +168,11 @@ class FlashLlamaSharded(FlashLlama):
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
model_id, revision=revision, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)

View File

@ -28,9 +28,14 @@ class FlashNeoX(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
super(FlashNeoX, self).__init__(
FlashGPTNeoXForCausalLM, model_id, revision, quantize
FlashGPTNeoXForCausalLM,
model_id,
revision,
quantize,
trust_remote_code=trust_remote_code,
)
@ -40,6 +45,7 @@ class FlashNeoXSharded(FlashNeoX):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -49,12 +55,15 @@ class FlashNeoXSharded(FlashNeoX):
raise NotImplementedError("FlashNeoX is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
model_id, revision=revision, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)

View File

@ -32,6 +32,7 @@ class FlashSantacoder(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -40,7 +41,11 @@ class FlashSantacoder(FlashCausalLM):
raise NotImplementedError("FlashSantacoder is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = GPT2Config.from_pretrained(
@ -178,6 +183,7 @@ class FlashSantacoderSharded(FlashSantacoder):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -187,7 +193,11 @@ class FlashSantacoderSharded(FlashSantacoder):
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = GPT2Config.from_pretrained(

View File

@ -199,6 +199,7 @@ class GalacticaSharded(Galactica):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -209,11 +210,18 @@ class GalacticaSharded(Galactica):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token_id = config.pad_token_id
@ -221,7 +229,9 @@ class GalacticaSharded(Galactica):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(

View File

@ -36,6 +36,7 @@ class GPTNeoxSharded(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -46,19 +47,28 @@ class GPTNeoxSharded(CausalLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(

View File

@ -52,6 +52,7 @@ class OPTSharded(OPT):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -62,11 +63,18 @@ class OPTSharded(OPT):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token_id = config.pad_token_id
@ -74,7 +82,9 @@ class OPTSharded(OPT):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(

View File

@ -19,6 +19,7 @@ class SantaCoder(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -31,7 +32,11 @@ class SantaCoder(CausalLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.add_special_tokens(
{
@ -51,7 +56,7 @@ class SantaCoder(CausalLM):
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=True, # required
trust_remote_code=trust_remote_code,
).to(device)
super(CausalLM, self).__init__(

View File

@ -503,6 +503,7 @@ class Seq2SeqLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -518,14 +519,21 @@ class Seq2SeqLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = model.config.decoder_start_token_id

View File

@ -36,6 +36,7 @@ class T5Sharded(Seq2SeqLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -46,11 +47,18 @@ class T5Sharded(Seq2SeqLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
@ -58,7 +66,9 @@ class T5Sharded(Seq2SeqLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(config)
model = AutoModelForSeq2SeqLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(

View File

@ -101,6 +101,7 @@ def serve(
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
trust_remote_code: bool,
uds_path: Path,
):
async def serve_inner(
@ -108,6 +109,7 @@ def serve(
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
unix_socket_template = "unix://{}-{}"
if sharded:
@ -121,7 +123,7 @@ def serve(
server_urls = [local_url]
try:
model = get_model(model_id, revision, sharded, quantize)
model = get_model(model_id, revision, sharded, quantize, trust_remote_code)
except Exception:
logger.exception("Error when initializing model")
raise
@ -152,4 +154,4 @@ def serve(
logger.info("Signal received. Shutting down")
await server.stop(0)
asyncio.run(serve_inner(model_id, revision, sharded, quantize))
asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code))