diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 13c74c91..0a29b3cc 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -24,7 +24,18 @@ try: FlashSantacoderSharded, ) - FLASH_ATTENTION = torch.cuda.is_available() + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + is_sm75 = major == 7 and minor == 5 + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + + supported = is_sm75 or is_sm8x or is_sm90 + if not supported: + raise ImportError(f"GPU with CUDA capability {major} {minor} is not supported") + FLASH_ATTENTION = True + else: + FLASH_ATTENTION = False except ImportError: logger.opt(exception=True).warning( "Could not import Flash Attention enabled models"