Adding Rope scaling. (#741)

# What does this PR do?


- Adds Rope NTK scaling.

Done because
https://github.com/huggingface/text-generation-inference/pull/529 was
closed
Took some code from
https://github.com/huggingface/transformers/pull/24653

- `--rope-scaling` and `--rope-factor` are added separately. I
considered having a single one and parsing something line ("linear:4.0"
, or "dynamic") but decided against
it because it would push more parsing+validation a bit everywhere (both
in the launcher and the server).


Fixes #512




<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Nicolas Patry 2023-07-31 15:38:47 +02:00 committed by GitHub
parent b9633c46d0
commit 932bdd93ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 142 additions and 15 deletions

View File

@ -60,6 +60,26 @@ impl std::fmt::Display for Dtype {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
Linear,
Dynamic,
}
impl std::fmt::Display for RopeScaling {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
RopeScaling::Linear => {
write!(f, "linear")
}
RopeScaling::Dynamic => {
write!(f, "dynamic")
}
}
}
}
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
@ -250,6 +270,26 @@ struct Args {
#[clap(default_value = "1.0", long, env)]
cuda_memory_fraction: f32,
/// Rope scaling will only be used for RoPE models
/// and allow rescaling the position rotary to accomodate for
/// larger prompts.
///
/// Goes together with `rope_factor`.
///
/// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
/// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
/// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
/// basically)
///
/// `--rope-scaling linear --rope-factor` fully describes the scaling you want
#[clap(long, env)]
rope_scaling: Option<RopeScaling>,
/// Rope scaling will only be used for RoPE models
/// See `rope_scaling`
#[clap(long, env)]
rope_factor: Option<f32>,
/// Outputs the logs in JSON format (useful for telemetry)
#[clap(long, env)]
json_output: bool,
@ -305,6 +345,8 @@ fn shard_manager(
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>,
@ -358,6 +400,12 @@ fn shard_manager(
shard_args.push(revision)
}
let rope = match (rope_scaling, rope_factor) {
(None, None) => None,
(Some(scaling), None) => Some((scaling, 1.0)),
(Some(scaling), Some(factor)) => Some((scaling, factor)),
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
};
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string());
@ -395,6 +443,15 @@ fn shard_manager(
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};
// Detect rope scaling
// Sending as env instead of CLI args to not bloat everything
// those only can be used by RoPE models, so passing information around
// for all models will complexify code unnecessarily
if let Some((scaling, factor)) = rope {
envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
}
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
@ -784,6 +841,8 @@ fn spawn_shards(
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
thread::spawn(move || {
shard_manager(
model_id,
@ -802,6 +861,8 @@ fn spawn_shards(
watermark_gamma,
watermark_delta,
cuda_memory_fraction,
rope_scaling,
rope_factor,
otlp_endpoint,
status_sender,
shutdown,

View File

@ -186,7 +186,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
)
self.softmax_scale = self.head_size**-0.5

View File

@ -102,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
)
self.softmax_scale = self.head_size ** (-0.5)

View File

@ -133,7 +133,7 @@ class FlashRWAttention(torch.nn.Module):
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)
@ -247,7 +247,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)

View File

@ -381,33 +381,65 @@ try:
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq):
super().__init__()
def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return inv_freq
def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
return rope_scaling
return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor):
super().__init__()
self.inv_freq = inv_freq
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
@classmethod
def static(cls, dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return cls(inv_freq)
def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
return cls(inv_freq, scaling_factor)
@classmethod
def load(cls, prefix, weights):
def load(cls, config, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
return cls(inv_freq)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
return cls(inv_freq, scaling_factor)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
@ -419,8 +451,11 @@ try:
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
@ -446,5 +481,36 @@ try:
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings:
newbase = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
except ImportError:
pass