feat(server): check cuda capability when importing flash models (#201)
close #198
This commit is contained in:
parent
e14ae3b5e9
commit
a88c54bb4c
|
@ -24,7 +24,18 @@ try:
|
|||
FlashSantacoderSharded,
|
||||
)
|
||||
|
||||
FLASH_ATTENTION = torch.cuda.is_available()
|
||||
if torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
is_sm8x = major == 8 and minor >= 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
|
||||
supported = is_sm75 or is_sm8x or is_sm90
|
||||
if not supported:
|
||||
raise ImportError(f"GPU with CUDA capability {major} {minor} is not supported")
|
||||
FLASH_ATTENTION = True
|
||||
else:
|
||||
FLASH_ATTENTION = False
|
||||
except ImportError:
|
||||
logger.opt(exception=True).warning(
|
||||
"Could not import Flash Attention enabled models"
|
||||
|
|
Loading…
Reference in New Issue