diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index f6616d3e..9c9b69d1 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -8,34 +8,6 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weight, Weights, WeightsLoader -try: - major, _minor = torch.cuda.get_device_capability() -except Exception: - major = 1 - -HAS_EXLLAMA = False -CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" -V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" -if os.getenv("DISABLE_EXLLAMA") == "True": - HAS_EXLLAMA = False -elif CAN_EXLLAMA: - try: - if V2: - from text_generation_server.layers.gptq.exllamav2 import ( - QuantLinear as ExllamaQuantLinear, # noqa: F401 - ) - - HAS_EXLLAMA = "2" - else: - from text_generation_server.layers.gptq.exllama import ( - Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 - ) - - HAS_EXLLAMA = "1" - - except ImportError: - pass - @dataclass class GPTQWeight(Weight): @@ -432,3 +404,33 @@ class GPTQWeightsLoader(WeightsLoader): else False ) self.quant_method = "gptq" + + +# Needs to be at the end because circular import. +try: + major, _minor = torch.cuda.get_device_capability() +except Exception: + major = 1 + +HAS_EXLLAMA = False +CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm" +V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" +if os.getenv("DISABLE_EXLLAMA") == "True": + HAS_EXLLAMA = False +elif CAN_EXLLAMA: + try: + if V2: + from text_generation_server.layers.gptq.exllamav2 import ( + QuantLinear as ExllamaQuantLinear, # noqa: F401 + ) + + HAS_EXLLAMA = "2" + else: + from text_generation_server.layers.gptq.exllama import ( + Ex4bitLinear as ExllamaQuantLinear, # noqa: F401 + ) + + HAS_EXLLAMA = "1" + + except ImportError: + pass