2024-05-28 03:51:31 -06:00
from typing import Optional
2024-05-13 04:44:30 -06:00
import torch
from torch . nn import functional as F
2024-06-14 01:45:42 -06:00
from text_generation_server . layers . marlin import GPTQMarlinLinear
2024-05-13 04:44:30 -06:00
from text_generation_server . utils . import_utils import SYSTEM
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
if SYSTEM == " rocm " :
try :
from vllm import _custom_C
except Exception as e :
raise ImportError ( f " Could not load `vllm._custom_C`. Full error: { e } " )
2024-05-13 04:44:30 -06:00
class FastLinear ( torch . nn . Module ) :
def __init__ (
self ,
weight ,
bias ,
) - > None :
super ( ) . __init__ ( )
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
self . weight = torch . nn . Parameter ( weight , requires_grad = False )
2024-05-13 04:44:30 -06:00
if bias is not None :
Pali gemma modeling (#1895)
This PR adds paligemma modeling code
Blog post: https://huggingface.co/blog/paligemma
Transformers PR: https://github.com/huggingface/transformers/pull/30814
install the latest changes and run with
```bash
# get the weights
# text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf
# run TGI
text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf
```
basic example sending various requests
```python
from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:3000")
images = [
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png",
]
prompts = [
"What animal is in this image?",
"Name three colors in this image.",
"What are 10 colors in this image?",
"Where is the cow standing?",
"answer en Where is the cow standing?",
"Is there a bird in the image?",
"Is ther a cow in the image?",
"Is there a rabbit in the image?",
"how many birds are in the image?",
"how many rabbits are in the image?",
]
for img in images:
print(f"\nImage: {img.split('/')[-1]}")
for prompt in prompts:
inputs = f"![]({img}){prompt}\n"
json_data = {
"inputs": inputs,
"parameters": {
"max_new_tokens": 30,
"do_sample": False,
},
}
generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False)
print([f"{prompt}\n{generated_output}"])
```
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2024-05-15 22:58:47 -06:00
self . bias = torch . nn . Parameter ( bias , requires_grad = False )
2024-05-13 04:44:30 -06:00
else :
self . bias = None
@classmethod
def load ( cls , config , prefix : str , weights , bias : bool ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
if bias :
bias = weights . get_tensor ( f " { prefix } .bias " )
else :
bias = None
return cls ( weight , bias )
def forward ( self , input : torch . Tensor ) - > torch . Tensor :
return F . linear ( input , self . weight , self . bias )
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
class FastLinearROCm ( torch . nn . Module ) :
def __init__ (
self ,
weight ,
bias ,
) - > None :
super ( ) . __init__ ( )
self . weight = torch . nn . Parameter ( weight )
if bias is not None :
self . bias = torch . nn . Parameter ( bias )
else :
self . bias = None
@classmethod
def load ( cls , config , prefix : str , weights , bias : bool ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
if bias :
bias = weights . get_tensor ( f " { prefix } .bias " )
else :
bias = None
return cls ( weight , bias )
def forward ( self , inp : torch . Tensor ) - > torch . Tensor :
weight = self . weight
bias = self . bias
if SYSTEM == " rocm " and inp . numel ( ) / / inp . shape [ - 1 ] == 1 :
batched = False
inp_shape = inp . shape
if inp . dim ( ) == 3 :
inp = inp . view ( - 1 , inp_shape [ - 1 ] )
batched = True
m , k = weight . shape [ 0 ] , inp_shape [ 1 ]
out = torch . empty (
inp_shape [ 0 ] , weight . shape [ 0 ] , dtype = inp . dtype , device = " cuda "
)
if ( k == 8192 and ( m == 1280 or m == 7168 ) ) or ( k == 3584 and m == 8192 ) :
_custom_C . LLMM1 ( weight , inp , out , 8 )
elif k < = 8192 and k % 8 == 0 and m % 4 == 0 :
_custom_C . LLMM1 ( weight , inp , out , 4 )
else :
out = F . linear ( inp , weight )
if batched :
out . view ( * inp_shape [ : - 1 ] , out . shape [ - 1 ] )
if bias is not None :
out = out + bias
return out
return F . linear ( inp , self . weight , self . bias )
2024-05-13 04:44:30 -06:00
def get_linear ( weight , bias , quantize ) :
if quantize is None :
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
if SYSTEM == " rocm " :
linear = FastLinearROCm ( weight , bias )
else :
linear = FastLinear ( weight , bias )
2024-05-13 04:44:30 -06:00
elif quantize == " eetq " :
try :
from text_generation_server . layers . eetq import EETQLinear
linear = EETQLinear ( weight , bias )
except ImportError :
raise ImportError (
" Please install EETQ from https://github.com/NetEase-FuXi/EETQ "
)
elif quantize == " fp8 " :
from text_generation_server . layers . fp8 import Fp8Linear
linear = Fp8Linear ( weight , bias )
elif quantize == " bitsandbytes " :
try :
from text_generation_server . layers . bnb import (
warn_deprecate_bnb ,
Linear8bitLt ,
)
except ImportError :
raise NotImplementedError (
f " Bitsandbytes is missing install it with `pip install bitsandbytes`. "
)
warn_deprecate_bnb ( )
linear = Linear8bitLt (
weight ,
bias ,
has_fp16_weights = False ,
threshold = 6.0 ,
)
if bias is not None :
linear . bias = nn . Parameter ( bias )
elif quantize == " bitsandbytes-fp4 " :
try :
from text_generation_server . layers . bnb import Linear4bit
except ImportError :
raise NotImplementedError (
f " Bitsandbytes is missing install it with `pip install bitsandbytes`. "
)
linear = Linear4bit (
weight ,
bias ,
quant_type = " fp4 " ,
)
elif quantize == " bitsandbytes-nf4 " :
try :
from text_generation_server . layers . bnb import Linear4bit
except ImportError :
raise NotImplementedError (
f " Bitsandbytes is missing install it with `pip install bitsandbytes`. "
)
linear = Linear4bit (
weight ,
bias ,
quant_type = " nf4 " ,
)
2024-05-28 03:51:31 -06:00
elif quantize == " exl2 " :
2024-06-03 02:36:29 -06:00
from text_generation_server . layers . exl2 import Exl2Weight
2024-05-28 03:51:31 -06:00
if not isinstance ( weight , Exl2Weight ) :
raise NotImplementedError (
f " The passed weight is not `exl2` compatible, loader needs to be updated. "
)
from text_generation_server . layers . gptq import ExllamaQuantLinear
linear = ExllamaQuantLinear ( weight , bias )
2024-05-13 04:44:30 -06:00
elif quantize == " gptq " :
2024-06-03 02:36:29 -06:00
from text_generation_server . layers . gptq import GPTQWeight
2024-05-28 03:51:31 -06:00
if not isinstance ( weight , GPTQWeight ) :
2024-05-13 04:44:30 -06:00
raise NotImplementedError (
f " The passed weight is not `gptq` compatible, loader needs to be updated. "
)
2024-05-28 03:51:31 -06:00
if weight . use_exllama :
2024-05-13 04:44:30 -06:00
try :
from text_generation_server . layers . gptq import (
ExllamaQuantLinear ,
)
except ImportError :
raise NotImplementedError (
f " Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install` "
)
2024-05-28 03:51:31 -06:00
linear = ExllamaQuantLinear ( weight , bias )
2024-05-13 04:44:30 -06:00
else :
from text_generation_server . layers . gptq . quant_linear import QuantLinear
linear = QuantLinear (
2024-05-28 03:51:31 -06:00
weight . qweight ,
weight . qzeros ,
weight . scales ,
weight . g_idx ,
2024-05-13 04:44:30 -06:00
bias ,
2024-05-28 03:51:31 -06:00
weight . bits ,
weight . groupsize ,
2024-05-13 04:44:30 -06:00
)
elif quantize == " awq " :
2024-06-03 03:32:12 -06:00
from text_generation_server . layers . gptq import GPTQWeight
2024-05-28 03:51:31 -06:00
if not isinstance ( weight , GPTQWeight ) :
2024-05-13 04:44:30 -06:00
raise NotImplementedError (
f " The passed weight is not `awq` compatible, loader needs to be updated. "
)
if SYSTEM == " rocm " :
raise NotImplementedError (
" AWQ GEMM kernel can ' t be used on ROCm systems, please use `--quantize gptq` instead "
" to use Exllama/GPTQ kernels for AWQ inference. "
)
try :
from text_generation_server . layers . awq . quantize . qmodule import WQLinear
linear = WQLinear (
2024-05-28 03:51:31 -06:00
w_bit = weight . bits ,
group_size = weight . groupsize ,
qweight = weight . qweight ,
qzeros = weight . qzeros ,
scales = weight . scales ,
2024-05-13 04:44:30 -06:00
bias = bias is not None ,
)
except ImportError :
raise NotImplementedError (
" You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly "
)
2024-06-05 02:14:40 -06:00
elif quantize == " marlin " :
2024-06-14 01:45:42 -06:00
from text_generation_server . layers . marlin import (
GPTQMarlinWeight ,
MarlinLinear ,
MarlinWeight ,
)
2024-06-05 02:14:40 -06:00
2024-06-14 01:45:42 -06:00
if isinstance ( weight , GPTQMarlinWeight ) :
linear = GPTQMarlinLinear (
weight = weight ,
bias = bias ,
)
elif isinstance ( weight , MarlinWeight ) :
linear = MarlinLinear ( weight = weight , bias = bias )
else :
2024-06-05 02:14:40 -06:00
raise NotImplementedError (
f " The passed weight is not `marlin` compatible, loader needs to be updated. "
)
2024-05-13 04:44:30 -06:00
else :
raise NotImplementedError ( f " Quantization ` { quantize } ` is not implemented yet. " )
return linear