From 932bdd93ff559702cd51f07311c78cf389ec0542 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 31 Jul 2023 15:38:47 +0200 Subject: [PATCH] Adding Rope scaling. (#741) # What does this PR do? - Adds Rope NTK scaling. Done because https://github.com/huggingface/text-generation-inference/pull/529 was closed Took some code from https://github.com/huggingface/transformers/pull/24653 - `--rope-scaling` and `--rope-factor` are added separately. I considered having a single one and parsing something line ("linear:4.0" , or "dynamic") but decided against it because it would push more parsing+validation a bit everywhere (both in the launcher and the server). Fixes #512 Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- launcher/src/main.rs | 61 +++++++++++++ .../custom_modeling/flash_llama_modeling.py | 2 +- .../custom_modeling/flash_neox_modeling.py | 2 +- .../custom_modeling/flash_rw_modeling.py | 4 +- server/text_generation_server/utils/layers.py | 88 ++++++++++++++++--- 5 files changed, 142 insertions(+), 15 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 2ad788a4..560ce181 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -60,6 +60,26 @@ impl std::fmt::Display for Dtype { } } +#[derive(Clone, Copy, Debug, ValueEnum)] +enum RopeScaling { + Linear, + Dynamic, +} + +impl std::fmt::Display for RopeScaling { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + RopeScaling::Linear => { + write!(f, "linear") + } + RopeScaling::Dynamic => { + write!(f, "dynamic") + } + } + } +} + /// App Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -250,6 +270,26 @@ struct Args { #[clap(default_value = "1.0", long, env)] cuda_memory_fraction: f32, + /// Rope scaling will only be used for RoPE models + /// and allow rescaling the position rotary to accomodate for + /// larger prompts. + /// + /// Goes together with `rope_factor`. + /// + /// `--rope-factor 2.0` gives linear scaling with a factor of 2.0 + /// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0 + /// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed + /// basically) + /// + /// `--rope-scaling linear --rope-factor` fully describes the scaling you want + #[clap(long, env)] + rope_scaling: Option, + + /// Rope scaling will only be used for RoPE models + /// See `rope_scaling` + #[clap(long, env)] + rope_factor: Option, + /// Outputs the logs in JSON format (useful for telemetry) #[clap(long, env)] json_output: bool, @@ -305,6 +345,8 @@ fn shard_manager( watermark_gamma: Option, watermark_delta: Option, cuda_memory_fraction: f32, + rope_scaling: Option, + rope_factor: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc, @@ -358,6 +400,12 @@ fn shard_manager( shard_args.push(revision) } + let rope = match (rope_scaling, rope_factor) { + (None, None) => None, + (Some(scaling), None) => Some((scaling, 1.0)), + (Some(scaling), Some(factor)) => Some((scaling, factor)), + (None, Some(factor)) => Some((RopeScaling::Linear, factor)), + }; // OpenTelemetry if let Some(otlp_endpoint) = otlp_endpoint { shard_args.push("--otlp-endpoint".to_string()); @@ -395,6 +443,15 @@ fn shard_manager( envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into())) }; + // Detect rope scaling + // Sending as env instead of CLI args to not bloat everything + // those only can be used by RoPE models, so passing information around + // for all models will complexify code unnecessarily + if let Some((scaling, factor)) = rope { + envs.push(("ROPE_SCALING".into(), scaling.to_string().into())); + envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); + } + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container if let Some(huggingface_hub_cache) = huggingface_hub_cache { @@ -784,6 +841,8 @@ fn spawn_shards( let watermark_gamma = args.watermark_gamma; let watermark_delta = args.watermark_delta; let cuda_memory_fraction = args.cuda_memory_fraction; + let rope_scaling = args.rope_scaling; + let rope_factor = args.rope_factor; thread::spawn(move || { shard_manager( model_id, @@ -802,6 +861,8 @@ fn spawn_shards( watermark_gamma, watermark_delta, cuda_memory_fraction, + rope_scaling, + rope_factor, otlp_endpoint, status_sender, shutdown, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b6285856..2c22ea46 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -186,7 +186,7 @@ class FlashLlamaAttention(torch.nn.Module): self.head_size = self.hidden_size // self.num_heads self.rotary_emb = PositionRotaryEmbedding.load( - prefix=f"{prefix}.rotary_emb", weights=weights + config=config, prefix=f"{prefix}.rotary_emb", weights=weights ) self.softmax_scale = self.head_size**-0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index e7c8ced4..9dc374df 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -102,7 +102,7 @@ class FlashNeoxAttention(torch.nn.Module): self.num_heads = self.num_heads // weights.process_group.size() self.rotary_emb = PositionRotaryEmbedding.load( - prefix=f"{prefix}.rotary_emb", weights=weights + config=config, prefix=f"{prefix}.rotary_emb", weights=weights ) self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 3570b283..14caa23d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -133,7 +133,7 @@ class FlashRWAttention(torch.nn.Module): self.head_size = self.hidden_size // self.num_heads self.rotary_emb = PositionRotaryEmbedding.static( - dim=self.head_size, base=10000.0, device=weights.device + config=config, dim=self.head_size, base=10000.0, device=weights.device ) self.softmax_scale = self.head_size ** (-0.5) @@ -247,7 +247,7 @@ class FlashRWLargeAttention(torch.nn.Module): self.head_size = hidden_size // num_heads self.rotary_emb = PositionRotaryEmbedding.static( - self.head_size, base=10000.0, device=weights.device + config=config, dim=self.head_size, base=10000.0, device=weights.device ) self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 183cf2c1..fc92ebe6 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -381,33 +381,65 @@ try: from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb - class PositionRotaryEmbedding(nn.Module): - def __init__(self, inv_freq): - super().__init__() + def _create_inv_freq(dim, base, device): + inv_freq = 1.0 / ( + base + ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) + return inv_freq + def _get_rope_config(config): + if os.getenv("ROPE_SCALING", None) is not None: + rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])} + return rope_scaling + return getattr(config, "rope_scaling", None) + + class PositionRotaryEmbedding(nn.Module): + def __init__(self, inv_freq, scaling_factor): + super().__init__() self.inv_freq = inv_freq self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None self._cos_k_cached = None self._sin_k_cached = None + self.scaling_factor = scaling_factor + self.dynamic_args = None @classmethod - def static(cls, dim, base, device): - inv_freq = 1.0 / ( - base - ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) - ) - return cls(inv_freq) + def static(cls, config, dim, base, device): + inv_freq = _create_inv_freq(dim, base, device) + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor) + else: + raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") + return cls(inv_freq, scaling_factor) @classmethod - def load(cls, prefix, weights): + def load(cls, config, prefix, weights): # XXX: Always load this in float32 ! dtype = weights.dtype weights.dtype = torch.float32 inv_freq = weights.get_tensor(f"{prefix}.inv_freq") weights.dtype = dtype - return cls(inv_freq) + + scaling_factor = None + rope_scaling = _get_rope_config(config) + if rope_scaling is not None: + scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "linear": + pass + elif rope_scaling["type"] == "dynamic": + return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor) + else: + raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") + return cls(inv_freq, scaling_factor) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -419,8 +451,11 @@ try: ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor # Don't do einsum, it converts fp32 to fp16 # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) @@ -446,5 +481,36 @@ try: rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) return x + class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): + def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): + inv_freq = create_inv_freq(dim, base, device) + super().__init__(inv_freq, scaling_factor) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + if seqlen > self.max_position_embeddings: + newbase = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2)) + self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device) + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + if self.scaling_factor is not None: + t /= self.scaling_factor + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + except ImportError: pass