diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 63131dee..7e838035 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -10,8 +10,8 @@ from text_generation_server.utils.weights import Weight, Weights, WeightsLoader if SYSTEM == "ipex": from .ipex import QuantLinear -elif SYSTEM == "cuda": - from .cuda import QuantLinear +elif SYSTEM in {"cuda", "rocm"}: + from .triton import QuantLinear @dataclass diff --git a/server/text_generation_server/layers/gptq/cuda.py b/server/text_generation_server/layers/gptq/triton.py similarity index 100% rename from server/text_generation_server/layers/gptq/cuda.py rename to server/text_generation_server/layers/gptq/triton.py