feat(server): check cuda capability when importing flash models (#201)
close #198
This commit is contained in:
parent
e14ae3b5e9
commit
a88c54bb4c
|
@ -24,7 +24,18 @@ try:
|
||||||
FlashSantacoderSharded,
|
FlashSantacoderSharded,
|
||||||
)
|
)
|
||||||
|
|
||||||
FLASH_ATTENTION = torch.cuda.is_available()
|
if torch.cuda.is_available():
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
is_sm75 = major == 7 and minor == 5
|
||||||
|
is_sm8x = major == 8 and minor >= 0
|
||||||
|
is_sm90 = major == 9 and minor == 0
|
||||||
|
|
||||||
|
supported = is_sm75 or is_sm8x or is_sm90
|
||||||
|
if not supported:
|
||||||
|
raise ImportError(f"GPU with CUDA capability {major} {minor} is not supported")
|
||||||
|
FLASH_ATTENTION = True
|
||||||
|
else:
|
||||||
|
FLASH_ATTENTION = False
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.opt(exception=True).warning(
|
logger.opt(exception=True).warning(
|
||||||
"Could not import Flash Attention enabled models"
|
"Could not import Flash Attention enabled models"
|
||||||
|
|
Loading…
Reference in New Issue