diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index 570aa75c..1d131e0b 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -2,8 +2,6 @@ from typing import Optional import torch from torch.nn import functional as F from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.layers.gptq import GPTQWeight if SYSTEM == "rocm": try: @@ -155,6 +153,8 @@ def get_linear(weight, bias, quantize): quant_type="nf4", ) elif quantize == "exl2": + from text_generation_server.layers.exl2 import Exl2Weight + if not isinstance(weight, Exl2Weight): raise NotImplementedError( f"The passed weight is not `exl2` compatible, loader needs to be updated." @@ -165,6 +165,8 @@ def get_linear(weight, bias, quantize): linear = ExllamaQuantLinear(weight, bias) elif quantize == "gptq": + from text_generation_server.layers.gptq import GPTQWeight + if not isinstance(weight, GPTQWeight): raise NotImplementedError( f"The passed weight is not `gptq` compatible, loader needs to be updated." diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 497956e3..7967e420 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -21,7 +21,6 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any from loguru import logger -from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.utils.import_utils import SYSTEM if SYSTEM != "xpu": @@ -198,6 +197,8 @@ def _load_gqa(config, prefix: str, weights): v_stop = v_offset + (rank + 1) * kv_block_size if config.quantize in ["gptq", "awq"]: + from text_generation_server.layers.gptq import GPTQWeight + try: qweight_slice = weights._get_slice(f"{prefix}.qweight") q_qweight = qweight_slice[:, q_start:q_stop] diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c8397000..1f47550e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,7 +5,6 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.layers.attention import ( paged_attention, attention, @@ -39,6 +38,8 @@ def load_multi_mqa( def _load_multi_mqa_gptq( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): + from text_generation_server.layers.gptq import GPTQWeight + if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose: world_size = weights.process_group.size() rank = weights.process_group.rank() diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 710ea680..5782de8a 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -7,8 +7,6 @@ import torch from loguru import logger from huggingface_hub import hf_hub_download import json -from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.layers.gptq import GPTQWeight from text_generation_server.utils.log import log_once @@ -221,6 +219,8 @@ class Weights: def get_weights_col(self, prefix: str, quantize: str): if quantize == "exl2": + from text_generation_server.layers.exl2 import Exl2Weight + try: q_weight = self.get_tensor(f"{prefix}.q_weight") except RuntimeError: @@ -247,6 +247,8 @@ class Weights: if quantize == "exl2": raise ValueError("get_multi_weights_col is not supported for exl2") elif quantize in ["gptq", "awq"]: + from text_generation_server.layers.gptq import GPTQWeight + try: qweight = torch.cat( [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1