diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 745c1d2e..6be54048 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -18,13 +18,20 @@ from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear -HAS_EXLLAMA = True +try: + major, _minor = torch.cuda.get_device_capability() +except Exception: + major = 1 +HAS_EXLLAMA = False +CAN_EXLLAMA = major >= 8 if os.getenv("DISABLE_EXLLAMA") == "True": HAS_EXLLAMA = False -try: - from text_generation_server.utils.gptq.exllama import Ex4bitLinear -except ImportError: - HAS_EXLLAMA = False +elif CAN_EXLLAMA: + try: + from text_generation_server.utils.gptq.exllama import Ex4bitLinear + HAS_EXLLAMA = True + except ImportError: + pass from typing import Optional diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index ef662ce1..261456bd 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -170,10 +170,10 @@ class Weights: "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) - from text_generation_server.utils.layers import HAS_EXLLAMA + from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA if use_exllama: - if not HAS_EXLLAMA: + if not HAS_EXLLAMA and CAN_EXLLAMA: logger.warning( "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True" )