diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 648d28ab..c2f12189 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -267,19 +267,21 @@ class SuRotaryEmbedding(PositionRotaryEmbedding): or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen - if seqlen > self.original_max_position_embeddings: - inv_freq = self.long_inv_freq - else: - inv_freq = self.short_inv_freq - t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype) - if self.scaling_factor is not None: - t /= self.scaling_factor - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, inv_freq.to(device=t.device)) - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) + t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype) + short_freqs = torch.outer( + t[: self.original_max_position_embeddings], + self.short_inv_freq.to(device=t.device), + ) + long_freqs = torch.outer( + t[self.original_max_position_embeddings :], + self.long_inv_freq.to(device=t.device), + ) + + freqs = torch.cat([short_freqs, long_freqs]) + + self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype) + self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype) class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 32b573a9..6a52c1d7 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -8,7 +8,6 @@ from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_phi_modeling import ( FlashPhiForCausalLM, - PhiConfig, ) from text_generation_server.utils import ( initialize_torch_distributed, @@ -44,7 +43,7 @@ class FlashPhi(FlashCausalLM): trust_remote_code=trust_remote_code, ) - config = PhiConfig.from_pretrained( + config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize