feat(server): improve flash attention import errors (#465)
@lewtun, is this enough? Closes #458 Closes #456
This commit is contained in:
parent
f59fb8b630
commit
ece7ffa40a
|
@ -18,11 +18,43 @@ from text_generation_server.models.santacoder import SantaCoder
|
||||||
from text_generation_server.models.t5 import T5Sharded
|
from text_generation_server.models.t5 import T5Sharded
|
||||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||||
|
|
||||||
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
|
# in PyTorch 1.12 and later.
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
# Disable gradients
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Model",
|
||||||
|
"BLOOMSharded",
|
||||||
|
"CausalLM",
|
||||||
|
"FlashCausalLM",
|
||||||
|
"GalacticaSharded",
|
||||||
|
"Seq2SeqLM",
|
||||||
|
"SantaCoder",
|
||||||
|
"OPTSharded",
|
||||||
|
"T5Sharded",
|
||||||
|
"get_model",
|
||||||
|
]
|
||||||
|
|
||||||
|
FLASH_ATT_ERROR_MESSAGE = (
|
||||||
|
"{} requires CUDA and Flash Attention kernels to be installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if (
|
if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
torch.cuda.is_available()
|
if not torch.cuda.is_available():
|
||||||
and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false"
|
FLASH_ATT_ERROR_MESSAGE = (
|
||||||
):
|
"{} requires CUDA. No compatible CUDA devices found."
|
||||||
|
)
|
||||||
|
raise ImportError("CUDA is not available")
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
is_sm8x = major == 8 and minor >= 0
|
is_sm8x = major == 8 and minor >= 0
|
||||||
|
@ -30,6 +62,10 @@ try:
|
||||||
|
|
||||||
supported = is_sm75 or is_sm8x or is_sm90
|
supported = is_sm75 or is_sm8x or is_sm90
|
||||||
if not supported:
|
if not supported:
|
||||||
|
FLASH_ATT_ERROR_MESSAGE = (
|
||||||
|
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
|
||||||
|
"No compatible CUDA device found."
|
||||||
|
)
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||||
)
|
)
|
||||||
|
@ -52,41 +88,12 @@ except ImportError:
|
||||||
)
|
)
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Model",
|
|
||||||
"BLOOMSharded",
|
|
||||||
"CausalLM",
|
|
||||||
"FlashCausalLM",
|
|
||||||
"GalacticaSharded",
|
|
||||||
"Seq2SeqLM",
|
|
||||||
"SantaCoder",
|
|
||||||
"OPTSharded",
|
|
||||||
"T5Sharded",
|
|
||||||
"get_model",
|
|
||||||
]
|
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
__all__.append(FlashNeoXSharded)
|
__all__.append(FlashNeoXSharded)
|
||||||
__all__.append(FlashRWSharded)
|
__all__.append(FlashRWSharded)
|
||||||
__all__.append(FlashSantacoderSharded)
|
__all__.append(FlashSantacoderSharded)
|
||||||
__all__.append(FlashLlama)
|
__all__.append(FlashLlama)
|
||||||
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = (
|
|
||||||
"{} requires Flash Attention CUDA kernels to be installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
||||||
)
|
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
|
||||||
# in PyTorch 1.12 and later.
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
# Disable gradients
|
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -16,9 +16,9 @@ def check_file_size(source_file: Path, target_file: Path):
|
||||||
source_file_size = source_file.stat().st_size
|
source_file_size = source_file.stat().st_size
|
||||||
target_file_size = target_file.stat().st_size
|
target_file_size = target_file.stat().st_size
|
||||||
|
|
||||||
if (source_file_size - target_file_size) / source_file_size > 0.01:
|
if (source_file_size - target_file_size) / source_file_size > 0.05:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"""The file size different is more than 1%:
|
f"""The file size different is more than 5%:
|
||||||
- {source_file}: {source_file_size}
|
- {source_file}: {source_file_size}
|
||||||
- {target_file}: {target_file_size}
|
- {target_file}: {target_file_size}
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -26,7 +26,10 @@ def weight_hub_files(
|
||||||
filenames = [
|
filenames = [
|
||||||
s.rfilename
|
s.rfilename
|
||||||
for s in info.siblings
|
for s in info.siblings
|
||||||
if s.rfilename.endswith(extension) and len(s.rfilename.split("/")) == 1
|
if s.rfilename.endswith(extension)
|
||||||
|
and len(s.rfilename.split("/")) == 1
|
||||||
|
and "arguments" not in s.rfilename
|
||||||
|
and "args" not in s.rfilename
|
||||||
]
|
]
|
||||||
|
|
||||||
if not filenames:
|
if not filenames:
|
||||||
|
|
Loading…
Reference in New Issue