diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 55b1aae9..cb0c1e85 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -179,7 +179,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize != "gptq": + if config.quantize not in ["gptq", "awq"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads