2024-05-13 04:44:30 -06:00
import torch
from torch import nn
from accelerate import init_empty_weights
from text_generation_server . utils . import_utils import (
SYSTEM ,
)
# Monkey patching
@classmethod
def load_layer_norm ( cls , prefix , weights , eps ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
bias = weights . get_tensor ( f " { prefix } .bias " )
with init_empty_weights ( ) :
ln = cls ( weight . shape , eps = eps )
ln . weight = torch . nn . Parameter ( weight )
ln . bias = torch . nn . Parameter ( bias )
return ln
@classmethod
def load_layer_norm_no_bias ( cls , prefix , weights , eps ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
with init_empty_weights ( ) :
ln = cls ( weight . shape , eps = eps )
ln . weight = torch . nn . Parameter ( weight )
ln . bias = None
return ln
torch . nn . LayerNorm . load = load_layer_norm
torch . nn . LayerNorm . load_no_bias = load_layer_norm_no_bias
if SYSTEM == " cuda " :
import dropout_layer_norm
class FastLayerNorm ( nn . LayerNorm ) :
def forward ( self , hidden_states , residual = None ) :
if hidden_states . shape [ - 1 ] > 8192 :
if residual is not None :
hidden_states + = residual
residual = hidden_states
return super ( FastLayerNorm , self ) . forward ( hidden_states ) , residual
else :
(
normed_hidden_states ,
residual ,
* rest ,
) = dropout_layer_norm . dropout_add_ln_fwd (
hidden_states ,
residual ,
self . weight ,
self . bias ,
None ,
None ,
None ,
None ,
0.0 ,
self . eps ,
1.0 ,
0 ,
None ,
False ,
False ,
)
if residual is None :
residual = hidden_states
return normed_hidden_states , residual
elif SYSTEM == " rocm " :
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
from vllm . _C import ops
2024-05-13 04:44:30 -06:00
class FastLayerNorm ( nn . LayerNorm ) :
def forward ( self , hidden_states , residual = None ) :
if residual is not None :
hidden_states + = residual
residual = hidden_states
return super ( ) . forward ( hidden_states ) , residual
2024-06-25 05:20:57 -06:00
elif SYSTEM == " ipex " :
2024-05-13 04:44:30 -06:00
import intel_extension_for_pytorch as ipex
class FastLayerNorm ( nn . LayerNorm ) :
def forward ( self , hidden_states , residual = None ) :
out = ipex . llm . functional . add_layer_norm (
2024-06-25 04:21:29 -06:00
residual ,
hidden_states ,
self . weight ,
self . bias ,
self . eps ,
residual is not None ,
2024-05-13 04:44:30 -06:00
)
2024-06-25 04:21:29 -06:00
return out , residual if residual is not None else hidden_states
2024-05-13 04:44:30 -06:00
class FastRMSNorm ( nn . Module ) :
def __init__ ( self , weight : torch . Tensor , eps : float ) :
super ( ) . __init__ ( )
self . weight = nn . Parameter ( weight )
self . variance_epsilon = eps
@classmethod
def load ( cls , prefix , weights , eps = 1e-6 ) :
weight = weights . get_tensor ( f " { prefix } .weight " )
return cls ( weight , eps )
def forward ( self , hidden_states , residual = None ) :
2024-06-25 05:20:57 -06:00
if SYSTEM == " ipex " :
2024-05-13 04:44:30 -06:00
out = ipex . llm . functional . add_rms_norm (
residual ,
hidden_states ,
self . weight ,
None ,
self . variance_epsilon ,
2024-06-25 04:21:29 -06:00
residual is not None ,
2024-05-13 04:44:30 -06:00
)
2024-06-25 04:21:29 -06:00
return out , residual if residual is not None else hidden_states
2024-05-13 04:44:30 -06:00
elif hidden_states . shape [ - 1 ] > 8192 :
if residual is not None :
hidden_states + = residual
residual = hidden_states
hidden_states = hidden_states . to ( torch . float32 )
variance = hidden_states . pow ( 2 ) . mean ( - 1 , keepdim = True )
hidden_states = hidden_states * torch . rsqrt (
variance + self . variance_epsilon
)
# convert into half-precision if necessary
if self . weight . dtype in [ torch . float16 , torch . bfloat16 ] :
hidden_states = hidden_states . to ( self . weight . dtype )
return self . weight * hidden_states , residual
elif SYSTEM == " cuda " :
# faster post attention rms norm
(
normed_hidden_states ,
res ,
* rest ,
) = dropout_layer_norm . dropout_add_ln_fwd (
hidden_states ,
residual ,
self . weight ,
None ,
None ,
None ,
None ,
None ,
0.0 ,
self . variance_epsilon ,
1.0 ,
0 ,
None ,
False ,
True , # Activate RMSNorm
)
if res is None :
res = hidden_states
return normed_hidden_states , res
elif SYSTEM == " rocm " :
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None :
hidden_states + = residual
residual = hidden_states
out = torch . empty_like ( hidden_states )
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
ops . rms_norm (
2024-05-13 04:44:30 -06:00
out ,
hidden_states ,
self . weight . data ,
self . variance_epsilon ,
)
return out , residual
else :
raise ValueError (
" Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction. "
)