GPTQ support on ROCm (#1489)

Tested with
```
CUDA_VISIBLE_DEVICES=0 text-generation-launcher --model-id TheBloke/Llama-2-7B-Chat-GPTQ --quantize gptq
EXLLAMA_VERSION=1 CUDA_VISIBLE_DEVICES=0 text-generation-launcher --model-id TheBloke/Llama-2-7B-Chat-GPTQ --quantize gptq
CUDA_VISIBLE_DEVICES="0,1" text-generation-launcher --model-id TheBloke/Llama-2-7B-Chat-GPTQ --quantize gptq
```

all with good and identical results on MI210.

---------

Co-authored-by: Felix Marty <felix@hf.co>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
This commit is contained in:
fxmarty 2024-01-26 16:27:44 +01:00 committed by GitHub
parent ebecc06161
commit 650fea1834
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 80 additions and 22 deletions

10
.gitignore vendored
View File

@ -2,3 +2,13 @@
target
router/tokenizer.json
*__pycache__*
# ROCm auto-generated files
*.hip
server/exllamav2_kernels/exllamav2_kernels/hip/
server/exllama_kernels/exllama_kernels/hip/
server/exllama_kernels/exllama_kernels/hip_func/
*_hip.cuh
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp

View File

@ -75,8 +75,8 @@ RUN chmod +x ~/mambaforge.sh && \
mamba init && \
rm ~/mambaforge.sh
# Install PyTorch nightly (2.2.0.dev2023) compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
RUN pip install --pre torch==2.2.0.dev20231106 --index-url https://download.pytorch.org/whl/nightly/rocm5.7
# Install PyTorch 2.2 RC compiled against RoCm 5.7, as VLLM can not be compiled with RoCm 5.6.
RUN pip install torch --index-url https://download.pytorch.org/whl/test/rocm5.7/
FROM base AS kernel-builder
@ -104,6 +104,20 @@ WORKDIR /usr/src
COPY server/custom_kernels/ .
RUN PYTORCH_ROCM_ARCH=gfx90a python setup.py build
# Build exllama kernels
FROM kernel-builder as exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
# Build exllama v2 kernels
FROM kernel-builder as exllamav2-kernels-builder
WORKDIR /usr/src
COPY server/exllamav2_kernels/ .
RUN PYTORCH_ROCM_ARCH="gfx90a" python setup.py build
FROM base as base-copy
# Text Generation Inference base env
@ -120,6 +134,12 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir

View File

@ -43,8 +43,8 @@ text-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>
TGI optimized models are supported on NVIDIA [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it. For other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention and flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
* Quantization (GPTQ, AWQ, etc.)
TGI also has support of ROCm-enabled AMD Instinct MI210 and MI250 GPUs, with paged attention, GPTQ quantization, flash attention v2 support. The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.
* Flash [layer norm kernel](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm)
* Kernel for slinding window attention (Mistral)

View File

@ -43,12 +43,12 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
//
#if defined(__CUDA_ARCH__)
#if __CUDA_ARCH__ < 700
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif

View File

@ -2,8 +2,11 @@
#include "column_remap.cuh"
#include "../util.cuh"
#include "../matrix.cuh"
#include "../cuda_compat.cuh"
#include "../cu_compat.cuh"
#include "../cuda_buffers.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif
const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
@ -128,7 +131,7 @@ __global__ void q4_matmul_kernel
if constexpr (use_half2)
{
half result = __hadd(acc.x, acc.y);
half result = __hadd(__low2half(acc), __high2half(acc));
atomicAdd(out_.item_ptr(x_row, w_column), result);
}
else

View File

@ -1,12 +1,23 @@
#ifndef _compat_gemm_cuh
#define _compat_gemm_cuh
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#if defined(USE_ROCM)
#ifndef _hip_compat_cuh
#define _hip_compat_cuh
// For some reason this include is not present anywhere in exllama_v2 codebase, but it is required
// for symbols as hipblasHalf.
#include <hipblas/hipblas.h>
// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.
__device__ __forceinline__ __half __compat_hrcp(__half x) {
return __half_raw{
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
}
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
}
#define hrcp __compat_hrcp
#define h2rcp __compat_h2rcp
// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA,
hipblasOperation_t transB,
@ -31,8 +42,10 @@ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t
#define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_handle hipblasHandle_t
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_get_stream hipblasGetStream
#define rocblas_set_stream hipblasSetStream
#define rocblas_hgemm __compat_hipblasHgemm
#endif
#endif
#endif

View File

@ -8,7 +8,11 @@
#include <cstdint>
#include <cstdio>
#if defined(USE_ROCM)
#define cudaUnspecified hipErrorUnknown
#else
#define cudaUnspecified cudaErrorApiFailureBase
#endif
// React to failure on return code != cudaSuccess

View File

@ -1,5 +1,15 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
extra_cuda_cflags = ["-lineinfo", "-O3"]
if torch.version.hip:
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
extra_compile_args = {
"nvcc": extra_cuda_cflags,
}
setup(
name="exllamav2_kernels",
@ -11,6 +21,7 @@ setup(
"exllamav2_kernels/cuda/q_matrix.cu",
"exllamav2_kernels/cuda/q_gemm.cu",
],
extra_compile_args=extra_compile_args,
)
],
cmdclass={"build_ext": BuildExtension},

View File

@ -1,12 +1,9 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
from logging import getLogger
import torch
import torch.nn as nn
import math
logger = getLogger(__name__)
from loguru import logger
try:
from exllamav2_kernels import make_q_matrix, gemm_half_q_half

View File

@ -33,7 +33,7 @@ except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
# V2 = False