feat: Add the option to force another dtype than `f16`. (#513)
This commit is contained in:
parent
3b0c979efc
commit
ecf6dc3a5a
|
@ -36,6 +36,26 @@ impl std::fmt::Display for Quantization {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Dtype {
|
||||
Float16,
|
||||
BFloat16,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Dtype {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// To keep in track with `server`.
|
||||
match self {
|
||||
Dtype::Float16 => {
|
||||
write!(f, "float16")
|
||||
}
|
||||
Dtype::BFloat16 => {
|
||||
write!(f, "bfloat16")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
|
@ -71,6 +91,10 @@ struct Args {
|
|||
#[clap(long, env, value_enum)]
|
||||
quantize: Option<Quantization>,
|
||||
|
||||
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
|
||||
#[clap(long, env, value_enum)]
|
||||
dtype: Option<Dtype>,
|
||||
|
||||
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
||||
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
||||
/// contributed in a newer revision.
|
||||
|
@ -258,6 +282,7 @@ fn shard_manager(
|
|||
model_id: String,
|
||||
revision: Option<String>,
|
||||
quantize: Option<Quantization>,
|
||||
dtype: Option<Dtype>,
|
||||
trust_remote_code: bool,
|
||||
uds_path: String,
|
||||
rank: usize,
|
||||
|
@ -307,6 +332,11 @@ fn shard_manager(
|
|||
shard_argv.push(quantize.to_string())
|
||||
}
|
||||
|
||||
if let Some(dtype) = dtype {
|
||||
shard_argv.push("--dtype".to_string());
|
||||
shard_argv.push(dtype.to_string())
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(revision) = revision {
|
||||
shard_argv.push("--revision".to_string());
|
||||
|
@ -743,6 +773,7 @@ fn spawn_shards(
|
|||
let shutdown_sender = shutdown_sender.clone();
|
||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||
let quantize = args.quantize;
|
||||
let dtype = args.dtype;
|
||||
let trust_remote_code = args.trust_remote_code;
|
||||
let master_port = args.master_port;
|
||||
let disable_custom_kernels = args.disable_custom_kernels;
|
||||
|
@ -753,6 +784,7 @@ fn spawn_shards(
|
|||
model_id,
|
||||
revision,
|
||||
quantize,
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
rank,
|
||||
|
|
|
@ -16,12 +16,18 @@ class Quantization(str, Enum):
|
|||
gptq = "gptq"
|
||||
|
||||
|
||||
class Dtype(str, Enum):
|
||||
float16 = "float16"
|
||||
bloat16 = "bfloat16"
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
sharded: bool = False,
|
||||
quantize: Optional[Quantization] = None,
|
||||
dtype: Optional[Dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
uds_path: Path = "/tmp/text-generation-server",
|
||||
logger_level: str = "INFO",
|
||||
|
@ -64,7 +70,14 @@ def serve(
|
|||
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path)
|
||||
dtype = None if dtype is None else dtype.value
|
||||
if dtype is not None and quantize is not None:
|
||||
raise RuntimeError(
|
||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||
)
|
||||
server.serve(
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
|
@ -100,11 +100,25 @@ def get_model(
|
|||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
) -> Model:
|
||||
if dtype is None:
|
||||
dtype = torch.float16
|
||||
elif dtype == "float16":
|
||||
dtype = torch.float16
|
||||
elif dtype == "bfloat16":
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||
|
||||
if "facebook/galactica" in model_id:
|
||||
return GalacticaSharded(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
dtypetrust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_id.startswith("bigcode/"):
|
||||
|
@ -113,6 +127,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
|
@ -124,6 +139,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
|
@ -138,6 +154,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
|
@ -149,12 +166,17 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "bloom":
|
||||
return BLOOMSharded(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "gpt_neox":
|
||||
|
@ -163,6 +185,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
|
@ -170,6 +193,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
|
@ -177,6 +201,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
|
@ -186,6 +211,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
|
@ -195,6 +221,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
|
@ -210,6 +237,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
raise NotImplementedError(
|
||||
|
@ -221,6 +249,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
|
@ -228,12 +257,17 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "opt":
|
||||
return OPTSharded(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "t5":
|
||||
|
@ -241,6 +275,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
|
@ -253,11 +288,19 @@ def get_model(
|
|||
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||
return Seq2SeqLM(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
auto_map = config_dict.get("auto_map", None)
|
||||
|
@ -267,6 +310,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
||||
|
@ -274,6 +318,7 @@ def get_model(
|
|||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
|
|
|
@ -42,12 +42,13 @@ class BLOOMSharded(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
|
|
@ -454,11 +454,12 @@ class CausalLM(Model):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
|
|
@ -25,12 +25,13 @@ class FlashLlama(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
||||
|
|
|
@ -24,12 +24,13 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||
|
||||
|
|
|
@ -25,12 +25,13 @@ class FlashRWSharded(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashRW is only available on GPU")
|
||||
|
||||
|
|
|
@ -24,12 +24,13 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||
|
||||
|
|
|
@ -158,12 +158,13 @@ class GalacticaSharded(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
|
|
@ -24,12 +24,13 @@ class GPTNeoxSharded(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
|
|
@ -22,12 +22,13 @@ class OPTSharded(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
|
|
@ -12,11 +12,12 @@ class RW(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
|
|
@ -19,11 +19,12 @@ class SantaCoder(CausalLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
|
|
@ -504,11 +504,12 @@ class Seq2SeqLM(Model):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
|
|
@ -25,12 +25,13 @@ class T5Sharded(Seq2SeqLM):
|
|||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
|
|
@ -106,6 +106,7 @@ def serve(
|
|||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
):
|
||||
|
@ -114,6 +115,7 @@ def serve(
|
|||
revision: Optional[str],
|
||||
sharded: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
|
@ -128,7 +130,9 @@ def serve(
|
|||
server_urls = [local_url]
|
||||
|
||||
try:
|
||||
model = get_model(model_id, revision, sharded, quantize, trust_remote_code)
|
||||
model = get_model(
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
@ -159,4 +163,6 @@ def serve(
|
|||
logger.info("Signal received. Shutting down")
|
||||
await server.stop(0)
|
||||
|
||||
asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code))
|
||||
asyncio.run(
|
||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue