feat: Add the option to force another dtype than `f16`. (#513)

This commit is contained in:
Nicolas Patry 2023-06-30 20:30:09 +02:00 committed by GitHub
parent 3b0c979efc
commit ecf6dc3a5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 130 additions and 21 deletions

View File

@ -36,6 +36,26 @@ impl std::fmt::Display for Quantization {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
Float16,
BFloat16,
}
impl std::fmt::Display for Dtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
Dtype::Float16 => {
write!(f, "float16")
}
Dtype::BFloat16 => {
write!(f, "bfloat16")
}
}
}
}
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
@ -71,6 +91,10 @@ struct Args {
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision.
@ -258,6 +282,7 @@ fn shard_manager(
model_id: String,
revision: Option<String>,
quantize: Option<Quantization>,
dtype: Option<Dtype>,
trust_remote_code: bool,
uds_path: String,
rank: usize,
@ -307,6 +332,11 @@ fn shard_manager(
shard_argv.push(quantize.to_string())
}
if let Some(dtype) = dtype {
shard_argv.push("--dtype".to_string());
shard_argv.push(dtype.to_string())
}
// Model optional revision
if let Some(revision) = revision {
shard_argv.push("--revision".to_string());
@ -743,6 +773,7 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize;
let dtype = args.dtype;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels;
@ -753,6 +784,7 @@ fn spawn_shards(
model_id,
revision,
quantize,
dtype,
trust_remote_code,
uds_path,
rank,

View File

@ -16,12 +16,18 @@ class Quantization(str, Enum):
gptq = "gptq"
class Dtype(str, Enum):
float16 = "float16"
bloat16 = "bfloat16"
@app.command()
def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
quantize: Optional[Quantization] = None,
dtype: Optional[Dtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
@ -64,7 +70,14 @@ def serve(
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path)
dtype = None if dtype is None else dtype.value
if dtype is not None and quantize is not None:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
)
@app.command()

View File

@ -100,11 +100,25 @@ def get_model(
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
) -> Model:
if dtype is None:
dtype = torch.float16
elif dtype == "float16":
dtype = torch.float16
elif dtype == "bfloat16":
dtype = torch.bfloat16
else:
raise RuntimeError(f"Unknown dtype {dtype}")
if "facebook/galactica" in model_id:
return GalacticaSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
model_id,
revision,
quantize=quantize,
dtype=dtype,
dtypetrust_remote_code=trust_remote_code,
)
if model_id.startswith("bigcode/"):
@ -113,6 +127,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
@ -124,6 +139,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -138,6 +154,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
@ -149,12 +166,17 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "bloom":
return BLOOMSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "gpt_neox":
@ -163,6 +185,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
@ -170,6 +193,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
@ -177,6 +201,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -186,6 +211,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
@ -195,6 +221,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -210,6 +237,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(
@ -221,6 +249,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
@ -228,12 +257,17 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "opt":
return OPTSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "t5":
@ -241,6 +275,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -253,11 +288,19 @@ def get_model(
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
auto_map = config_dict.get("auto_map", None)
@ -267,6 +310,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if "AutoModelForSeq2SeqLM" in auto_map.keys():
@ -274,6 +318,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -42,12 +42,13 @@ class BLOOMSharded(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -454,11 +454,12 @@ class CausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -25,12 +25,13 @@ class FlashLlama(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")

View File

@ -24,12 +24,13 @@ class FlashNeoXSharded(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashNeoX is only available on GPU")

View File

@ -25,12 +25,13 @@ class FlashRWSharded(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashRW is only available on GPU")

View File

@ -24,12 +24,13 @@ class FlashSantacoderSharded(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")

View File

@ -158,12 +158,13 @@ class GalacticaSharded(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -24,12 +24,13 @@ class GPTNeoxSharded(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -22,12 +22,13 @@ class OPTSharded(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -12,11 +12,12 @@ class RW(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -19,11 +19,12 @@ class SantaCoder(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -504,11 +504,12 @@ class Seq2SeqLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")

View File

@ -25,12 +25,13 @@ class T5Sharded(Seq2SeqLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32

View File

@ -106,6 +106,7 @@ def serve(
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
):
@ -114,6 +115,7 @@ def serve(
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
):
unix_socket_template = "unix://{}-{}"
@ -128,7 +130,9 @@ def serve(
server_urls = [local_url]
try:
model = get_model(model_id, revision, sharded, quantize, trust_remote_code)
model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code
)
except Exception:
logger.exception("Error when initializing model")
raise
@ -159,4 +163,6 @@ def serve(
logger.info("Signal received. Shutting down")
await server.stop(0)
asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code))
asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
)