From cece8635f8ec9b89cabef2f056c07ec8de3b00d1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 25 Oct 2024 09:17:57 +0200 Subject: [PATCH] Fixing rocm gptq by using triton code too (renamed cuda into triton). (#2691) --- server/text_generation_server/layers/gptq/__init__.py | 4 ++-- .../text_generation_server/layers/gptq/{cuda.py => triton.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename server/text_generation_server/layers/gptq/{cuda.py => triton.py} (100%) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 63131dee..7e838035 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -10,8 +10,8 @@ from text_generation_server.utils.weights import Weight, Weights, WeightsLoader if SYSTEM == "ipex": from .ipex import QuantLinear -elif SYSTEM == "cuda": - from .cuda import QuantLinear +elif SYSTEM in {"cuda", "rocm"}: + from .triton import QuantLinear @dataclass diff --git a/server/text_generation_server/layers/gptq/cuda.py b/server/text_generation_server/layers/gptq/triton.py similarity index 100% rename from server/text_generation_server/layers/gptq/cuda.py rename to server/text_generation_server/layers/gptq/triton.py