diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index cf61e47b..f38f130e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -601,6 +601,19 @@ try: device=inv_freq.device, scaling_factor=scaling_factor, ) + elif rope_scaling["type"] == "yarn": + return YarnPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=rope_scaling["original_max_position_embeddings"], + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, + beta_slow=1 + + ) else: raise NotImplementedError( f"rope scaling type {rope_scaling['type']} is not implemented or invalid" @@ -629,6 +642,19 @@ try: device=inv_freq.device, scaling_factor=scaling_factor, ) + elif rope_scaling["type"] == "yarn": + return YarnPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=rope_scaling["original_max_position_embeddings"], + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + extrapolation_factor=1, + attn_factor=1, + beta_fast=32, + beta_slow=1 + + ) else: raise NotImplementedError( f"rope scaling type {rope_scaling['type']} is not implemented or invalid" @@ -708,5 +734,76 @@ try: self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) + + # Inverse dim formula to find dim based on number of rotations + import math + def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): + return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) + + # Find dim range bounds based on rotations + def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): + low = math.floor(find_correction_dim( + low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim( + high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim-1) # Clamp values just in case + + def linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def get_mscale(scale=1): + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow): + inv_freq = _create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings: + inv_freq_extrapolation = _create_inv_freq( + self.dim, self.base, self.inv_freq.device + ) + freqs = 1.0 / inv_freq_extrapolation + inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) + low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings) + inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + + self.inv_freq = inv_freq + self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation + + + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) + self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) + except ImportError: pass