fix(layers): fix SuRotaryEmbedding (#2060)

* fix(layers): fix SuRotaryEmbedding

* change arange

* remove logs
This commit is contained in:
OlivierDehaene 2024-06-12 18:24:47 +02:00 committed by GitHub
parent 521de6cacd
commit 90184df79c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 14 deletions

View File

@ -267,19 +267,21 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
if seqlen > self.original_max_position_embeddings:
inv_freq = self.long_inv_freq
else:
inv_freq = self.short_inv_freq
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq.to(device=t.device)) t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
self._cos_cached = torch.cos(freqs).to(dtype) short_freqs = torch.outer(
self._sin_cached = torch.sin(freqs).to(dtype) t[: self.original_max_position_embeddings],
self.short_inv_freq.to(device=t.device),
)
long_freqs = torch.outer(
t[self.original_max_position_embeddings :],
self.long_inv_freq.to(device=t.device),
)
freqs = torch.cat([short_freqs, long_freqs])
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):

View File

@ -8,7 +8,6 @@ from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_phi_modeling import ( from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM, FlashPhiForCausalLM,
PhiConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
@ -44,7 +43,7 @@ class FlashPhi(FlashCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config = PhiConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize