fix(layers): fix SuRotaryEmbedding (#2060)
* fix(layers): fix SuRotaryEmbedding * change arange * remove logs
This commit is contained in:
parent
521de6cacd
commit
90184df79c
|
@ -267,19 +267,21 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
if seqlen > self.original_max_position_embeddings:
|
|
||||||
inv_freq = self.long_inv_freq
|
|
||||||
else:
|
|
||||||
inv_freq = self.short_inv_freq
|
|
||||||
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
|
|
||||||
if self.scaling_factor is not None:
|
|
||||||
t /= self.scaling_factor
|
|
||||||
# Don't do einsum, it converts fp32 to fp16
|
|
||||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
||||||
|
|
||||||
freqs = torch.outer(t, inv_freq.to(device=t.device))
|
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
|
||||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
short_freqs = torch.outer(
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
t[: self.original_max_position_embeddings],
|
||||||
|
self.short_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
long_freqs = torch.outer(
|
||||||
|
t[self.original_max_position_embeddings :],
|
||||||
|
self.long_inv_freq.to(device=t.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs = torch.cat([short_freqs, long_freqs])
|
||||||
|
|
||||||
|
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
|
||||||
|
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
|
|
@ -8,7 +8,6 @@ from typing import Optional
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
FlashPhiForCausalLM,
|
FlashPhiForCausalLM,
|
||||||
PhiConfig,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
|
@ -44,7 +43,7 @@ class FlashPhi(FlashCausalLM):
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = PhiConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
|
Loading…
Reference in New Issue