feat(server): support trust_remote_code (#363)
This commit is contained in:
parent
e9669a4085
commit
e3e487dc71
|
@ -213,13 +213,12 @@ jobs:
|
|||
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
||||
- name: Install
|
||||
run: |
|
||||
pip install pytest-xdist
|
||||
make install-integration-tests
|
||||
- name: Run tests
|
||||
run: |
|
||||
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
pytest -s -vv -n 2 --dist loadfile integration-tests
|
||||
pytest -s -vv integration-tests
|
||||
|
||||
stop-runner:
|
||||
name: Stop self-hosted EC2 runner
|
||||
|
|
|
@ -53,7 +53,7 @@ struct Args {
|
|||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// Wether to shard or not the model across multiple GPUs
|
||||
/// Whether to shard the model across multiple GPUs
|
||||
/// By default text-generation-inference will use all available GPUs to run
|
||||
/// the model. Setting it to `false` deactivates `num_shard`.
|
||||
#[clap(long, env)]
|
||||
|
@ -66,11 +66,17 @@ struct Args {
|
|||
#[clap(long, env)]
|
||||
num_shard: Option<usize>,
|
||||
|
||||
/// Wether you want the model to be quantized or not. This will use `bitsandbytes` for
|
||||
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
|
||||
/// quantization on the fly, or `gptq`.
|
||||
#[clap(long, env, value_enum)]
|
||||
quantize: Option<Quantization>,
|
||||
|
||||
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
||||
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
||||
/// contributed in a newer revision.
|
||||
#[clap(long, env, value_enum)]
|
||||
trust_remote_code: bool,
|
||||
|
||||
/// The maximum amount of concurrent requests for this particular deployment.
|
||||
/// Having a low limit will refuse clients requests instead of having them
|
||||
/// wait for too long and is usually good to handle backpressure correctly.
|
||||
|
@ -239,6 +245,7 @@ fn shard_manager(
|
|||
model_id: String,
|
||||
revision: Option<String>,
|
||||
quantize: Option<Quantization>,
|
||||
trust_remote_code: bool,
|
||||
uds_path: String,
|
||||
rank: usize,
|
||||
world_size: usize,
|
||||
|
@ -272,6 +279,11 @@ fn shard_manager(
|
|||
"--json-output".to_string(),
|
||||
];
|
||||
|
||||
// Activate trust remote code
|
||||
if trust_remote_code {
|
||||
shard_argv.push("--trust-remote-code".to_string());
|
||||
}
|
||||
|
||||
// Activate tensor parallelism
|
||||
if world_size > 1 {
|
||||
shard_argv.push("--sharded".to_string());
|
||||
|
@ -692,6 +704,16 @@ fn spawn_shards(
|
|||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
running: Arc<AtomicBool>,
|
||||
) -> Result<(), LauncherError> {
|
||||
if args.trust_remote_code {
|
||||
tracing::warn!(
|
||||
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
|
||||
args.model_id
|
||||
);
|
||||
if args.revision.is_none() {
|
||||
tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.");
|
||||
}
|
||||
}
|
||||
|
||||
// Start shard processes
|
||||
for rank in 0..num_shard {
|
||||
let model_id = args.model_id.clone();
|
||||
|
@ -705,6 +727,7 @@ fn spawn_shards(
|
|||
let shutdown_sender = shutdown_sender.clone();
|
||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||
let quantize = args.quantize;
|
||||
let trust_remote_code = args.trust_remote_code;
|
||||
let master_port = args.master_port;
|
||||
let disable_custom_kernels = args.disable_custom_kernels;
|
||||
let watermark_gamma = args.watermark_gamma;
|
||||
|
@ -714,6 +737,7 @@ fn spawn_shards(
|
|||
model_id,
|
||||
revision,
|
||||
quantize,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
rank,
|
||||
num_shard,
|
||||
|
|
|
@ -22,6 +22,7 @@ def serve(
|
|||
revision: Optional[str] = None,
|
||||
sharded: bool = False,
|
||||
quantize: Optional[Quantization] = None,
|
||||
trust_remote_code: bool = False,
|
||||
uds_path: Path = "/tmp/text-generation-server",
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
|
@ -63,7 +64,7 @@ def serve(
|
|||
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
server.serve(model_id, revision, sharded, quantize, uds_path)
|
||||
server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
|
@ -91,13 +91,27 @@ torch.set_grad_enabled(False)
|
|||
|
||||
|
||||
def get_model(
|
||||
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
|
||||
model_id: str,
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
) -> Model:
|
||||
if "facebook/galactica" in model_id:
|
||||
if sharded:
|
||||
return GalacticaSharded(model_id, revision, quantize=quantize)
|
||||
return GalacticaSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return Galactica(model_id, revision, quantize=quantize)
|
||||
return Galactica(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_id.startswith("bigcode/"):
|
||||
if sharded:
|
||||
|
@ -105,12 +119,24 @@ def get_model(
|
|||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||
)
|
||||
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
|
||||
return FlashSantacoderSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||
return santacoder_cls(model_id, revision, quantize=quantize)
|
||||
return santacoder_cls(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
model_type = config.model_type
|
||||
|
||||
if model_type == "gpt_bigcode":
|
||||
|
@ -119,52 +145,133 @@ def get_model(
|
|||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||
)
|
||||
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
|
||||
return FlashSantacoderSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||
return santacoder_cls(model_id, revision, quantize=quantize)
|
||||
return santacoder_cls(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "bloom":
|
||||
if sharded:
|
||||
return BLOOMSharded(model_id, revision, quantize=quantize)
|
||||
return BLOOMSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return BLOOM(model_id, revision, quantize=quantize)
|
||||
return BLOOM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "gpt_neox":
|
||||
if sharded:
|
||||
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
|
||||
return neox_cls(model_id, revision, quantize=quantize)
|
||||
return neox_cls(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
|
||||
return neox_cls(model_id, revision, quantize=quantize)
|
||||
return neox_cls(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "llama":
|
||||
if sharded:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashLlamaSharded(model_id, revision, quantize=quantize)
|
||||
return FlashLlamaSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
|
||||
else:
|
||||
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
|
||||
return llama_cls(model_id, revision, quantize=quantize)
|
||||
return llama_cls(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if config.model_type == "opt":
|
||||
if sharded:
|
||||
return OPTSharded(model_id, revision, quantize=quantize)
|
||||
return OPTSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return OPT(model_id, revision, quantize=quantize)
|
||||
return OPT(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "t5":
|
||||
if sharded:
|
||||
return T5Sharded(model_id, revision, quantize=quantize)
|
||||
return T5Sharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
||||
return Seq2SeqLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if sharded:
|
||||
raise ValueError("sharded is not supported for AutoModel")
|
||||
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(model_id, revision, quantize=quantize)
|
||||
return CausalLM(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
||||
return Seq2SeqLM(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
auto_map = getattr(config, "auto_map", None)
|
||||
if trust_remote_code and auto_map is not None:
|
||||
if "AutoModelForCausalLM" in auto_map.keys():
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if "AutoModelForSeq2SeqLM" in auto_map.keys:
|
||||
return Seq2SeqLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported model type {model_type}")
|
||||
|
|
|
@ -54,9 +54,13 @@ class BLOOM(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(BLOOM, self).__init__(
|
||||
model_id=model_id, revision=revision, quantize=quantize
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -70,6 +74,7 @@ class BLOOMSharded(BLOOM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -80,11 +85,19 @@ class BLOOMSharded(BLOOM):
|
|||
dtype = torch.float32
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, slow_but_exact=False, tp_parallel=True
|
||||
model_id,
|
||||
revision=revision,
|
||||
slow_but_exact=False,
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.pad_token_id = 3
|
||||
|
||||
|
@ -92,7 +105,9 @@ class BLOOMSharded(BLOOM):
|
|||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import inspect
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
|
@ -450,6 +451,7 @@ class CausalLM(Model):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
@ -462,22 +464,38 @@ class CausalLM(Model):
|
|||
dtype = torch.float32
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
|
||||
tokenizer.pad_token_id = (
|
||||
model.config.pad_token_id
|
||||
if model.config.pad_token_id is not None
|
||||
else model.config.eos_token_id
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.pad_token_id
|
||||
elif model.config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
self.has_position_ids = (
|
||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||
is not None
|
||||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
|
@ -501,14 +519,17 @@ class CausalLM(Model):
|
|||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": True,
|
||||
"return_dict": True,
|
||||
}
|
||||
if self.has_position_ids:
|
||||
kwargs["position_ids"] = position_ids
|
||||
|
||||
outputs = self.model.forward(**kwargs)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
|
|
|
@ -394,6 +394,7 @@ class FlashCausalLM(Model):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
@ -402,13 +403,18 @@ class FlashCausalLM(Model):
|
|||
raise NotImplementedError("FlashCausalLM is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = model_cls.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
).to(device)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
|
|
|
@ -33,6 +33,7 @@ class FlashLlama(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
@ -45,11 +46,11 @@ class FlashLlama(FlashCausalLM):
|
|||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
# We do not use from_pretrained as we modified the model internal module layout
|
||||
|
@ -153,6 +154,7 @@ class FlashLlamaSharded(FlashLlama):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -166,11 +168,11 @@ class FlashLlamaSharded(FlashLlama):
|
|||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
|
|
@ -28,9 +28,14 @@ class FlashNeoX(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(FlashNeoX, self).__init__(
|
||||
FlashGPTNeoXForCausalLM, model_id, revision, quantize
|
||||
FlashGPTNeoXForCausalLM,
|
||||
model_id,
|
||||
revision,
|
||||
quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
|
@ -40,6 +45,7 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -49,12 +55,15 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
|
|
@ -32,6 +32,7 @@ class FlashSantacoder(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
@ -40,7 +41,11 @@ class FlashSantacoder(FlashCausalLM):
|
|||
raise NotImplementedError("FlashSantacoder is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = GPT2Config.from_pretrained(
|
||||
|
@ -178,6 +183,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -187,7 +193,11 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = GPT2Config.from_pretrained(
|
||||
|
|
|
@ -199,6 +199,7 @@ class GalacticaSharded(Galactica):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -209,11 +210,18 @@ class GalacticaSharded(Galactica):
|
|||
dtype = torch.float32
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, tp_parallel=True
|
||||
model_id,
|
||||
revision=revision,
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
|
@ -221,7 +229,9 @@ class GalacticaSharded(Galactica):
|
|||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
|
|
|
@ -36,6 +36,7 @@ class GPTNeoxSharded(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -46,19 +47,28 @@ class GPTNeoxSharded(CausalLM):
|
|||
dtype = torch.float32
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, tp_parallel=True
|
||||
model_id,
|
||||
revision=revision,
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
|
|
|
@ -52,6 +52,7 @@ class OPTSharded(OPT):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -62,11 +63,18 @@ class OPTSharded(OPT):
|
|||
dtype = torch.float32
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, tp_parallel=True
|
||||
model_id,
|
||||
revision=revision,
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
|
@ -74,7 +82,9 @@ class OPTSharded(OPT):
|
|||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
|
|
|
@ -19,6 +19,7 @@ class SantaCoder(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
@ -31,7 +32,11 @@ class SantaCoder(CausalLM):
|
|||
dtype = torch.float32
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
|
@ -51,7 +56,7 @@ class SantaCoder(CausalLM):
|
|||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=True, # required
|
||||
trust_remote_code=trust_remote_code,
|
||||
).to(device)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
|
|
|
@ -503,6 +503,7 @@ class Seq2SeqLM(Model):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
@ -518,14 +519,21 @@ class Seq2SeqLM(Model):
|
|||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ class T5Sharded(Seq2SeqLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -46,11 +47,18 @@ class T5Sharded(Seq2SeqLM):
|
|||
dtype = torch.float32
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, tp_parallel=True
|
||||
model_id,
|
||||
revision=revision,
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||
|
||||
|
@ -58,7 +66,9 @@ class T5Sharded(Seq2SeqLM):
|
|||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForSeq2SeqLM.from_config(config)
|
||||
model = AutoModelForSeq2SeqLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
|
|
|
@ -101,6 +101,7 @@ def serve(
|
|||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
):
|
||||
async def serve_inner(
|
||||
|
@ -108,6 +109,7 @@ def serve(
|
|||
revision: Optional[str],
|
||||
sharded: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
if sharded:
|
||||
|
@ -121,7 +123,7 @@ def serve(
|
|||
server_urls = [local_url]
|
||||
|
||||
try:
|
||||
model = get_model(model_id, revision, sharded, quantize)
|
||||
model = get_model(model_id, revision, sharded, quantize, trust_remote_code)
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
@ -152,4 +154,4 @@ def serve(
|
|||
logger.info("Signal received. Shutting down")
|
||||
await server.stop(0)
|
||||
|
||||
asyncio.run(serve_inner(model_id, revision, sharded, quantize))
|
||||
asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code))
|
||||
|
|
Loading…
Reference in New Issue