feat(server): improve flash attention import errors (#465)
@lewtun, is this enough? Closes #458 Closes #456
This commit is contained in:
parent
f59fb8b630
commit
ece7ffa40a
|
@ -18,11 +18,43 @@ from text_generation_server.models.santacoder import SantaCoder
|
|||
from text_generation_server.models.t5 import T5Sharded
|
||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||
# in PyTorch 1.12 and later.
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Disable gradients
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
__all__ = [
|
||||
"Model",
|
||||
"BLOOMSharded",
|
||||
"CausalLM",
|
||||
"FlashCausalLM",
|
||||
"GalacticaSharded",
|
||||
"Seq2SeqLM",
|
||||
"SantaCoder",
|
||||
"OPTSharded",
|
||||
"T5Sharded",
|
||||
"get_model",
|
||||
]
|
||||
|
||||
FLASH_ATT_ERROR_MESSAGE = (
|
||||
"{} requires CUDA and Flash Attention kernels to be installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false"
|
||||
):
|
||||
if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||
if not torch.cuda.is_available():
|
||||
FLASH_ATT_ERROR_MESSAGE = (
|
||||
"{} requires CUDA. No compatible CUDA devices found."
|
||||
)
|
||||
raise ImportError("CUDA is not available")
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
is_sm8x = major == 8 and minor >= 0
|
||||
|
@ -30,6 +62,10 @@ try:
|
|||
|
||||
supported = is_sm75 or is_sm8x or is_sm90
|
||||
if not supported:
|
||||
FLASH_ATT_ERROR_MESSAGE = (
|
||||
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
|
||||
"No compatible CUDA device found."
|
||||
)
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||
)
|
||||
|
@ -52,41 +88,12 @@ except ImportError:
|
|||
)
|
||||
FLASH_ATTENTION = False
|
||||
|
||||
__all__ = [
|
||||
"Model",
|
||||
"BLOOMSharded",
|
||||
"CausalLM",
|
||||
"FlashCausalLM",
|
||||
"GalacticaSharded",
|
||||
"Seq2SeqLM",
|
||||
"SantaCoder",
|
||||
"OPTSharded",
|
||||
"T5Sharded",
|
||||
"get_model",
|
||||
]
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashNeoXSharded)
|
||||
__all__.append(FlashRWSharded)
|
||||
__all__.append(FlashSantacoderSharded)
|
||||
__all__.append(FlashLlama)
|
||||
|
||||
FLASH_ATT_ERROR_MESSAGE = (
|
||||
"{} requires Flash Attention CUDA kernels to be installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
)
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||
# in PyTorch 1.12 and later.
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Disable gradients
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def get_model(
|
||||
model_id: str,
|
||||
|
|
|
@ -16,9 +16,9 @@ def check_file_size(source_file: Path, target_file: Path):
|
|||
source_file_size = source_file.stat().st_size
|
||||
target_file_size = target_file.stat().st_size
|
||||
|
||||
if (source_file_size - target_file_size) / source_file_size > 0.01:
|
||||
if (source_file_size - target_file_size) / source_file_size > 0.05:
|
||||
raise RuntimeError(
|
||||
f"""The file size different is more than 1%:
|
||||
f"""The file size different is more than 5%:
|
||||
- {source_file}: {source_file_size}
|
||||
- {target_file}: {target_file_size}
|
||||
"""
|
||||
|
|
|
@ -26,7 +26,10 @@ def weight_hub_files(
|
|||
filenames = [
|
||||
s.rfilename
|
||||
for s in info.siblings
|
||||
if s.rfilename.endswith(extension) and len(s.rfilename.split("/")) == 1
|
||||
if s.rfilename.endswith(extension)
|
||||
and len(s.rfilename.split("/")) == 1
|
||||
and "arguments" not in s.rfilename
|
||||
and "args" not in s.rfilename
|
||||
]
|
||||
|
||||
if not filenames:
|
||||
|
|
Loading…
Reference in New Issue