2024-05-13 04:44:30 -06:00
import os
import torch
from torch import nn
2024-07-19 09:23:20 -06:00
from loguru import logger
2024-05-13 04:44:30 -06:00
2024-06-25 05:20:57 -06:00
from text_generation_server . utils . import_utils import SYSTEM
2024-05-13 04:44:30 -06:00
if SYSTEM == " cuda " :
from flash_attn . layers . rotary import RotaryEmbedding
import rotary_emb
elif SYSTEM == " rocm " :
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
from vllm . _C import ops
2024-06-25 05:20:57 -06:00
elif SYSTEM == " ipex " :
2024-05-23 06:11:08 -06:00
import intel_extension_for_pytorch as ipex
2024-05-13 04:44:30 -06:00
def _create_inv_freq ( dim , base , device ) :
inv_freq = 1.0 / (
base * * ( torch . arange ( 0 , dim , 2 , device = device , dtype = torch . float32 ) / dim )
)
return inv_freq
def _get_rope_config ( config ) :
if os . getenv ( " ROPE_SCALING " , None ) is not None :
rope_scaling = {
" type " : os . environ [ " ROPE_SCALING " ] ,
" factor " : float ( os . environ [ " ROPE_FACTOR " ] ) ,
}
return rope_scaling
return getattr ( config , " rope_scaling " , None )
class PositionRotaryEmbedding ( nn . Module ) :
def __init__ ( self , inv_freq , scaling_factor ) :
super ( ) . __init__ ( )
self . inv_freq = inv_freq
self . _seq_len_cached = 0
self . _cos_cached = None
self . _sin_cached = None
self . _cos_k_cached = None
self . _sin_k_cached = None
self . scaling_factor = scaling_factor
self . dynamic_args = None
def forward (
self ,
query : torch . Tensor ,
key : torch . Tensor ,
cos : torch . Tensor ,
sin : torch . Tensor ,
) :
# Such controlflows may add some overhead.
if SYSTEM == " cuda " :
rotary_dim = cos . shape [ - 1 ]
q1 = query [ . . . , : rotary_dim ]
q2 = query [ . . . , rotary_dim : 2 * rotary_dim ]
rotary_emb . apply_rotary ( q1 , q2 , cos , sin , q1 , q2 , False )
k1 = key [ . . . , : rotary_dim ]
k2 = key [ . . . , rotary_dim : 2 * rotary_dim ]
rotary_emb . apply_rotary ( k1 , k2 , cos , sin , k1 , k2 , False )
elif SYSTEM == " rocm " :
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query . shape [ - 1 ]
# Inplace operation, updating query and key.
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
ops . rotary_embedding ( query , key , head_size , cos , sin , True )
2024-06-25 05:20:57 -06:00
elif SYSTEM == " ipex " :
2024-05-13 04:44:30 -06:00
ipex . llm . functional . rotary_embedding (
query , key , sin , cos , query . size ( - 1 ) , True
)
else :
raise ValueError (
" Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction. "
)
@classmethod
def static ( cls , config , dim , base , device ) :
inv_freq = _create_inv_freq ( dim , base , device )
scaling_factor = None
rope_scaling = _get_rope_config ( config )
if rope_scaling is not None :
if rope_scaling [ " type " ] == " linear " :
pass
elif rope_scaling [ " type " ] == " dynamic " :
scaling_factor = rope_scaling [ " factor " ]
return DynamicPositionRotaryEmbedding (
dim = dim ,
max_position_embeddings = config . max_position_embeddings ,
base = base ,
device = inv_freq . device ,
scaling_factor = scaling_factor ,
)
elif rope_scaling [ " type " ] == " yarn " :
scaling_factor = rope_scaling [ " factor " ]
2024-07-19 09:23:20 -06:00
mscale = rope_scaling . get ( " mscale " , 1.0 )
mscale_all_dim = rope_scaling . get ( " mscale_all_dim " , 0.0 )
2024-05-13 04:44:30 -06:00
return YarnPositionRotaryEmbedding (
dim = 2 * inv_freq . shape [ 0 ] ,
max_position_embeddings = rope_scaling [
" original_max_position_embeddings "
] ,
2024-07-12 02:04:51 -06:00
base = base ,
2024-05-13 04:44:30 -06:00
device = inv_freq . device ,
scaling_factor = scaling_factor ,
extrapolation_factor = 1 ,
attn_factor = 1 ,
beta_fast = 32 ,
beta_slow = 1 ,
2024-07-19 09:23:20 -06:00
mscale = mscale ,
mscale_all_dim = mscale_all_dim ,
2024-05-13 04:44:30 -06:00
)
2024-07-05 01:46:41 -06:00
elif rope_scaling [ " type " ] in [ " su " , " longrope " ] :
2024-05-13 04:44:30 -06:00
short_factor = torch . tensor (
rope_scaling [ " short_factor " ] , dtype = torch . float32 , device = device
)
short_inv_freq = 1.0 / (
short_factor
* base
* * (
torch . arange ( 0 , dim , 2 , device = device , dtype = torch . float32 )
/ dim
)
)
long_factor = torch . tensor (
rope_scaling [ " long_factor " ] , dtype = torch . float32 , device = device
)
long_inv_freq = 1.0 / (
long_factor
* base
* * (
torch . arange ( 0 , dim , 2 , device = device , dtype = torch . float32 )
/ dim
)
)
original_max_position_embeddings = (
config . original_max_position_embeddings
)
max_position_embeddings = config . max_position_embeddings
if max_position_embeddings < = original_max_position_embeddings :
scaling_factor = 1.0
else :
scale = max_position_embeddings / original_max_position_embeddings
scaling_factor = math . sqrt (
1 + math . log ( scale ) / math . log ( original_max_position_embeddings )
)
return SuRotaryEmbedding (
short_inv_freq = short_inv_freq ,
long_inv_freq = long_inv_freq ,
scaling_factor = scaling_factor ,
original_max_position_embeddings = original_max_position_embeddings ,
)
else :
raise NotImplementedError (
f " rope scaling type { rope_scaling [ ' type ' ] } is not implemented or invalid "
)
return cls ( inv_freq , scaling_factor )
@classmethod
def load ( cls , config , prefix , weights ) :
# XXX: Always load this in float32 !
dtype = weights . dtype
weights . dtype = torch . float32
inv_freq = weights . get_tensor ( f " { prefix } .inv_freq " )
weights . dtype = dtype
scaling_factor = None
rope_scaling = _get_rope_config ( config )
if rope_scaling is not None :
scaling_factor = rope_scaling [ " factor " ]
if rope_scaling [ " type " ] == " linear " :
pass
elif rope_scaling [ " type " ] == " dynamic " :
return DynamicPositionRotaryEmbedding (
dim = 2 * inv_freq . shape [ 0 ] ,
max_position_embeddings = config . max_position_embeddings ,
base = 10000.0 ,
device = inv_freq . device ,
scaling_factor = scaling_factor ,
)
elif rope_scaling [ " type " ] == " yarn " :
2024-07-19 09:23:20 -06:00
mscale = rope_scaling . get ( " mscale " , 1.0 )
mscale_all_dim = rope_scaling . get ( " mscale_all_dim " , 0.0 )
2024-05-13 04:44:30 -06:00
return YarnPositionRotaryEmbedding (
dim = 2 * inv_freq . shape [ 0 ] ,
max_position_embeddings = rope_scaling [
" original_max_position_embeddings "
] ,
base = 10000.0 ,
device = inv_freq . device ,
scaling_factor = scaling_factor ,
extrapolation_factor = 1 ,
attn_factor = 1 ,
beta_fast = 32 ,
beta_slow = 1 ,
2024-07-19 09:23:20 -06:00
mscale = mscale ,
mscale_all_dim = mscale_all_dim ,
2024-05-13 04:44:30 -06:00
)
else :
raise NotImplementedError (
f " rope scaling type { rope_scaling [ ' type ' ] } is not implemented or invalid "
)
return cls ( inv_freq , scaling_factor )
def _update_cos_sin_cache ( self , dtype , device , seqlen ) :
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self . _seq_len_cached
or self . _cos_cached . device != device
or self . _cos_cached . dtype != dtype
) :
self . _seq_len_cached = seqlen
t = torch . arange ( seqlen , device = device , dtype = self . inv_freq . dtype )
if self . scaling_factor is not None :
t / = self . scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch . outer ( t , self . inv_freq . to ( device = t . device ) )
self . _cos_cached = torch . cos ( freqs ) . to ( dtype )
self . _sin_cached = torch . sin ( freqs ) . to ( dtype )
def get_cos_sin ( self , position_ids : torch . Tensor , max_s : int , dtype : torch . dtype ) :
"""
Return cos and sin for the asked position ids
"""
if SYSTEM == " rocm " :
# For RoCm, we always use float cos/sin to avoid a cast.
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
dtype = torch . float32
self . _update_cos_sin_cache ( dtype , position_ids . device , max_s )
cos = torch . index_select ( self . _cos_cached , 0 , position_ids )
sin = torch . index_select ( self . _sin_cached , 0 , position_ids )
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos . unsqueeze ( 1 ) , sin . unsqueeze ( 1 )
class SuRotaryEmbedding ( PositionRotaryEmbedding ) :
def __init__ (
self ,
short_inv_freq ,
long_inv_freq ,
scaling_factor ,
original_max_position_embeddings ,
) :
super ( PositionRotaryEmbedding , self ) . __init__ ( )
self . short_inv_freq = short_inv_freq
self . long_inv_freq = long_inv_freq
self . scaling_factor = scaling_factor
self . original_max_position_embeddings = original_max_position_embeddings
self . _seq_len_cached = 0
self . _cos_cached = None
self . _sin_cached = None
self . _cos_k_cached = None
self . _sin_k_cached = None
self . dynamic_args = None
def _update_cos_sin_cache ( self , dtype , device , seqlen ) :
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self . _seq_len_cached
or self . _cos_cached . device != device
or self . _cos_cached . dtype != dtype
) :
self . _seq_len_cached = seqlen
2024-06-12 10:24:47 -06:00
t = torch . arange ( seqlen , device = device , dtype = self . short_inv_freq . dtype )
short_freqs = torch . outer (
t [ : self . original_max_position_embeddings ] ,
self . short_inv_freq . to ( device = t . device ) ,
)
long_freqs = torch . outer (
t [ self . original_max_position_embeddings : ] ,
self . long_inv_freq . to ( device = t . device ) ,
)
freqs = torch . cat ( [ short_freqs , long_freqs ] )
self . _cos_cached = ( torch . cos ( freqs ) * self . scaling_factor ) . to ( dtype )
self . _sin_cached = ( torch . sin ( freqs ) * self . scaling_factor ) . to ( dtype )
2024-05-13 04:44:30 -06:00
class DynamicPositionRotaryEmbedding ( PositionRotaryEmbedding ) :
def __init__ ( self , dim , max_position_embeddings , base , device , scaling_factor ) :
inv_freq = _create_inv_freq ( dim , base , device )
super ( ) . __init__ ( inv_freq , scaling_factor )
self . dim = dim
self . max_position_embeddings = max_position_embeddings
self . base = base
def _update_cos_sin_cache ( self , dtype , device , seqlen ) :
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self . _seq_len_cached
or self . _cos_cached . device != device
or self . _cos_cached . dtype != dtype
) :
if seqlen > self . max_position_embeddings :
newbase = self . base * (
( self . scaling_factor * seqlen / self . max_position_embeddings )
- ( self . scaling_factor - 1 )
) * * ( self . dim / ( self . dim - 2 ) )
self . inv_freq = _create_inv_freq (
self . dim , newbase , self . inv_freq . device
)
self . _seq_len_cached = seqlen
t = torch . arange ( seqlen , device = device , dtype = self . inv_freq . dtype )
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch . outer ( t , self . inv_freq . to ( device = t . device ) )
self . _cos_cached = torch . cos ( freqs ) . to ( dtype )
self . _sin_cached = torch . sin ( freqs ) . to ( dtype )
# Inverse dim formula to find dim based on number of rotations
import math
def find_correction_dim ( num_rotations , dim , base = 10000 , max_position_embeddings = 2048 ) :
return ( dim * math . log ( max_position_embeddings / ( num_rotations * 2 * math . pi ) ) ) / (
2 * math . log ( base )
)
# Find dim range bounds based on rotations
def find_correction_range (
low_rot , high_rot , dim , base = 10000 , max_position_embeddings = 2048
) :
low = math . floor ( find_correction_dim ( low_rot , dim , base , max_position_embeddings ) )
high = math . ceil ( find_correction_dim ( high_rot , dim , base , max_position_embeddings ) )
return max ( low , 0 ) , min ( high , dim - 1 ) # Clamp values just in case
def linear_ramp_mask ( min , max , dim ) :
if min == max :
max + = 0.001 # Prevent singularity
linear_func = ( torch . arange ( dim , dtype = torch . float32 ) - min ) / ( max - min )
ramp_func = torch . clamp ( linear_func , 0 , 1 )
return ramp_func
2024-07-19 09:23:20 -06:00
def get_mscale ( scale : float = 1.0 , mscale : float = 1.0 ) :
2024-05-13 04:44:30 -06:00
if scale < = 1 :
return 1.0
2024-07-19 09:23:20 -06:00
return 0.1 * mscale * math . log ( scale ) + 1.0
2024-05-13 04:44:30 -06:00
class YarnPositionRotaryEmbedding ( PositionRotaryEmbedding ) :
def __init__ (
self ,
dim ,
max_position_embeddings ,
base ,
device ,
scaling_factor ,
* ,
extrapolation_factor ,
attn_factor ,
beta_fast ,
beta_slow ,
2024-07-19 09:23:20 -06:00
mscale : float ,
mscale_all_dim : float ,
2024-05-13 04:44:30 -06:00
) :
inv_freq = _create_inv_freq ( dim , base , device )
super ( ) . __init__ ( inv_freq , scaling_factor )
self . dim = dim
self . max_position_embeddings = max_position_embeddings
self . base = base
self . extrapolation_factor = extrapolation_factor
self . attn_factor = attn_factor
self . beta_fast = beta_fast
self . beta_slow = beta_slow
2024-07-19 09:23:20 -06:00
self . mscale_all_dim = mscale_all_dim
self . scaling_factor = scaling_factor
2024-05-13 04:44:30 -06:00
self . mscale = float (
2024-07-19 09:23:20 -06:00
get_mscale ( self . scaling_factor , mscale )
/ get_mscale ( self . scaling_factor , mscale_all_dim )
* self . attn_factor
2024-05-13 04:44:30 -06:00
) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache ( self , dtype , device , seqlen ) :
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self . _seq_len_cached
or self . _cos_cached . device != device
or self . _cos_cached . dtype != dtype
) :
2024-07-19 09:23:20 -06:00
if seqlen > self . max_position_embeddings or True :
2024-05-13 04:44:30 -06:00
inv_freq_extrapolation = _create_inv_freq (
self . dim , self . base , self . inv_freq . device
)
freqs = 1.0 / inv_freq_extrapolation
inv_freq_interpolation = 1.0 / ( self . scaling_factor * freqs )
low , high = find_correction_range (
self . beta_fast ,
self . beta_slow ,
self . dim ,
self . base ,
self . max_position_embeddings ,
)
2024-07-19 09:23:20 -06:00
2024-05-13 04:44:30 -06:00
inv_freq_mask = (
1 - linear_ramp_mask ( low , high , self . dim / / 2 ) . float ( ) . to ( device )
) * self . extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = (
inv_freq_interpolation * ( 1 - inv_freq_mask )
+ inv_freq_extrapolation * inv_freq_mask
)
self . inv_freq = inv_freq
self . _seq_len_cached = seqlen
t = torch . arange ( seqlen , device = device , dtype = self . inv_freq . dtype )
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch . outer ( t , self . inv_freq . to ( device = t . device ) )
self . _cos_cached = ( torch . cos ( freqs ) * self . mscale ) . to ( dtype )
self . _sin_cached = ( torch . sin ( freqs ) * self . mscale ) . to ( dtype )