Adding yarn support. (#1099)
# What does this PR do? Fixes #1017 Not sure if there's a mistake here but - NousResearch/Yarn-Llama-2-7b-128k seems to be working fine - TheBloke/Yarn-Llama-2-13B-128K-GPTQ outputs garbage <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
This commit is contained in:
parent
87f43814e3
commit
3c373dcc53
|
@ -601,6 +601,19 @@ try:
|
||||||
device=inv_freq.device,
|
device=inv_freq.device,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
)
|
)
|
||||||
|
elif rope_scaling["type"] == "yarn":
|
||||||
|
return YarnPositionRotaryEmbedding(
|
||||||
|
dim=2 * inv_freq.shape[0],
|
||||||
|
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
|
||||||
|
base=10000.0,
|
||||||
|
device=inv_freq.device,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
extrapolation_factor=1,
|
||||||
|
attn_factor=1,
|
||||||
|
beta_fast=32,
|
||||||
|
beta_slow=1
|
||||||
|
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||||
|
@ -629,6 +642,19 @@ try:
|
||||||
device=inv_freq.device,
|
device=inv_freq.device,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
)
|
)
|
||||||
|
elif rope_scaling["type"] == "yarn":
|
||||||
|
return YarnPositionRotaryEmbedding(
|
||||||
|
dim=2 * inv_freq.shape[0],
|
||||||
|
max_position_embeddings=rope_scaling["original_max_position_embeddings"],
|
||||||
|
base=10000.0,
|
||||||
|
device=inv_freq.device,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
extrapolation_factor=1,
|
||||||
|
attn_factor=1,
|
||||||
|
beta_fast=32,
|
||||||
|
beta_slow=1
|
||||||
|
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||||
|
@ -708,5 +734,76 @@ try:
|
||||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# Inverse dim formula to find dim based on number of rotations
|
||||||
|
import math
|
||||||
|
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
||||||
|
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))
|
||||||
|
|
||||||
|
# Find dim range bounds based on rotations
|
||||||
|
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
||||||
|
low = math.floor(find_correction_dim(
|
||||||
|
low_rot, dim, base, max_position_embeddings))
|
||||||
|
high = math.ceil(find_correction_dim(
|
||||||
|
high_rot, dim, base, max_position_embeddings))
|
||||||
|
return max(low, 0), min(high, dim-1) # Clamp values just in case
|
||||||
|
|
||||||
|
def linear_ramp_mask(min, max, dim):
|
||||||
|
if min == max:
|
||||||
|
max += 0.001 # Prevent singularity
|
||||||
|
|
||||||
|
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
||||||
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||||
|
return ramp_func
|
||||||
|
|
||||||
|
def get_mscale(scale=1):
|
||||||
|
if scale <= 1:
|
||||||
|
return 1.0
|
||||||
|
return 0.1 * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow):
|
||||||
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
super().__init__(inv_freq, scaling_factor)
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
self.extrapolation_factor = extrapolation_factor
|
||||||
|
self.attn_factor = attn_factor
|
||||||
|
self.beta_fast = beta_fast
|
||||||
|
self.beta_slow = beta_slow
|
||||||
|
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
# Reset the tables if the sequence length has changed,
|
||||||
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
if seqlen > self.max_position_embeddings:
|
||||||
|
inv_freq_extrapolation = _create_inv_freq(
|
||||||
|
self.dim, self.base, self.inv_freq.device
|
||||||
|
)
|
||||||
|
freqs = 1.0 / inv_freq_extrapolation
|
||||||
|
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
||||||
|
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings)
|
||||||
|
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
||||||
|
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||||
|
|
||||||
|
self.inv_freq = inv_freq
|
||||||
|
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
||||||
|
|
||||||
|
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||||
|
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
|
||||||
|
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue