2023-07-21 02:59:00 -06:00
import os
2023-05-15 09:30:47 -06:00
import torch
2023-06-08 06:51:52 -06:00
import torch . distributed
2023-05-15 09:30:47 -06:00
from torch import nn
2023-05-15 15:36:30 -06:00
from torch . nn import functional as F
2024-02-26 11:49:28 -07:00
from typing import List , Tuple , Optional
2023-09-27 03:42:57 -06:00
from loguru import logger
from functools import lru_cache
2023-05-15 09:30:47 -06:00
HAS_BITS_AND_BYTES = True
try :
2023-06-08 06:51:52 -06:00
import bitsandbytes as bnb
2023-08-03 15:00:59 -06:00
from bitsandbytes . nn import Int8Params , Params4bit
2023-06-08 06:51:52 -06:00
except ImportError :
2023-05-15 09:30:47 -06:00
HAS_BITS_AND_BYTES = False
2023-06-08 06:51:52 -06:00
from accelerate import init_empty_weights
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
from text_generation_server . utils . gptq . quant_linear import QuantLinear
2023-12-11 06:43:40 -07:00
from text_generation_server . utils . import_utils import IS_CUDA_SYSTEM , IS_ROCM_SYSTEM
Add AWQ quantization inference support (#1019) (#1054)
# Add AWQ quantization inference support
Fixes
https://github.com/huggingface/text-generation-inference/issues/781
This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.
This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).
Quick way to test this PR would be bring up TGI as follows:
```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq
text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```
Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions
[here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested.
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released
[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).
Please refer to the linked issue for benchmarks for
[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs
[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).
Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.
## Who can review?
@OlivierDehaene OR @Narsil
---------
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Abhinav M Kulkarni <abhinavkulkarni@gmail.com>
Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
2023-09-25 07:31:27 -06:00
HAS_AWQ = True
2023-09-27 04:22:09 -06:00
try :
Add AWQ quantization inference support (#1019) (#1054)
# Add AWQ quantization inference support
Fixes
https://github.com/huggingface/text-generation-inference/issues/781
This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.
This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).
Quick way to test this PR would be bring up TGI as follows:
```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq
text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```
Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions
[here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested.
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released
[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).
Please refer to the linked issue for benchmarks for
[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs
[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).
Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.
## Who can review?
@OlivierDehaene OR @Narsil
---------
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Abhinav M Kulkarni <abhinavkulkarni@gmail.com>
Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
2023-09-25 07:31:27 -06:00
from text_generation_server . utils . awq . quantize . qmodule import WQLinear
except ImportError :
HAS_AWQ = False
2023-07-21 02:59:00 -06:00
try :
2023-09-06 07:01:00 -06:00
major , _minor = torch . cuda . get_device_capability ( )
except Exception :
major = 1
2023-11-25 14:38:38 -07:00
2023-09-06 07:01:00 -06:00
HAS_EXLLAMA = False
2024-01-26 08:27:44 -07:00
CAN_EXLLAMA = major > = 8 or IS_ROCM_SYSTEM
2023-11-25 14:38:38 -07:00
V2 = os . getenv ( " EXLLAMA_VERSION " , " 2 " ) == " 2 "
2023-09-06 07:01:00 -06:00
if os . getenv ( " DISABLE_EXLLAMA " ) == " True " :
2023-07-21 02:59:00 -06:00
HAS_EXLLAMA = False
2023-09-06 07:01:00 -06:00
elif CAN_EXLLAMA :
2023-09-27 04:22:09 -06:00
try :
2023-11-25 14:38:38 -07:00
if V2 :
2023-12-11 06:49:52 -07:00
from text_generation_server . utils . gptq . exllamav2 import (
QuantLinear as ExllamaQuantLinear ,
create_exllama_buffers ,
set_device ,
)
2023-12-11 06:43:40 -07:00
2023-11-25 14:38:38 -07:00
HAS_EXLLAMA = " 2 "
else :
2023-12-11 06:49:52 -07:00
from text_generation_server . utils . gptq . exllama import (
Ex4bitLinear as ExllamaQuantLinear ,
create_exllama_buffers ,
set_device ,
)
2023-12-11 06:43:40 -07:00
2023-11-25 14:38:38 -07:00
HAS_EXLLAMA = " 1 "
2023-09-27 04:22:09 -06:00
except ImportError :
pass
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
2023-09-27 03:42:57 -06:00
HAS_EETQ = False
try :
from EETQ import quant_weights , w8_a16_gemm
2023-09-27 04:22:09 -06:00
2023-09-27 03:42:57 -06:00
HAS_EETQ = True
except ImportError :
pass
2023-09-28 01:55:47 -06:00
2023-06-08 06:51:52 -06:00
# Monkey patching
@classmethod
def load_layer_norm ( cls , prefix , weights , eps ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
bias = weights . get_tensor ( f " { prefix } .bias " )
with init_empty_weights ( ) :
ln = cls ( weight . shape , eps = eps )
ln . weight = nn . Parameter ( weight )
ln . bias = nn . Parameter ( bias )
return ln
2023-07-03 05:01:46 -06:00
@classmethod
def load_layer_norm_no_bias ( cls , prefix , weights , eps ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
with init_empty_weights ( ) :
ln = cls ( weight . shape , eps = eps )
ln . weight = nn . Parameter ( weight )
ln . bias = None
return ln
2023-09-27 04:22:09 -06:00
2023-08-17 06:38:49 -06:00
@classmethod
def load_conv2d ( cls , prefix , weights , in_channels , out_channels , kernel_size , stride ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
bias = weights . get_tensor ( f " { prefix } .bias " )
with init_empty_weights ( ) :
2023-09-27 04:22:09 -06:00
conv2d = cls (
in_channels = in_channels ,
out_channels = out_channels ,
kernel_size = kernel_size ,
stride = stride ,
)
2023-08-17 06:38:49 -06:00
conv2d . weight = nn . Parameter ( weight )
conv2d . bias = nn . Parameter ( bias )
return conv2d
@classmethod
2023-09-27 04:22:09 -06:00
def load_conv2d_no_bias (
2023-12-11 06:49:52 -07:00
cls , prefix , weights , in_channels , out_channels , kernel_size , stride
2023-09-27 04:22:09 -06:00
) :
2023-08-17 06:38:49 -06:00
weight = weights . get_tensor ( f " { prefix } .weight " )
with init_empty_weights ( ) :
2023-09-27 04:22:09 -06:00
conv2d = cls (
in_channels = in_channels ,
out_channels = out_channels ,
kernel_size = kernel_size ,
stride = stride ,
)
2023-08-17 06:38:49 -06:00
conv2d . weight = nn . Parameter ( weight )
conv2d . bias = None
return conv2d
2023-07-03 05:01:46 -06:00
2023-08-17 06:38:49 -06:00
torch . nn . Conv2d . load = load_conv2d
torch . nn . Conv2d . load_no_bias = load_conv2d_no_bias
2023-06-08 06:51:52 -06:00
torch . nn . LayerNorm . load = load_layer_norm
2023-07-03 05:01:46 -06:00
torch . nn . LayerNorm . load_no_bias = load_layer_norm_no_bias
2023-05-15 09:30:47 -06:00
2023-06-08 06:51:52 -06:00
class FastLinear ( nn . Module ) :
2023-05-15 09:30:47 -06:00
def __init__ (
2023-12-11 06:49:52 -07:00
self ,
weight ,
bias ,
2023-05-15 09:30:47 -06:00
) - > None :
2023-06-08 06:51:52 -06:00
super ( ) . __init__ ( )
self . weight = nn . Parameter ( weight )
if bias is not None :
self . bias = nn . Parameter ( bias )
else :
2023-05-15 09:30:47 -06:00
self . bias = None
2023-06-08 06:51:52 -06:00
@classmethod
def load ( cls , config , prefix : str , weights , bias : bool ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
if bias :
bias = weights . get_tensor ( f " { prefix } .bias " )
2023-05-15 09:30:47 -06:00
else :
2023-06-08 06:51:52 -06:00
bias = None
return cls ( weight , bias )
2023-05-15 09:30:47 -06:00
def forward ( self , input : torch . Tensor ) - > torch . Tensor :
2023-06-08 06:51:52 -06:00
return F . linear ( input , self . weight , self . bias )
2023-05-15 09:30:47 -06:00
2023-09-27 03:42:57 -06:00
class EETQLinear ( nn . Module ) :
def __init__ (
2023-12-11 06:49:52 -07:00
self ,
weight ,
bias ,
2023-09-27 03:42:57 -06:00
) - > None :
super ( ) . __init__ ( )
device = weight . device
2024-04-10 09:20:25 -06:00
if weight . dtype != torch . float16 :
weight = weight . to ( dtype = torch . float16 )
2023-09-27 03:42:57 -06:00
weight = torch . t ( weight ) . contiguous ( ) . cpu ( )
weight , scale = quant_weights ( weight , torch . int8 , False )
2023-10-19 04:15:05 -06:00
2023-09-27 03:42:57 -06:00
self . weight = weight . cuda ( device )
self . scale = scale . cuda ( device )
self . bias = bias . cuda ( device ) if bias is not None else None
def forward ( self , input : torch . Tensor ) - > torch . Tensor :
output = w8_a16_gemm ( input , self . weight , self . scale )
output = output + self . bias if self . bias is not None else output
return output
2024-04-12 00:13:30 -06:00
def fp8_quantize ( weight , qdtype = torch . float8_e4m3fn ) :
device = weight . device
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch . finfo ( qdtype )
# Calculate the scale as dtype max divided by absmax
scale = finfo . max / weight . abs ( ) . max ( ) . clamp ( min = 1e-12 )
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = ( weight * scale ) . clamp ( min = finfo . min , max = finfo . max )
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight . to ( qdtype )
scale = scale . float ( ) . reciprocal ( )
return qweight , scale
class Fp8Linear ( nn . Module ) :
def __init__ (
self ,
weight ,
bias ,
) - > None :
super ( ) . __init__ ( )
self . dtype = weight . dtype
self . qweight , self . scale = fp8_quantize ( weight )
self . bias = bias if bias is not None else None
def forward ( self , input : torch . Tensor ) - > torch . Tensor :
qinput , scale = fp8_quantize ( input )
output , _ = torch . _scaled_mm (
qinput ,
self . qweight . t ( ) ,
out_dtype = self . dtype ,
scale_a = scale ,
scale_b = self . scale ,
bias = self . bias ,
)
return output
2023-06-08 06:51:52 -06:00
class Linear8bitLt ( nn . Module ) :
2023-05-15 09:30:47 -06:00
def __init__ (
2023-12-11 06:49:52 -07:00
self ,
weight ,
bias ,
has_fp16_weights = True ,
memory_efficient_backward = False ,
threshold = 0.0 ,
index = None ,
2023-05-15 09:30:47 -06:00
) :
2023-06-08 06:51:52 -06:00
super ( ) . __init__ ( )
assert (
not memory_efficient_backward
) , " memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0 "
self . state = bnb . MatmulLtState ( )
self . index = index
# Necessary for stacked layers
self . state . threshold = threshold
self . state . has_fp16_weights = has_fp16_weights
self . state . memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights :
self . state . use_pool = True
self . weight = Int8Params (
weight . data ,
has_fp16_weights = has_fp16_weights ,
requires_grad = has_fp16_weights ,
2023-05-15 09:30:47 -06:00
)
2023-06-08 06:51:52 -06:00
self . weight . cuda ( weight . device )
self . bias = bias
def init_8bit_state ( self ) :
self . state . CB = self . weight . CB
self . state . SCB = self . weight . SCB
self . weight . CB = None
self . weight . SCB = None
def forward ( self , x : torch . Tensor ) :
self . state . is_training = self . training
if self . weight . CB is not None :
self . init_8bit_state ( )
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self . bias is not None and self . bias . dtype != x . dtype :
self . bias . data = self . bias . data . to ( x . dtype )
out = bnb . matmul ( x , self . weight , bias = self . bias , state = self . state )
if not self . state . has_fp16_weights :
if self . state . CB is not None and self . state . CxB is not None :
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self . state . CB
self . weight . data = self . state . CxB
return out
2023-05-15 09:30:47 -06:00
2023-08-03 15:00:59 -06:00
class Linear4bit ( nn . Module ) :
def __init__ ( self , weight , bias , quant_type ) :
super ( ) . __init__ ( )
self . weight = Params4bit (
2023-09-27 04:22:09 -06:00
weight . data ,
requires_grad = False ,
compress_statistics = True ,
quant_type = quant_type ,
2023-08-03 15:00:59 -06:00
)
self . compute_dtype = None
self . weight . cuda ( weight . device )
self . bias = bias
def forward ( self , x : torch . Tensor ) :
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self . bias is not None and self . bias . dtype != x . dtype :
self . bias . data = self . bias . data . to ( x . dtype )
if getattr ( self . weight , " quant_state " , None ) is None :
print (
" FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first. "
)
inp_dtype = x . dtype
if self . compute_dtype is not None :
x = x . to ( self . compute_dtype )
bias = None if self . bias is None else self . bias . to ( self . compute_dtype )
out = bnb . matmul_4bit (
x , self . weight . t ( ) , bias = bias , quant_state = self . weight . quant_state
)
out = out . to ( inp_dtype )
return out
2023-09-27 03:42:57 -06:00
@lru_cache ( 1 )
def warn_deprecate_bnb ( ) :
2023-09-27 04:22:09 -06:00
logger . warning (
" Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce "
)
2023-09-27 03:42:57 -06:00
2023-06-08 06:51:52 -06:00
def get_linear ( weight , bias , quantize ) :
if quantize is None :
linear = FastLinear ( weight , bias )
2023-09-27 03:42:57 -06:00
elif quantize == " eetq " :
if HAS_EETQ :
linear = EETQLinear ( weight , bias )
else :
2023-09-27 04:22:09 -06:00
raise ImportError (
" Please install EETQ from https://github.com/NetEase-FuXi/EETQ "
)
2024-04-12 00:13:30 -06:00
elif quantize == " fp8 " :
linear = Fp8Linear ( weight , bias )
2023-06-08 06:51:52 -06:00
elif quantize == " bitsandbytes " :
2023-09-27 03:42:57 -06:00
warn_deprecate_bnb ( )
2023-06-08 06:51:52 -06:00
linear = Linear8bitLt (
weight ,
bias ,
has_fp16_weights = False ,
threshold = 6.0 ,
)
if bias is not None :
linear . bias = nn . Parameter ( bias )
2023-08-03 15:00:59 -06:00
elif quantize == " bitsandbytes-fp4 " :
linear = Linear4bit (
weight ,
bias ,
quant_type = " fp4 " ,
)
elif quantize == " bitsandbytes-nf4 " :
linear = Linear4bit (
weight ,
bias ,
quant_type = " nf4 " ,
)
2023-06-08 06:51:52 -06:00
elif quantize == " gptq " :
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
try :
2023-07-21 02:59:00 -06:00
qweight , qzeros , scales , g_idx , bits , groupsize , use_exllama = weight
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
except Exception :
raise NotImplementedError (
f " The passed weight is not `gptq` compatible, loader needs to be updated. "
)
2023-07-21 02:59:00 -06:00
if use_exllama :
2023-12-11 06:49:52 -07:00
linear = ExllamaQuantLinear (
qweight , qzeros , scales , g_idx , bias , bits , groupsize
)
2023-07-21 02:59:00 -06:00
else :
linear = QuantLinear (
qweight ,
qzeros ,
scales ,
g_idx ,
bias ,
bits ,
groupsize ,
)
Add AWQ quantization inference support (#1019) (#1054)
# Add AWQ quantization inference support
Fixes
https://github.com/huggingface/text-generation-inference/issues/781
This PR (partially) adds support for AWQ quantization for inference.
More information on AWQ [here](https://arxiv.org/abs/2306.00978). In
general, AWQ is faster and more accurate than GPTQ, which is currently
supported by TGI.
This PR installs 4-bit GEMM custom CUDA kernels released by AWQ authors
(in `requirements.txt`, just one line change).
Quick way to test this PR would be bring up TGI as follows:
```
text-generation-server download-weights abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq
text-generation-launcher \
--huggingface-hub-cache ~/.cache/huggingface/hub/ \
--model-id abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq \
--trust-remote-code --port 8080 \
--max-input-length 2048 --max-total-tokens 4096 --max-batch-prefill-tokens 4096 \
--quantize awq
```
Please note:
* This PR was tested with FlashAttention v2 and vLLM.
* This PR adds support for AWQ inference, not quantizing the models.
That needs to be done outside of TGI, instructions
[here](https://github.com/mit-han-lab/llm-awq/tree/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa).
* This PR only adds support for `FlashLlama` models for now.
* Multi-GPU setup has not been tested.
* No integration tests have been added so far, will add later if
maintainers are interested in this change.
* This PR can be tested on any of the models released
[here](https://huggingface.co/abhinavkulkarni?sort_models=downloads#models).
Please refer to the linked issue for benchmarks for
[abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq](https://huggingface.co/abhinavkulkarni/meta-llama-Llama-2-7b-chat-hf-w4-g128-awq)
vs
[TheBloke/Llama-2-7b-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GPTQ).
Please note, AWQ has released faster (and in case of Llama, fused)
kernels for 4-bit GEMM, currently at the top of the `main` branch at
https://github.com/mit-han-lab/llm-awq, but this PR uses an older commit
that has been tested to work. We can switch to latest commit later on.
## Who can review?
@OlivierDehaene OR @Narsil
---------
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Abhinav M Kulkarni <abhinavkulkarni@gmail.com>
Co-authored-by: Abhinav Kulkarni <abhinav@concentric.ai>
2023-09-25 07:31:27 -06:00
elif quantize == " awq " :
try :
qweight , qzeros , scales , _ , bits , groupsize , _ = weight
except Exception :
raise NotImplementedError (
f " The passed weight is not `awq` compatible, loader needs to be updated. "
)
2024-02-09 02:45:16 -07:00
if IS_ROCM_SYSTEM :
raise NotImplementedError (
" AWQ GEMM kernel can ' t be used on ROCm systems, please use `--quantize gptq` instead "
" to use Exllama/GPTQ kernels for AWQ inference. "
)
if not HAS_AWQ :
2024-02-16 03:58:58 -07:00
raise NotImplementedError (
" You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly "
)
2023-09-27 04:22:09 -06:00
linear = WQLinear (
w_bit = bits ,
group_size = groupsize ,
qweight = qweight ,
qzeros = qzeros ,
scales = scales ,
bias = bias is not None ,
)
2023-06-08 06:51:52 -06:00
else :
raise NotImplementedError ( f " Quantization ` { quantize } ` is not implemented yet. " )
return linear
class SuperLayer ( nn . Module ) :
def __init__ ( self , linear ) :
super ( ) . __init__ ( )
self . linear = linear
def forward ( self , x ) :
return self . linear . forward ( x )
2024-02-26 11:49:28 -07:00
class ResBlock ( torch . nn . Module ) :
def __init__ ( self , config , prefix , weights ) :
super ( ) . __init__ ( )
self . linear = FastLinear . load (
config , prefix = f " { prefix } .linear " , weights = weights , bias = True
)
self . act = torch . nn . SiLU ( )
def forward ( self , x ) :
return x + self . act ( self . linear ( x ) )
class MedusaModel ( torch . nn . Module ) :
def __init__ ( self , config , weights ) :
super ( ) . __init__ ( )
self . heads = torch . nn . ModuleList (
[
MedusaHead ( config , prefix = f " { i } " , weights = weights )
for i in range ( config [ " medusa_num_heads " ] )
]
)
def forward ( self , x ) :
speculative_logits = torch . stack ( [ head ( x ) for head in self . heads ] , dim = 1 )
return speculative_logits
class MedusaHead ( torch . nn . Module ) :
def __init__ ( self , config , prefix , weights ) :
super ( ) . __init__ ( )
self . blocks = torch . nn . ModuleList (
[
ResBlock ( config , prefix = f " { prefix } . { i } " , weights = weights )
for i in range ( config [ " medusa_num_layers " ] )
]
)
n = len ( self . blocks )
self . out = FastLinear . load (
config , prefix = f " { prefix } . { n } " , weights = weights , bias = False
)
def forward ( self , x ) :
for block in self . blocks :
x = block ( x )
x = self . out ( x )
return x
class SpeculativeHead ( nn . Module ) :
def __init__ ( self , lm_head , medusa ) :
super ( ) . __init__ ( )
self . lm_head = lm_head
self . medusa = medusa
@staticmethod
def load ( config , prefix : str , weights ) :
lm_head = TensorParallelHead . load ( config , prefix , weights )
use_medusa = config . use_medusa
if use_medusa :
from pathlib import Path
from safetensors import safe_open
import json
medusa_config = str ( Path ( use_medusa ) / " config.json " )
filename = str ( Path ( use_medusa ) / " medusa_lm_head.safetensors " )
with open ( medusa_config , " r " ) as f :
config = json . load ( f )
routing = weights . routing
with safe_open ( filename , framework = " pytorch " ) as f :
for k in f . keys ( ) :
if k in routing :
raise RuntimeError (
f " Key { k } was found in multiple files: { filename } and { routing [ k ] } "
)
weights . routing [ k ] = filename
medusa = MedusaModel ( config , weights )
else :
medusa = None
return SpeculativeHead ( lm_head , medusa )
def forward (
self , input : torch . Tensor
) - > Tuple [ torch . Tensor , Optional [ torch . Tensor ] ] :
logits = self . lm_head ( input )
speculative_logits = self . medusa ( input ) if self . medusa is not None else None
return logits , speculative_logits
2023-06-08 06:51:52 -06:00
class TensorParallelHead ( SuperLayer ) :
2023-07-12 08:43:31 -06:00
def __init__ ( self , linear , process_group , should_gather : bool ) :
2023-06-08 06:51:52 -06:00
super ( ) . __init__ ( linear )
2023-05-15 09:30:47 -06:00
self . process_group = process_group
2023-07-12 08:43:31 -06:00
self . should_gather = should_gather
2023-06-08 06:51:52 -06:00
@staticmethod
def load ( config , prefix : str , weights ) :
2023-07-12 08:43:31 -06:00
if weights . process_group . size ( ) > 1 :
try :
weight = weights . get_sharded ( f " { prefix } .weight " , dim = 0 )
should_gather = True
except AssertionError :
# If the vocab size is not divisible by number of shards
# just load the entire thing.
weight = weights . get_tensor ( f " { prefix } .weight " )
should_gather = False
else :
weight = weights . get_tensor ( f " { prefix } .weight " )
should_gather = False
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
2023-09-27 03:42:57 -06:00
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config . quantize in [ " gptq " , " awq " , " eetq " ] :
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
quantize = None
else :
quantize = config . quantize
2023-06-08 06:51:52 -06:00
return TensorParallelHead (
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
get_linear ( weight , bias = None , quantize = quantize ) ,
2023-06-08 06:51:52 -06:00
process_group = weights . process_group ,
2023-07-12 08:43:31 -06:00
should_gather = should_gather ,
2023-05-15 09:30:47 -06:00
)
def forward ( self , input : torch . Tensor ) - > torch . Tensor :
2023-07-28 09:43:46 -06:00
if not self . should_gather :
return super ( ) . forward ( input )
2023-07-12 08:43:31 -06:00
world_size = self . process_group . size ( )
2023-07-28 09:43:46 -06:00
if len ( input . shape ) == 2 and isinstance ( self . linear , FastLinear ) :
2023-06-09 03:55:29 -06:00
out_dim = self . linear . weight . shape [ 0 ]
2023-07-28 09:43:46 -06:00
if input . shape [ 0 ] == 1 :
world_out = input . new_empty ( 1 , out_dim * world_size )
local_out = input . new_empty ( 1 , out_dim )
gather_input = local_out
else :
world_out = input . new_empty ( out_dim * world_size , input . shape [ 0 ] )
gather_input = input . new_empty ( out_dim , input . shape [ 0 ] )
local_out = gather_input . T
2023-06-09 03:55:29 -06:00
torch . mm ( input , self . linear . weight . T , out = local_out )
torch . distributed . all_gather_into_tensor (
2023-07-28 09:43:46 -06:00
world_out , gather_input , group = self . process_group
2023-06-09 03:55:29 -06:00
)
2023-07-28 09:43:46 -06:00
if input . shape [ 0 ] == 1 :
return world_out
return world_out . T
2023-07-28 07:36:38 -06:00
2023-07-28 09:43:46 -06:00
output = super ( ) . forward ( input )
world_output = [
torch . empty_like ( output ) for _ in range ( self . process_group . size ( ) )
]
2023-06-08 06:51:52 -06:00
torch . distributed . all_gather ( world_output , output , group = self . process_group )
world_output = torch . cat ( world_output , dim = - 1 )
return world_output
class TensorParallelColumnLinear ( SuperLayer ) :
@classmethod
fit for baichuan models (#981)
As more and more people begin to use Baichuan's open-source models, the
influence of Baichuan models is growing, especially in China. Many
community members are interested in adding support for Baichuan models
to TGI. Meanwhile, Baichuan is a very open company, and in the future,
it plans to open-source more and more models, taking all this into
consideration, we would like to add support for the Baichuan model to
TGI. To do this, we need to make some changes, which we hope can be
merged into the main branch of TGI. In the future, we would be happy to
help maintain support for Baichuan models in TGI. We sincerely hope that
our pull request can be accepted. Thank you.
By the way, the changes of this time mainly for supporting Baichuan-7B.
---------
Co-authored-by: xiaoyuze <xiaoyuze@baichuan.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2023-09-08 08:51:34 -06:00
def load_qkv ( cls , config , prefix : str , weights , bias : bool ) :
""" Specific method when the QKV was joined after the fact """
2023-09-27 04:22:09 -06:00
weight = weights . get_weights_col_packed_qkv ( prefix , quantize = config . quantize )
fit for baichuan models (#981)
As more and more people begin to use Baichuan's open-source models, the
influence of Baichuan models is growing, especially in China. Many
community members are interested in adding support for Baichuan models
to TGI. Meanwhile, Baichuan is a very open company, and in the future,
it plans to open-source more and more models, taking all this into
consideration, we would like to add support for the Baichuan model to
TGI. To do this, we need to make some changes, which we hope can be
merged into the main branch of TGI. In the future, we would be happy to
help maintain support for Baichuan models in TGI. We sincerely hope that
our pull request can be accepted. Thank you.
By the way, the changes of this time mainly for supporting Baichuan-7B.
---------
Co-authored-by: xiaoyuze <xiaoyuze@baichuan.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2023-09-08 08:51:34 -06:00
if bias :
raise NotImplementedError ( " packed_qkv only implemented for baichuan " )
else :
bias = None
linear = get_linear ( weight , bias , config . quantize )
return cls ( linear )
@classmethod
2023-06-08 06:51:52 -06:00
def load ( cls , config , prefix : str , weights , bias : bool ) :
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
return cls . load_multi ( config , [ prefix ] , weights , bias , dim = 0 )
2023-05-15 09:30:47 -06:00
2023-06-08 06:51:52 -06:00
@classmethod
def load_multi ( cls , config , prefixes : List [ str ] , weights , bias : bool , dim : int ) :
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
weight = weights . get_multi_weights_col (
prefixes , quantize = config . quantize , dim = dim
)
2023-05-15 09:30:47 -06:00
2023-06-08 06:51:52 -06:00
if bias :
b = [ weights . get_sharded ( f " { p } .bias " , dim = 0 ) for p in prefixes ]
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
bias = torch . cat ( b , dim = dim )
2023-06-08 06:51:52 -06:00
else :
bias = None
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
linear = get_linear ( weight , bias , config . quantize )
return cls ( linear )
2023-05-15 09:30:47 -06:00
2023-06-08 06:51:52 -06:00
class TensorParallelRowLinear ( SuperLayer ) :
def __init__ ( self , linear , process_group ) :
super ( ) . __init__ ( linear )
2023-05-15 09:30:47 -06:00
self . process_group = process_group
2023-06-08 06:51:52 -06:00
@classmethod
def load ( cls , config , prefix : str , weights , bias : bool ) :
feat(server): Add inference support for GPTQ (llama + falcon tested) + Quantization script (#438)
Let's start discussing implementation.
- Need to expose the quantization scripts (either included here or add
doc on how to use https://github.com/qwopqwop200/GPTQ-for-LLaMa)
- Make sure GPTQ works for multiple models (priority to Falcon).
Currently it means that every place we use `get_{tensor|sharded}` to
check for quantization.
My idea is to reintegrate as much as possible into `utils/layer.py` by
expanding `load_multi` to be a bit more generic.
This might require some thinking, but ultimately the
`qweight,qzeros,scales,g_idx` should be in a single place, and
independant of bias presence.
# What does this PR do?
<!--
Congratulations! You've made it this far! You're not quite done yet
though.
Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.
Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.
Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->
<!-- Remove if not applicable -->
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @
@OlivierDehaene OR @Narsil
-->
---------
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
2023-06-26 04:27:01 -06:00
weight = weights . get_multi_weights_row ( prefix , quantize = config . quantize )
2023-06-08 06:51:52 -06:00
if bias and weights . process_group . rank ( ) == 0 :
# Rank is only on the first rank process
bias = weights . get_tensor ( f " { prefix } .bias " )
else :
bias = None
return cls (
get_linear ( weight , bias , config . quantize ) ,
process_group = weights . process_group ,
)
2023-05-15 09:30:47 -06:00
2023-12-12 09:55:03 -07:00
def forward ( self , input : torch . Tensor , reduce : bool = True ) - > torch . Tensor :
2023-06-08 06:51:52 -06:00
out = super ( ) . forward ( input )
2023-12-12 09:55:03 -07:00
if self . process_group . size ( ) > 1 and reduce :
2023-06-09 03:55:29 -06:00
torch . distributed . all_reduce ( out , group = self . process_group )
2023-06-08 06:51:52 -06:00
return out
2023-05-15 09:30:47 -06:00
2023-06-08 06:51:52 -06:00
class TensorParallelEmbedding ( nn . Module ) :
def __init__ ( self , prefix : str , weights , reduce = True ) :
super ( ) . __init__ ( )
2023-07-12 08:43:31 -06:00
weight = weights . get_partial_sharded ( f " { prefix } .weight " , dim = 0 )
2023-06-08 06:51:52 -06:00
num_embeddings = weights . get_shape ( f " { prefix } .weight " ) [ 0 ]
process_group = weights . process_group
world_size = process_group . size ( )
rank = process_group . rank ( )
2024-01-24 05:08:41 -07:00
block_size = ( num_embeddings + world_size - 1 ) / / world_size
2023-06-08 06:51:52 -06:00
self . min_id = rank * block_size
self . max_id = min ( num_embeddings , ( rank + 1 ) * block_size )
2024-01-26 11:04:57 -07:00
self . null_idx = weight . shape [
0
] # Usually block_size, might be less in non even vocab_size.
2023-06-08 06:51:52 -06:00
self . process_group = weights . process_group
self . reduce = reduce
2023-05-15 09:30:47 -06:00
""" Additional 0 entry used for masking """
2023-06-08 06:51:52 -06:00
self . weight = nn . Parameter ( F . pad ( weight , ( 0 , 0 , 0 , 1 ) ) )
2023-05-15 09:30:47 -06:00
def forward ( self , input : torch . Tensor ) - > torch . Tensor :
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input = torch . where (
( self . min_id > input ) | ( input > = self . max_id ) ,
self . null_idx ,
input - self . min_id ,
)
2023-06-08 06:51:52 -06:00
out = torch . nn . functional . embedding ( input , self . weight )
2023-06-09 03:55:29 -06:00
if self . reduce and self . process_group . size ( ) > 1 :
2023-05-15 15:36:30 -06:00
torch . distributed . all_reduce ( out , group = self . process_group )
2023-05-15 09:30:47 -06:00
return out
try :
2023-11-27 06:08:12 -07:00
if IS_CUDA_SYSTEM :
import dropout_layer_norm
2023-12-11 06:43:40 -07:00
elif IS_ROCM_SYSTEM :
from vllm import layernorm_ops
2023-11-27 06:08:12 -07:00
else :
dropout_layer_norm = None
2023-05-15 09:30:47 -06:00
class FastLayerNorm ( nn . LayerNorm ) :
def forward ( self , hidden_states , residual = None ) :
2023-11-27 06:08:12 -07:00
if hidden_states . shape [ - 1 ] > 8192 or IS_ROCM_SYSTEM :
2023-05-15 09:30:47 -06:00
if residual is not None :
hidden_states + = residual
residual = hidden_states
return super ( FastLayerNorm , self ) . forward ( hidden_states ) , residual
else :
(
normed_hidden_states ,
residual ,
* rest ,
) = dropout_layer_norm . dropout_add_ln_fwd (
hidden_states ,
residual ,
self . weight ,
self . bias ,
None ,
None ,
None ,
None ,
0.0 ,
self . eps ,
1.0 ,
0 ,
None ,
False ,
False ,
)
if residual is None :
residual = hidden_states
return normed_hidden_states , residual
2023-12-11 06:43:40 -07:00
class FastRMSNorm ( nn . Module ) :
def __init__ ( self , weight : torch . Tensor , eps : float ) :
super ( ) . __init__ ( )
self . weight = nn . Parameter ( weight )
self . variance_epsilon = eps
@classmethod
def load ( cls , prefix , weights , eps = 1e-6 ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
return cls ( weight , eps )
def forward ( self , hidden_states , residual = None ) :
if hidden_states . shape [ - 1 ] > 8192 :
if residual is not None :
hidden_states + = residual
residual = hidden_states
hidden_states = hidden_states . to ( torch . float32 )
variance = hidden_states . pow ( 2 ) . mean ( - 1 , keepdim = True )
hidden_states = hidden_states * torch . rsqrt (
variance + self . variance_epsilon
)
# convert into half-precision if necessary
if self . weight . dtype in [ torch . float16 , torch . bfloat16 ] :
hidden_states = hidden_states . to ( self . weight . dtype )
return self . weight * hidden_states , residual
elif IS_CUDA_SYSTEM :
# faster post attention rms norm
2023-12-11 06:49:52 -07:00
(
normed_hidden_states ,
res ,
* rest ,
) = dropout_layer_norm . dropout_add_ln_fwd (
2023-12-11 06:43:40 -07:00
hidden_states ,
residual ,
self . weight ,
None ,
None ,
None ,
None ,
None ,
0.0 ,
self . variance_epsilon ,
1.0 ,
0 ,
None ,
False ,
True , # Activate RMSNorm
)
if res is None :
res = hidden_states
return normed_hidden_states , res
elif IS_ROCM_SYSTEM :
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None :
hidden_states + = residual
residual = hidden_states
out = torch . empty_like ( hidden_states )
layernorm_ops . rms_norm (
out ,
hidden_states ,
self . weight . data ,
self . variance_epsilon ,
)
return out , residual
else :
raise ValueError (
2023-12-11 06:49:52 -07:00
" Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction. "
)
2023-12-11 06:43:40 -07:00
2023-05-15 09:30:47 -06:00
except ImportError :
pass
try :
2023-11-27 06:08:12 -07:00
if IS_CUDA_SYSTEM :
from flash_attn . layers . rotary import RotaryEmbedding
import rotary_emb
elif IS_ROCM_SYSTEM :
from vllm import pos_encoding_ops
2023-05-15 09:30:47 -06:00
2023-07-31 07:38:47 -06:00
def _create_inv_freq ( dim , base , device ) :
inv_freq = 1.0 / (
2023-12-11 06:49:52 -07:00
base * * ( torch . arange ( 0 , dim , 2 , device = device , dtype = torch . float32 ) / dim )
2023-07-31 07:38:47 -06:00
)
return inv_freq
def _get_rope_config ( config ) :
if os . getenv ( " ROPE_SCALING " , None ) is not None :
2023-09-27 04:22:09 -06:00
rope_scaling = {
" type " : os . environ [ " ROPE_SCALING " ] ,
" factor " : float ( os . environ [ " ROPE_FACTOR " ] ) ,
}
2023-07-31 07:38:47 -06:00
return rope_scaling
return getattr ( config , " rope_scaling " , None )
2023-06-08 06:51:52 -06:00
class PositionRotaryEmbedding ( nn . Module ) :
2023-07-31 07:38:47 -06:00
def __init__ ( self , inv_freq , scaling_factor ) :
2023-06-08 06:51:52 -06:00
super ( ) . __init__ ( )
2023-07-04 12:23:55 -06:00
self . inv_freq = inv_freq
2023-06-08 06:51:52 -06:00
self . _seq_len_cached = 0
self . _cos_cached = None
self . _sin_cached = None
self . _cos_k_cached = None
self . _sin_k_cached = None
2023-07-31 07:38:47 -06:00
self . scaling_factor = scaling_factor
self . dynamic_args = None
2023-06-08 06:51:52 -06:00
2023-12-11 06:49:52 -07:00
def forward (
self ,
query : torch . Tensor ,
key : torch . Tensor ,
cos : torch . Tensor ,
sin : torch . Tensor ,
) :
2023-11-27 06:08:12 -07:00
# Such controlflows may add some overhead.
if IS_CUDA_SYSTEM :
rotary_dim = cos . shape [ - 1 ]
q1 = query [ . . . , : rotary_dim ]
2023-12-11 06:49:52 -07:00
q2 = query [ . . . , rotary_dim : 2 * rotary_dim ]
2023-11-27 06:08:12 -07:00
rotary_emb . apply_rotary ( q1 , q2 , cos , sin , q1 , q2 , False )
k1 = key [ . . . , : rotary_dim ]
2023-12-11 06:49:52 -07:00
k2 = key [ . . . , rotary_dim : 2 * rotary_dim ]
2023-11-27 06:08:12 -07:00
rotary_emb . apply_rotary ( k1 , k2 , cos , sin , k1 , k2 , False )
elif IS_ROCM_SYSTEM :
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query . shape [ - 1 ]
# Inplace operation, updating query and key.
2023-12-11 06:49:52 -07:00
pos_encoding_ops . rotary_embedding ( query , key , head_size , cos , sin , True )
2023-11-27 06:08:12 -07:00
else :
2023-12-11 06:43:40 -07:00
raise ValueError (
2023-12-11 06:49:52 -07:00
" Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction. "
)
2023-11-27 06:08:12 -07:00
2023-06-08 06:51:52 -06:00
@classmethod
2023-07-31 07:38:47 -06:00
def static ( cls , config , dim , base , device ) :
inv_freq = _create_inv_freq ( dim , base , device )
scaling_factor = None
rope_scaling = _get_rope_config ( config )
if rope_scaling is not None :
scaling_factor = rope_scaling [ " factor " ]
if rope_scaling [ " type " ] == " linear " :
pass
elif rope_scaling [ " type " ] == " dynamic " :
2023-09-27 04:22:09 -06:00
return DynamicPositionRotaryEmbedding (
dim = dim ,
max_position_embeddings = config . max_position_embeddings ,
base = base ,
device = inv_freq . device ,
scaling_factor = scaling_factor ,
)
2023-10-05 02:11:50 -06:00
elif rope_scaling [ " type " ] == " yarn " :
return YarnPositionRotaryEmbedding (
dim = 2 * inv_freq . shape [ 0 ] ,
2023-12-11 06:49:52 -07:00
max_position_embeddings = rope_scaling [
" original_max_position_embeddings "
] ,
2023-10-05 02:11:50 -06:00
base = 10000.0 ,
device = inv_freq . device ,
scaling_factor = scaling_factor ,
extrapolation_factor = 1 ,
attn_factor = 1 ,
beta_fast = 32 ,
2023-12-11 06:49:52 -07:00
beta_slow = 1 ,
2023-10-05 02:11:50 -06:00
)
2023-07-31 07:38:47 -06:00
else :
2023-09-27 04:22:09 -06:00
raise NotImplementedError (
f " rope scaling type { rope_scaling [ ' type ' ] } is not implemented or invalid "
)
2023-07-31 07:38:47 -06:00
return cls ( inv_freq , scaling_factor )
2023-06-08 06:51:52 -06:00
@classmethod
2023-07-31 07:38:47 -06:00
def load ( cls , config , prefix , weights ) :
2023-06-08 06:51:52 -06:00
# XXX: Always load this in float32 !
dtype = weights . dtype
weights . dtype = torch . float32
inv_freq = weights . get_tensor ( f " { prefix } .inv_freq " )
weights . dtype = dtype
2023-07-31 07:38:47 -06:00
scaling_factor = None
rope_scaling = _get_rope_config ( config )
if rope_scaling is not None :
scaling_factor = rope_scaling [ " factor " ]
if rope_scaling [ " type " ] == " linear " :
pass
elif rope_scaling [ " type " ] == " dynamic " :
2023-09-27 04:22:09 -06:00
return DynamicPositionRotaryEmbedding (
dim = 2 * inv_freq . shape [ 0 ] ,
max_position_embeddings = config . max_position_embeddings ,
base = 10000.0 ,
device = inv_freq . device ,
scaling_factor = scaling_factor ,
)
2023-10-05 02:11:50 -06:00
elif rope_scaling [ " type " ] == " yarn " :
return YarnPositionRotaryEmbedding (
dim = 2 * inv_freq . shape [ 0 ] ,
2023-12-11 06:49:52 -07:00
max_position_embeddings = rope_scaling [
" original_max_position_embeddings "
] ,
2023-10-05 02:11:50 -06:00
base = 10000.0 ,
device = inv_freq . device ,
scaling_factor = scaling_factor ,
extrapolation_factor = 1 ,
attn_factor = 1 ,
beta_fast = 32 ,
2023-12-11 06:49:52 -07:00
beta_slow = 1 ,
2023-10-05 02:11:50 -06:00
)
2023-07-31 07:38:47 -06:00
else :
2023-09-27 04:22:09 -06:00
raise NotImplementedError (
f " rope scaling type { rope_scaling [ ' type ' ] } is not implemented or invalid "
)
2023-07-31 07:38:47 -06:00
return cls ( inv_freq , scaling_factor )
2023-06-08 06:51:52 -06:00
2023-05-15 09:30:47 -06:00
def _update_cos_sin_cache ( self , dtype , device , seqlen ) :
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
2023-12-11 06:49:52 -07:00
seqlen > self . _seq_len_cached
or self . _cos_cached . device != device
or self . _cos_cached . dtype != dtype
2023-05-15 09:30:47 -06:00
) :
self . _seq_len_cached = seqlen
t = torch . arange ( seqlen , device = device , dtype = self . inv_freq . dtype )
2023-07-31 07:38:47 -06:00
if self . scaling_factor is not None :
t / = self . scaling_factor
2023-05-15 09:30:47 -06:00
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
2023-07-31 07:38:47 -06:00
2023-05-15 09:30:47 -06:00
freqs = torch . outer ( t , self . inv_freq . to ( device = t . device ) )
self . _cos_cached = torch . cos ( freqs ) . to ( dtype )
self . _sin_cached = torch . sin ( freqs ) . to ( dtype )
def get_cos_sin (
2023-12-11 06:49:52 -07:00
self , position_ids : torch . Tensor , max_s : int , dtype : torch . dtype
2023-05-15 09:30:47 -06:00
) :
"""
Return cos and sin for the asked position ids
"""
2023-11-27 06:08:12 -07:00
if IS_ROCM_SYSTEM :
# For RoCm, we always use float cos/sin to avoid a cast.
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
dtype = torch . float32
2023-05-15 09:30:47 -06:00
self . _update_cos_sin_cache ( dtype , position_ids . device , max_s )
cos = torch . index_select ( self . _cos_cached , 0 , position_ids )
sin = torch . index_select ( self . _sin_cached , 0 , position_ids )
2023-11-27 06:08:12 -07:00
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
2023-05-15 09:30:47 -06:00
return cos . unsqueeze ( 1 ) , sin . unsqueeze ( 1 )
2023-07-31 07:38:47 -06:00
class DynamicPositionRotaryEmbedding ( PositionRotaryEmbedding ) :
def __init__ ( self , dim , max_position_embeddings , base , device , scaling_factor ) :
2023-07-31 10:57:29 -06:00
inv_freq = _create_inv_freq ( dim , base , device )
2023-07-31 07:38:47 -06:00
super ( ) . __init__ ( inv_freq , scaling_factor )
self . dim = dim
self . max_position_embeddings = max_position_embeddings
self . base = base
2023-12-11 06:43:40 -07:00
def _update_cos_sin_cache ( self , dtype , device , seqlen ) :
2023-07-31 07:38:47 -06:00
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
2023-12-11 06:49:52 -07:00
seqlen > self . _seq_len_cached
or self . _cos_cached . device != device
or self . _cos_cached . dtype != dtype
2023-07-31 07:38:47 -06:00
) :
if seqlen > self . max_position_embeddings :
2023-09-27 04:22:09 -06:00
newbase = self . base * (
2023-12-11 06:49:52 -07:00
( self . scaling_factor * seqlen / self . max_position_embeddings )
- ( self . scaling_factor - 1 )
2023-09-27 04:22:09 -06:00
) * * ( self . dim / ( self . dim - 2 ) )
self . inv_freq = _create_inv_freq (
self . dim , newbase , self . inv_freq . device
)
2023-07-31 07:38:47 -06:00
self . _seq_len_cached = seqlen
t = torch . arange ( seqlen , device = device , dtype = self . inv_freq . dtype )
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch . outer ( t , self . inv_freq . to ( device = t . device ) )
self . _cos_cached = torch . cos ( freqs ) . to ( dtype )
self . _sin_cached = torch . sin ( freqs ) . to ( dtype )
2023-10-05 02:11:50 -06:00
# Inverse dim formula to find dim based on number of rotations
import math
2023-12-11 06:43:40 -07:00
2023-12-11 06:49:52 -07:00
def find_correction_dim (
num_rotations , dim , base = 10000 , max_position_embeddings = 2048
) :
return (
dim * math . log ( max_position_embeddings / ( num_rotations * 2 * math . pi ) )
) / ( 2 * math . log ( base ) )
2023-10-05 02:11:50 -06:00
# Find dim range bounds based on rotations
2023-12-11 06:49:52 -07:00
def find_correction_range (
low_rot , high_rot , dim , base = 10000 , max_position_embeddings = 2048
) :
low = math . floor (
find_correction_dim ( low_rot , dim , base , max_position_embeddings )
)
high = math . ceil (
find_correction_dim ( high_rot , dim , base , max_position_embeddings )
)
2023-12-11 06:43:40 -07:00
return max ( low , 0 ) , min ( high , dim - 1 ) # Clamp values just in case
2023-10-05 02:11:50 -06:00
def linear_ramp_mask ( min , max , dim ) :
if min == max :
max + = 0.001 # Prevent singularity
linear_func = ( torch . arange ( dim , dtype = torch . float32 ) - min ) / ( max - min )
ramp_func = torch . clamp ( linear_func , 0 , 1 )
return ramp_func
def get_mscale ( scale = 1 ) :
if scale < = 1 :
return 1.0
return 0.1 * math . log ( scale ) + 1.0
class YarnPositionRotaryEmbedding ( PositionRotaryEmbedding ) :
2023-12-11 06:49:52 -07:00
def __init__ (
self ,
dim ,
max_position_embeddings ,
base ,
device ,
scaling_factor ,
* ,
extrapolation_factor ,
attn_factor ,
beta_fast ,
beta_slow ,
) :
2023-10-05 02:11:50 -06:00
inv_freq = _create_inv_freq ( dim , base , device )
super ( ) . __init__ ( inv_freq , scaling_factor )
self . dim = dim
self . max_position_embeddings = max_position_embeddings
self . base = base
self . extrapolation_factor = extrapolation_factor
self . attn_factor = attn_factor
self . beta_fast = beta_fast
self . beta_slow = beta_slow
2023-12-11 06:49:52 -07:00
self . mscale = float (
get_mscale ( self . scaling_factor ) * self . attn_factor
) # Get n-d magnitude scaling corrected for interpolation
2023-10-05 02:11:50 -06:00
def _update_cos_sin_cache ( self , dtype , device , seqlen ) :
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
2023-12-11 06:49:52 -07:00
seqlen > self . _seq_len_cached
or self . _cos_cached . device != device
or self . _cos_cached . dtype != dtype
2023-10-05 02:11:50 -06:00
) :
if seqlen > self . max_position_embeddings :
inv_freq_extrapolation = _create_inv_freq (
self . dim , self . base , self . inv_freq . device
)
freqs = 1.0 / inv_freq_extrapolation
inv_freq_interpolation = 1.0 / ( self . scaling_factor * freqs )
2023-12-11 06:49:52 -07:00
low , high = find_correction_range (
self . beta_fast ,
self . beta_slow ,
self . dim ,
self . base ,
self . max_position_embeddings ,
)
inv_freq_mask = (
1
- linear_ramp_mask ( low , high , self . dim / / 2 ) . float ( ) . to ( device )
) * self . extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = (
inv_freq_interpolation * ( 1 - inv_freq_mask )
+ inv_freq_extrapolation * inv_freq_mask
)
2023-10-05 02:11:50 -06:00
self . inv_freq = inv_freq
2023-12-11 06:49:52 -07:00
self . mscale = float (
get_mscale ( self . scaling_factor ) * self . attn_factor
) # Get n-d magnitude scaling corrected for interpolation
2023-10-05 02:11:50 -06:00
self . _seq_len_cached = seqlen
t = torch . arange ( seqlen , device = device , dtype = self . inv_freq . dtype )
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch . outer ( t , self . inv_freq . to ( device = t . device ) )
self . _cos_cached = ( torch . cos ( freqs ) * self . mscale ) . to ( dtype )
self . _sin_cached = ( torch . sin ( freqs ) * self . mscale ) . to ( dtype )
2023-05-15 09:30:47 -06:00
except ImportError :
pass