feat(server): improve flash attention import errors (#465)

@lewtun, is this enough?

Closes #458
Closes #456
This commit is contained in:
OlivierDehaene 2023-06-19 09:53:45 +02:00 committed by GitHub
parent f59fb8b630
commit ece7ffa40a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 36 deletions

View File

@ -18,11 +18,43 @@ from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeoxSharded
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
# Disable gradients
torch.set_grad_enabled(False)
__all__ = [
"Model",
"BLOOMSharded",
"CausalLM",
"FlashCausalLM",
"GalacticaSharded",
"Seq2SeqLM",
"SantaCoder",
"OPTSharded",
"T5Sharded",
"get_model",
]
FLASH_ATT_ERROR_MESSAGE = (
"{} requires CUDA and Flash Attention kernels to be installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
try: try:
if ( if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
torch.cuda.is_available() if not torch.cuda.is_available():
and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false" FLASH_ATT_ERROR_MESSAGE = (
): "{} requires CUDA. No compatible CUDA devices found."
)
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0 is_sm8x = major == 8 and minor >= 0
@ -30,6 +62,10 @@ try:
supported = is_sm75 or is_sm8x or is_sm90 supported = is_sm75 or is_sm8x or is_sm90
if not supported: if not supported:
FLASH_ATT_ERROR_MESSAGE = (
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
"No compatible CUDA device found."
)
raise ImportError( raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported" f"GPU with CUDA capability {major} {minor} is not supported"
) )
@ -52,41 +88,12 @@ except ImportError:
) )
FLASH_ATTENTION = False FLASH_ATTENTION = False
__all__ = [
"Model",
"BLOOMSharded",
"CausalLM",
"FlashCausalLM",
"GalacticaSharded",
"Seq2SeqLM",
"SantaCoder",
"OPTSharded",
"T5Sharded",
"get_model",
]
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded) __all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded) __all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama) __all__.append(FlashLlama)
FLASH_ATT_ERROR_MESSAGE = (
"{} requires Flash Attention CUDA kernels to be installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
# Disable gradients
torch.set_grad_enabled(False)
def get_model( def get_model(
model_id: str, model_id: str,

View File

@ -16,9 +16,9 @@ def check_file_size(source_file: Path, target_file: Path):
source_file_size = source_file.stat().st_size source_file_size = source_file.stat().st_size
target_file_size = target_file.stat().st_size target_file_size = target_file.stat().st_size
if (source_file_size - target_file_size) / source_file_size > 0.01: if (source_file_size - target_file_size) / source_file_size > 0.05:
raise RuntimeError( raise RuntimeError(
f"""The file size different is more than 1%: f"""The file size different is more than 5%:
- {source_file}: {source_file_size} - {source_file}: {source_file_size}
- {target_file}: {target_file_size} - {target_file}: {target_file_size}
""" """

View File

@ -26,7 +26,10 @@ def weight_hub_files(
filenames = [ filenames = [
s.rfilename s.rfilename
for s in info.siblings for s in info.siblings
if s.rfilename.endswith(extension) and len(s.rfilename.split("/")) == 1 if s.rfilename.endswith(extension)
and len(s.rfilename.split("/")) == 1
and "arguments" not in s.rfilename
and "args" not in s.rfilename
] ]
if not filenames: if not filenames: