feat: Add the option to force another dtype than `f16`. (#513)

This commit is contained in:
Nicolas Patry 2023-06-30 20:30:09 +02:00 committed by GitHub
parent 3b0c979efc
commit ecf6dc3a5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 130 additions and 21 deletions

View File

@ -36,6 +36,26 @@ impl std::fmt::Display for Quantization {
} }
} }
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Dtype {
Float16,
BFloat16,
}
impl std::fmt::Display for Dtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
Dtype::Float16 => {
write!(f, "float16")
}
Dtype::BFloat16 => {
write!(f, "bfloat16")
}
}
}
}
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)] #[clap(author, version, about, long_about = None)]
@ -71,6 +91,10 @@ struct Args {
#[clap(long, env, value_enum)] #[clap(long, env, value_enum)]
quantize: Option<Quantization>, quantize: Option<Quantization>,
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been /// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision. /// contributed in a newer revision.
@ -258,6 +282,7 @@ fn shard_manager(
model_id: String, model_id: String,
revision: Option<String>, revision: Option<String>,
quantize: Option<Quantization>, quantize: Option<Quantization>,
dtype: Option<Dtype>,
trust_remote_code: bool, trust_remote_code: bool,
uds_path: String, uds_path: String,
rank: usize, rank: usize,
@ -307,6 +332,11 @@ fn shard_manager(
shard_argv.push(quantize.to_string()) shard_argv.push(quantize.to_string())
} }
if let Some(dtype) = dtype {
shard_argv.push("--dtype".to_string());
shard_argv.push(dtype.to_string())
}
// Model optional revision // Model optional revision
if let Some(revision) = revision { if let Some(revision) = revision {
shard_argv.push("--revision".to_string()); shard_argv.push("--revision".to_string());
@ -743,6 +773,7 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone(); let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone(); let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize; let quantize = args.quantize;
let dtype = args.dtype;
let trust_remote_code = args.trust_remote_code; let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port; let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels; let disable_custom_kernels = args.disable_custom_kernels;
@ -753,6 +784,7 @@ fn spawn_shards(
model_id, model_id,
revision, revision,
quantize, quantize,
dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
rank, rank,

View File

@ -16,12 +16,18 @@ class Quantization(str, Enum):
gptq = "gptq" gptq = "gptq"
class Dtype(str, Enum):
float16 = "float16"
bloat16 = "bfloat16"
@app.command() @app.command()
def serve( def serve(
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
sharded: bool = False, sharded: bool = False,
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
dtype: Optional[Dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO", logger_level: str = "INFO",
@ -64,7 +70,14 @@ def serve(
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path) dtype = None if dtype is None else dtype.value
if dtype is not None and quantize is not None:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
)
@app.command() @app.command()

View File

@ -100,11 +100,25 @@ def get_model(
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
if dtype is None:
dtype = torch.float16
elif dtype == "float16":
dtype = torch.float16
elif dtype == "bfloat16":
dtype = torch.bfloat16
else:
raise RuntimeError(f"Unknown dtype {dtype}")
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
return GalacticaSharded( return GalacticaSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id,
revision,
quantize=quantize,
dtype=dtype,
dtypetrust_remote_code=trust_remote_code,
) )
if model_id.startswith("bigcode/"): if model_id.startswith("bigcode/"):
@ -113,6 +127,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded: elif sharded:
@ -124,6 +139,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -138,6 +154,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded: elif sharded:
@ -149,12 +166,17 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "bloom": if model_type == "bloom":
return BLOOMSharded( return BLOOMSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
) )
elif model_type == "gpt_neox": elif model_type == "gpt_neox":
@ -163,6 +185,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded: elif sharded:
@ -170,6 +193,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
@ -177,6 +201,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -186,6 +211,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded: elif sharded:
@ -195,6 +221,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -210,6 +237,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
raise NotImplementedError( raise NotImplementedError(
@ -221,6 +249,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
@ -228,12 +257,17 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == "opt": elif model_type == "opt":
return OPTSharded( return OPTSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
) )
elif model_type == "t5": elif model_type == "t5":
@ -241,6 +275,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -253,11 +288,19 @@ def get_model(
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
) )
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM( return Seq2SeqLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
) )
auto_map = config_dict.get("auto_map", None) auto_map = config_dict.get("auto_map", None)
@ -267,6 +310,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if "AutoModelForSeq2SeqLM" in auto_map.keys(): if "AutoModelForSeq2SeqLM" in auto_map.keys():
@ -274,6 +318,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -42,12 +42,13 @@ class BLOOMSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -454,11 +454,12 @@ class CausalLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -25,12 +25,13 @@ class FlashLlama(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")

View File

@ -24,12 +24,13 @@ class FlashNeoXSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashNeoX is only available on GPU") raise NotImplementedError("FlashNeoX is only available on GPU")

View File

@ -25,12 +25,13 @@ class FlashRWSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashRW is only available on GPU") raise NotImplementedError("FlashRW is only available on GPU")

View File

@ -24,12 +24,13 @@ class FlashSantacoderSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")

View File

@ -158,12 +158,13 @@ class GalacticaSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -24,12 +24,13 @@ class GPTNeoxSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -22,12 +22,13 @@ class OPTSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -12,11 +12,12 @@ class RW(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -19,11 +19,12 @@ class SantaCoder(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -504,11 +504,12 @@ class Seq2SeqLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -25,12 +25,13 @@ class T5Sharded(Seq2SeqLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16 if dtype is None else dtype
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32

View File

@ -106,6 +106,7 @@ def serve(
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: Optional[str], quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
): ):
@ -114,6 +115,7 @@ def serve(
revision: Optional[str], revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
@ -128,7 +130,9 @@ def serve(
server_urls = [local_url] server_urls = [local_url]
try: try:
model = get_model(model_id, revision, sharded, quantize, trust_remote_code) model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code
)
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
raise raise
@ -159,4 +163,6 @@ def serve(
logger.info("Signal received. Shutting down") logger.info("Signal received. Shutting down")
await server.stop(0) await server.stop(0)
asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code)) asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
)