From cc36128cdac2922d99f1b7ea3e55b5c02e414a9a Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 5 Oct 2023 14:40:35 +0200 Subject: [PATCH] feat: support attention sinks --- launcher/src/main.rs | 27 ++ .../text_generation_server/models/__init__.py | 2 +- server/text_generation_server/models/bloom.py | 6 +- .../models/cache_manager.py | 31 +- .../custom_modeling/flash_llama_modeling.py | 24 +- .../custom_modeling/flash_mistral_modeling.py | 21 +- .../custom_modeling/flash_neox_modeling.py | 38 ++- .../custom_modeling/flash_rw_modeling.py | 37 ++- .../flash_santacoder_modeling.py | 23 +- .../idefics_image_processing.py | 6 +- .../models/flash_causal_lm.py | 87 +++++- .../models/flash_mistral.py | 289 +----------------- server/text_generation_server/models/model.py | 10 +- .../models/sliding_window.py | 41 +++ server/text_generation_server/utils/layers.py | 81 +++-- .../text_generation_server/utils/weights.py | 5 +- update_doc.py | 6 +- 17 files changed, 383 insertions(+), 351 deletions(-) create mode 100644 server/text_generation_server/models/sliding_window.py diff --git a/launcher/src/main.rs b/launcher/src/main.rs index b4fc86b7..ac1469f6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -333,6 +333,16 @@ struct Args { #[clap(long, env)] rope_factor: Option, + /// Sliding Window will only be used by flash attention optimized models + /// Limit the Paged Attention context window size + #[clap(long, env)] + sliding_window: Option, + + /// If `sliding_window` is set, always keep the first `attention_sinks` tokens in the context + /// See: [Efficient Streaming Language Models with Attention Sinks](https://arxiv.org/abs/2309.17453) + #[clap(long, env)] + attention_sinks: Option, + /// Outputs the logs in JSON format (useful for telemetry) #[clap(long, env)] json_output: bool, @@ -390,6 +400,8 @@ fn shard_manager( cuda_memory_fraction: f32, rope_scaling: Option, rope_factor: Option, + sliding_window: Option, + attention_sinks: Option, otlp_endpoint: Option, status_sender: mpsc::Sender, shutdown: Arc, @@ -495,6 +507,17 @@ fn shard_manager( envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); } + // Detect sliding window + // Sending as env instead of CLI args to not bloat everything + // those only can be used by flash attention models, so passing information around + // for all models will complexify code unnecessarily + if let Some(sliding_window) = sliding_window { + envs.push(("SLIDING_WINDOW".into(), sliding_window.to_string().into())); + } + if let Some(attention_sinks) = attention_sinks { + envs.push(("ATTENTION_SINKS".into(), attention_sinks.to_string().into())); + } + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container if let Some(huggingface_hub_cache) = huggingface_hub_cache { @@ -891,6 +914,8 @@ fn spawn_shards( let cuda_memory_fraction = args.cuda_memory_fraction; let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor; + let sliding_window = args.sliding_window; + let attention_sinks = args.attention_sinks; thread::spawn(move || { shard_manager( model_id, @@ -911,6 +936,8 @@ fn spawn_shards( cuda_memory_fraction, rope_scaling, rope_factor, + sliding_window, + attention_sinks, otlp_endpoint, status_sender, shutdown, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5b1b5715..4ee0cd00 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -297,7 +297,7 @@ def get_model( raise ValueError("awq quantization is not supported for AutoModel") elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): raise ValueError("4bit quantization is not supported for AutoModel") - elif (quantize == "eetq"): + elif quantize == "eetq": raise ValueError("Eetq quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 8e8daad3..c3876023 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -74,7 +74,11 @@ class BLOOMSharded(CausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer", + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + prefix="transformer", ) if config.quantize == "gptq": weights._set_gptq_params(model_id) diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index 2e6ae086..35549f51 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -15,12 +15,14 @@ class CacheManager: num_layers: int, num_heads: int, head_size: int, + attention_sinks: int, repeat_slots: bool, dtype: torch.dtype, device: torch.device, ): self.block_size = BLOCK_SIZE self.num_blocks = num_blocks + self.attention_sinks = attention_sinks self.repeat_slots = repeat_slots element_size = torch.tensor([], dtype=dtype).element_size() @@ -82,8 +84,23 @@ class CacheManager: # Repeat slots in the case of context sliding window if needed_slots > len(all_slots) and self.repeat_slots: - repeats = math.ceil(needed_slots / len(all_slots)) - all_slots = all_slots.repeat(repeats) + repeats = math.ceil( + needed_slots / (len(all_slots) - self.attention_sinks) + ) + + if self.attention_sinks > 0: + # Remove attention sinks from the repeat to not override them + all_slots = torch.cat( + [ + all_slots, + all_slots[self.attention_sinks :].repeat(repeats - 1), + ] + ) + else: + all_slots = all_slots.repeat(repeats) + + elif needed_slots > len(all_slots): + raise RuntimeError("Out of available slots. This is a bug") allocated_slots = all_slots[:needed_slots] @@ -112,6 +129,7 @@ def set_cache_manager( num_layers: int, num_heads: int, head_size: int, + attention_sinks: int, repeat_slots: bool, dtype: torch.dtype, device: torch.device, @@ -122,7 +140,14 @@ def set_cache_manager( torch.cuda.empty_cache() CACHE_MANAGER = CacheManager( - num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device + num_blocks, + num_layers, + num_heads, + head_size, + attention_sinks, + repeat_slots, + dtype, + device, ) return CACHE_MANAGER diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 7c743a88..ff191ee5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -254,6 +254,7 @@ class FlashLlamaAttention(torch.nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -269,8 +270,13 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + vllm_cache_ops.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) # output tensor @@ -376,6 +382,7 @@ class FlashLlamaLayer(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -390,6 +397,7 @@ class FlashLlamaLayer(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ) # faster post attention rms norm @@ -442,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -464,6 +473,7 @@ class FlashLlamaModel(torch.nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -492,8 +502,19 @@ class FlashLlamaForCausalLM(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + sliding_window: int, lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif sliding_window != -1: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + max_s = min(sliding_window, max_s) + input_lengths = torch.clamp(input_lengths, max=sliding_window) + hidden_states = self.model( input_ids, position_ids, @@ -503,6 +524,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 77b7f230..b21b730b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -201,9 +201,6 @@ class MistralAttention(torch.nn.Module): weights, ): super().__init__() - self.max_past = ( - config.sliding_window if config.sliding_window is not None else 0 - ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads @@ -252,6 +249,7 @@ class MistralAttention(torch.nn.Module): input_lengths, max_s, prefill_cache_indices, + sliding_window, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -290,7 +288,7 @@ class MistralAttention(torch.nn.Module): cu_seqlen_prefill, max_s, self.softmax_scale, - window_size_left=self.max_past, + window_size_left=sliding_window, ) # Decode else: @@ -381,6 +379,7 @@ class MistralLayer(nn.Module): input_lengths, max_s, prefill_cache_indices, + sliding_window, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -396,6 +395,7 @@ class MistralLayer(nn.Module): input_lengths, max_s, prefill_cache_indices, + sliding_window, ) # faster post attention rms norm @@ -449,6 +449,7 @@ class MistralModel(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, prefill_cache_indices: Optional[torch.Tensor], + sliding_window: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -472,6 +473,7 @@ class MistralModel(torch.nn.Module): input_lengths, max_s, prefill_cache_indices, + sliding_window, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -489,9 +491,6 @@ class FlashMistralForCausalLM(torch.nn.Module): prefix="lm_head", weights=weights, ) - self.max_past = config.sliding_window - if self.max_past is None: - raise ValueError("max_past cannot be None") def forward( self, @@ -504,16 +503,17 @@ class FlashMistralForCausalLM(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, prefill_cache_indices: Optional[torch.Tensor], + sliding_window: int, lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - else: + elif sliding_window != -1: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - max_s = min(self.max_past, max_s) - input_lengths = torch.clamp(input_lengths, max=self.max_past) + max_s = min(sliding_window, max_s) + input_lengths = torch.clamp(input_lengths, max=sliding_window) hidden_states = self.model( input_ids, @@ -525,6 +525,7 @@ class FlashMistralForCausalLM(torch.nn.Module): input_lengths, max_s, prefill_cache_indices, + sliding_window, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 9dc374df..abc66179 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -133,6 +133,7 @@ class FlashNeoxAttention(torch.nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -141,20 +142,28 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(qkv[:, 1], cos, sin) + query, kv = qkv.split([1, 2], dim=1) + query = query.view(-1, self.num_heads, self.head_size) + + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + vllm_cache_ops.reshape_and_cache( - qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) # output tensor - attn_output = torch.empty_like(qkv[:, 0]) + attn_output = torch.empty_like(query) # Prefill if cu_seqlen_prefill is not None: # flash attention attention( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], + query, + kv[:, 0], + kv[:, 1], attn_output, cu_seqlen_prefill, max_s, @@ -166,7 +175,7 @@ class FlashNeoxAttention(torch.nn.Module): block_size = kv_cache[1].shape[3] vllm_attention_ops.single_query_cached_kv_attention( attn_output, - qkv[:, 0], + query, kv_cache[0], kv_cache[1], self.kv_head_mapping, @@ -245,6 +254,7 @@ class FlashNeoXLayer(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ): if self.use_parallel_residual: ln1_hidden_states, _ = self.input_layernorm(hidden_states) @@ -259,6 +269,7 @@ class FlashNeoXLayer(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) @@ -283,6 +294,7 @@ class FlashNeoXLayer(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ) hidden_states, residual = self.post_attention_layernorm( @@ -337,6 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) @@ -359,6 +372,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): slots, input_lengths, max_s, + prefill_cache_indices, ) hidden_states, _ = self.final_layer_norm(hidden_states, residual) @@ -385,8 +399,19 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + sliding_window: int, lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif sliding_window != -1: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + max_s = min(sliding_window, max_s) + input_lengths = torch.clamp(input_lengths, max=sliding_window) + hidden_states = self.gpt_neox( input_ids, position_ids, @@ -396,6 +421,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): slots, input_lengths, max_s, + prefill_cache_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 8419fa4f..89a15c5d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -174,6 +174,7 @@ class FlashRWAttention(torch.nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ): qkv = self.query_key_value(hidden_states) @@ -191,8 +192,13 @@ class FlashRWAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + vllm_cache_ops.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) # output @@ -294,6 +300,7 @@ class FlashRWLargeAttention(torch.nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) @@ -310,9 +317,14 @@ class FlashRWLargeAttention(torch.nn.Module): self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + vllm_cache_ops.reshape_and_cache( - kv[:, :, 0].contiguous(), - kv[:, :, 1].contiguous(), + kv_to_cache[:, :, 0].contiguous(), + kv_to_cache[:, :, 1].contiguous(), kv_cache[0], kv_cache[1], slots, @@ -428,6 +440,7 @@ class FlashRWLayer(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -442,6 +455,7 @@ class FlashRWLayer(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ) mlp_output = self.mlp(ln_hidden_states) @@ -464,6 +478,7 @@ class FlashRWLayer(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ) hidden_states, residual = self.post_attention_layernorm( @@ -513,6 +528,7 @@ class FlashRWLargeLayer(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ): ln_attn, residual = self.ln_attn(hidden_states, residual) ln_mlp, _ = self.ln_mlp(residual) @@ -528,6 +544,7 @@ class FlashRWLargeLayer(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ) # MLP. @@ -589,6 +606,7 @@ class FlashRWModel(FlashRWPreTrainedModel): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices, ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) @@ -611,6 +629,7 @@ class FlashRWModel(FlashRWPreTrainedModel): slots, input_lengths, max_s, + prefill_cache_indices, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -638,8 +657,19 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + sliding_window: int, lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif sliding_window != -1: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + max_s = min(sliding_window, max_s) + input_lengths = torch.clamp(input_lengths, max=sliding_window) + hidden_states = self.transformer( input_ids, position_ids, @@ -649,6 +679,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): slots, input_lengths, max_s, + prefill_cache_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2dd0a5ee..3a61d058 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -246,6 +246,7 @@ class FlashMQAttention(torch.nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ): qkv = self.c_attn(hidden_states) @@ -258,8 +259,13 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) + if prefill_cache_indices is not None: + kv_to_cache = key_value[prefill_cache_indices] + else: + kv_to_cache = key_value + vllm_cache_ops.reshape_and_cache( - key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) # output @@ -367,6 +373,7 @@ class Block(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ) hidden_states, residual = self.ln_2(hidden_states, residual) @@ -420,6 +427,7 @@ class FlashSantacoderModel(nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices, ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -437,6 +445,7 @@ class FlashSantacoderModel(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -462,8 +471,19 @@ class FlashSantacoderForCausalLM(nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + sliding_window: int, lm_head_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif sliding_window != -1: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + max_s = min(sliding_window, max_s) + input_lengths = torch.clamp(input_lengths, max=sliding_window) + hidden_states = self.transformer( input_ids, position_ids, @@ -473,6 +493,7 @@ class FlashSantacoderForCausalLM(nn.Module): slots, input_lengths, max_s, + prefill_cache_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py index 21aa3ff3..0a6f24e0 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py +++ b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py @@ -198,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor): image = image_url_or_urls if image.startswith("http://") or image.startswith("https://"): - response = requests.get(image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)) + response = requests.get( + image_url_or_urls, stream=True, headers=headers, timeout=(1, 5) + ) response.raise_for_status() content = response.content else: @@ -208,7 +210,7 @@ class IdeficsImageProcessor(BaseImageProcessor): image = Image.open(BytesIO(content)) # image.verify() except Exception: - raise ValueError(f"Could not load image from url {image_url_or_urls}") + raise ValueError(f"Could not load image from url {image_url_or_urls}") return image else: raise ValueError( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1fe40c0c..09bb078b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -9,7 +9,7 @@ import numpy as np from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase -from typing import Optional, Tuple, List, Type, Union, Dict +from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model from text_generation_server.models.types import ( @@ -24,6 +24,10 @@ from text_generation_server.models.cache_manager import ( set_cache_manager, BLOCK_SIZE, ) +from text_generation_server.models.sliding_window import ( + set_sliding_window_from_env, + get_sliding_window, +) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION @@ -47,6 +51,12 @@ class FlashCausalLMBatch(Batch): # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill cu_seqlen_prefill: Optional[torch.Tensor] + # Sliding window values + + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] + # Paged Attention values # Set when creating the batch @@ -109,6 +119,8 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": + sliding_window = get_sliding_window() + batch_inputs = [] max_truncation = 0 for r in pb.requests: @@ -124,6 +136,7 @@ class FlashCausalLMBatch(Batch): needed_blocks_slots = [] start_slots = [] slot_indices = [] + prefill_cache_indices = [] input_lengths = [] prefix_offsets = [] @@ -187,8 +200,15 @@ class FlashCausalLMBatch(Batch): # Paged attention # Remove one as the first token des not have a past total_tokens = input_length + max_new_tokens - 1 + needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) + + # If using sliding window + if sliding_window is not None: + # Needed blocks can not go over SLIDING_WINDOW_BLOCKS + needed_blocks = min(needed_blocks, sliding_window.blocks) blocks += needed_blocks + needed_blocks_slots.append((needed_blocks, total_tokens)) start_slots.append(cumulative_max_length) @@ -199,6 +219,32 @@ class FlashCausalLMBatch(Batch): ) slot_indices.append(request_slot_indices) + # If using sliding window + if sliding_window is not None: + # Start of the sliding window cache + start_offset = max( + 0, + input_length - sliding_window.size + sliding_window.attention_sinks, + ) + + if sliding_window.attention_sinks > 0 and start_offset > 0: + # Attention sinks indices + request_attention_sinks_cache_indices = torch.arange( + cumulative_length, + cumulative_length + + min(sliding_window.attention_sinks, start_offset), + dtype=torch.int64, + ) + prefill_cache_indices.append(request_attention_sinks_cache_indices) + + # Create tensor to slice into the kv tensor in prefill + request_prefill_cache_indices = torch.arange( + cumulative_length + start_offset, + cumulative_length + input_length, + dtype=torch.int64, + ) + prefill_cache_indices.append(request_prefill_cache_indices) + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs @@ -252,12 +298,26 @@ class FlashCausalLMBatch(Batch): position_ids = position_ids[0] slot_indices = slot_indices[0] + if len(prefill_cache_indices) > 1: + prefill_cache_indices = ( + torch.cat(prefill_cache_indices) if prefill_cache_indices else None + ) + else: + prefill_cache_indices = ( + prefill_cache_indices[0] if prefill_cache_indices else None + ) + cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill, device=device, dtype=torch.int32 ) position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) + prefill_cache_indices = ( + prefill_cache_indices.to(device) + if prefill_cache_indices is not None + else None + ) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_lengths_tensor = torch.tensor( input_lengths, dtype=torch.int32, device=device @@ -309,6 +369,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, + prefill_cache_indices=prefill_cache_indices, ) @tracer.start_as_current_span("filter") @@ -425,7 +486,7 @@ class FlashCausalLMBatch(Batch): # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) - return type(self)( + return FlashCausalLMBatch( batch_id=self.batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, @@ -454,6 +515,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, + prefill_cache_indices=None, ) @classmethod @@ -611,6 +673,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, + prefill_cache_indices=None, ) def __del__(self): @@ -636,11 +699,11 @@ class FlashCausalLM(Model): device: torch.device, rank: int = 0, world_size: int = 1, - sliding_window: Optional[int] = None, ): self.num_layers = num_layers self.num_kv_heads = num_kv_heads self.head_size = head_size + set_sliding_window_from_env() super(FlashCausalLM, self).__init__( model=model, @@ -650,7 +713,6 @@ class FlashCausalLM(Model): device=device, rank=rank, world_size=world_size, - sliding_window=sliding_window, ) @property @@ -658,6 +720,8 @@ class FlashCausalLM(Model): return FlashCausalLMBatch def warmup(self, batch: FlashCausalLMBatch): + sliding_window = get_sliding_window() + torch.cuda.empty_cache() try: cache_manager = set_cache_manager( @@ -665,7 +729,8 @@ class FlashCausalLM(Model): self.num_layers, self.num_kv_heads, self.head_size, - self.sliding_window is not None, + sliding_window.attention_sinks if sliding_window is not None else 0, + True if sliding_window is not None else False, self.dtype, self.device, ) @@ -705,7 +770,8 @@ class FlashCausalLM(Model): self.num_layers, self.num_kv_heads, self.head_size, - self.sliding_window is not None, + sliding_window.attention_sinks if sliding_window is not None else 0, + True if sliding_window is not None else False, self.dtype, self.device, ) @@ -713,8 +779,10 @@ class FlashCausalLM(Model): return int(num_blocks * BLOCK_SIZE) def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: + sliding_window = get_sliding_window() + # Model Forward - return self.model.forward( + logits = self.model.forward( input_ids=batch.input_ids, position_ids=batch.position_ids, cu_seqlen_prefill=batch.cu_seqlen_prefill, @@ -723,8 +791,13 @@ class FlashCausalLM(Model): slots=batch.slots[batch.slot_indices], input_lengths=batch.input_lengths_tensor, max_s=batch.max_seqlen, + prefill_cache_indices=batch.prefill_cache_indices, + sliding_window=sliding_window.size if sliding_window is not None else -1, lm_head_indices=batch.prefill_head_indices, ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits @tracer.start_as_current_span("generate_token") def generate_token( diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 919e4625..0c65be40 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -1,21 +1,14 @@ -import math import torch import torch.distributed -import numpy as np - -from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase from transformers.models.llama import LlamaTokenizerFast -from typing import Optional, Tuple, Type +from typing import Optional -from text_generation_server.pb import generate_pb2 from text_generation_server.models import FlashCausalLM -from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE -from text_generation_server.models.cache_manager import ( - get_cache_manager, - set_cache_manager, +from text_generation_server.models.sliding_window import ( + set_sliding_window, + get_sliding_window, ) from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, @@ -25,255 +18,10 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, - HeterogeneousNextTokenChooser, - StoppingCriteria, ) tracer = trace.get_tracer(__name__) -# Will be set in init -SLIDING_WINDOW: Optional[int] = None -SLIDING_WINDOW_BLOCKS: Optional[int] = None - - -# Adds windowing logic to FlashCausalLMBatch -@dataclass -class FlashMistralBatch(FlashCausalLMBatch): - # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers - # as we only keep SLIDING_WINDOW values instead of the whole tensor - prefill_cache_indices: Optional[torch.Tensor] = None - - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "FlashCausalLMBatch": - global SLIDING_WINDOW - global SLIDING_WINDOW_BLOCKS - - batch_inputs = [] - max_truncation = 0 - for r in pb.requests: - batch_inputs.append(r.inputs) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, truncation=True, max_length=max_truncation - )["input_ids"] - - position_ids = [] - cu_seqlen_prefill = [0] - needed_blocks_slots = [] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] - - input_lengths = [] - prefix_offsets = [] - read_offsets = [] - all_input_ids = [] - requests_idx_mapping = {} - - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] - - next_token_chooser_parameters = [] - stopping_criterias = [] - top_n_tokens = [] - - # Cumulative length - cumulative_length = 0 - cumulative_max_length = 0 - prefill_out_cumulative_length = 0 - - blocks = 0 - max_seqlen = 0 - max_length = 0 - max_blocks = 0 - - # Parse batch - for i, (r, tokenized_input) in enumerate( - zip(pb.requests, batch_tokenized_inputs) - ): - # request id -> idx in list mapping - requests_idx_mapping[r.id] = i - - tokenized_input = tokenized_input[-r.truncate :] - - input_length = len(tokenized_input) - input_lengths.append(input_length) - - prefix_offsets.append(input_length - 5) - read_offsets.append(input_length) - - all_input_ids.append(tokenized_input) - - # Position ids - request_position_ids = torch.arange(0, input_length, dtype=torch.int32) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) - - next_token_chooser_parameters.append(r.parameters) - - stopping_criteria = StoppingCriteria.from_pb( - r.stopping_parameters, tokenizer - ) - max_new_tokens = stopping_criteria.max_new_tokens - stopping_criterias.append(stopping_criteria) - top_n_tokens.append(r.top_n_tokens) - - # Paged attention - # Remove one as the first token des not have a past - total_tokens = input_length + max_new_tokens - 1 - - # Needed blocks can not go over SLIDING_WINDOW_BLOCKS - needed_blocks = min( - math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS - ) - blocks += needed_blocks - - needed_blocks_slots.append((needed_blocks, total_tokens)) - start_slots.append(cumulative_max_length) - - request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, - dtype=torch.int64, - ) - slot_indices.append(request_slot_indices) - - # Create tensor to slice into the kv tensor in prefill - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - SLIDING_WINDOW), - cumulative_length + input_length, - dtype=torch.int64, - ) - prefill_cache_indices.append(request_prefill_cache_indices) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) - prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) - prefill_out_cumulative_length += input_length - else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], dtype=torch.int32 - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 - - # Update - cumulative_length += input_length - cumulative_max_length += total_tokens - max_seqlen = max(max_seqlen, input_length) - max_blocks = max(max_blocks, needed_blocks) - max_length = max(max_length, input_length + max_new_tokens) - - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - next_token_chooser_parameters, dtype, device - ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) - - # Padded all_input_ids_tensor - all_input_ids_tensor = np.zeros( - (len(all_input_ids), max_length), dtype=np.int64 - ) - for i, input_ids in enumerate(all_input_ids): - all_input_ids_tensor[i, : len(input_ids)] = input_ids - - # Create tensors on device - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) - - if len(pb.requests) > 1: - input_ids = np.concatenate(all_input_ids, dtype=np.int64) - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - prefill_cache_indices = torch.cat(prefill_cache_indices) - else: - input_ids = all_input_ids[0] - position_ids = position_ids[0] - slot_indices = slot_indices[0] - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - - position_ids = position_ids.to(device) - slot_indices = slot_indices.to(device) - prefill_cache_indices = prefill_cache_indices.to(device) - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - input_lengths_tensor = torch.tensor( - input_lengths, dtype=torch.int32, device=device - ) - - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) - top_n_tokens_tensor = torch.tensor( - top_n_tokens, device=device, dtype=torch.int64 - ) - - return cls( - batch_id=pb.id, - requests=pb.requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - start_slots=start_slots, - slot_indices=slot_indices, - needed_blocks_slots=needed_blocks_slots, - block_tables=None, - block_tables_tensor=None, - slots=None, - max_seqlen=max_seqlen, - prefill_head_indices=prefill_head_indices, - prefill_next_token_indices=prefill_next_token_indices, - prefill_cu_outlens=prefill_cu_outlens, - input_lengths=input_lengths, - input_lengths_tensor=input_lengths_tensor, - prefix_offsets=prefix_offsets, - read_offsets=read_offsets, - all_input_ids=all_input_ids, - all_input_ids_tensor=all_input_ids_tensor, - next_token_chooser=next_token_chooser, - stopping_criterias=stopping_criterias, - top_n_tokens=top_n_tokens, - top_n_tokens_tensor=top_n_tokens_tensor, - blocks=blocks, - max_blocks=max_blocks, - prefill_cache_indices=prefill_cache_indices, - ) - class FlashMistral(FlashCausalLM): def __init__( @@ -284,9 +32,6 @@ class FlashMistral(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - global SLIDING_WINDOW - global SLIDING_WINDOW_BLOCKS - self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") @@ -308,8 +53,7 @@ class FlashMistral(FlashCausalLM): config.quantize = quantize # Set context windows - SLIDING_WINDOW = config.sliding_window - SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE) + set_sliding_window(config.sliding_window, 0) torch.distributed.barrier(group=self.process_group) @@ -331,27 +75,4 @@ class FlashMistral(FlashCausalLM): device=device, rank=rank, world_size=world_size, - sliding_window=config.sliding_window, ) - - @property - def batch_type(self) -> Type[FlashMistralBatch]: - return FlashMistralBatch - - def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: - # Model Forward - logits = self.model.forward( - input_ids=batch.input_ids, - position_ids=batch.position_ids, - cu_seqlen_prefill=batch.cu_seqlen_prefill, - kv_cache=get_cache_manager().kv_cache, - block_tables=batch.block_tables_tensor, - slots=batch.slots[batch.slot_indices], - input_lengths=batch.input_lengths_tensor, - max_s=batch.max_seqlen, - prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=batch.prefill_head_indices, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 17d2ea9b..775d073c 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -3,10 +3,11 @@ import torch from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type -from transformers import PreTrainedTokenizerBase, PretrainedConfig +from transformers import PreTrainedTokenizerBase from text_generation_server.models.types import Batch, Generation from text_generation_server.pb.generate_pb2 import InfoResponse +from text_generation_server.models.sliding_window import get_sliding_window B = TypeVar("B", bound=Batch) @@ -21,7 +22,6 @@ class Model(ABC): device: torch.device, rank: int = 0, world_size: int = 1, - sliding_window: Optional[int] = None, ): self.model = model.eval() self.tokenizer = tokenizer @@ -31,7 +31,6 @@ class Model(ABC): self.device = device self.rank = rank self.world_size = world_size - self.sliding_window = sliding_window self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) @@ -42,14 +41,15 @@ class Model(ABC): @property def info(self) -> InfoResponse: - if self.requires_padding and self.sliding_window is not None: + sliding_window = get_sliding_window() + if self.requires_padding and sliding_window is not None: raise NotImplementedError("sliding_window is not implemented with padding") return InfoResponse( requires_padding=self.requires_padding, dtype=str(self.dtype), device_type=self.device.type, - window_size=self.sliding_window, + window_size=sliding_window.size if sliding_window is not None else None, ) @property diff --git a/server/text_generation_server/models/sliding_window.py b/server/text_generation_server/models/sliding_window.py new file mode 100644 index 00000000..2df39d8b --- /dev/null +++ b/server/text_generation_server/models/sliding_window.py @@ -0,0 +1,41 @@ +import os +import math + +from typing import Optional + +from text_generation_server.models.cache_manager import BLOCK_SIZE + +SLIDING_WINDOW: Optional["SlidingWindow"] = None + + +class SlidingWindow: + def __init__(self, size: int, attention_sinks: int): + self.size = size + self.blocks = math.ceil(size / BLOCK_SIZE) + self.attention_sinks = attention_sinks + + @classmethod + def from_env(cls) -> Optional["SlidingWindow"]: + sliding_window_env = os.getenv("SLIDING_WINDOW", None) + if sliding_window_env is not None: + return cls(int(sliding_window_env), int(os.getenv("ATTENTION_SINKS", 0))) + return None + + +def set_sliding_window(size: int, attention_sinks: int) -> SlidingWindow: + global SLIDING_WINDOW + SLIDING_WINDOW = SlidingWindow(size, attention_sinks) + return SLIDING_WINDOW + + +def set_sliding_window_from_env() -> Optional[SlidingWindow]: + global SLIDING_WINDOW + env_sliding_window = SlidingWindow.from_env() + if env_sliding_window is not None: + SLIDING_WINDOW = env_sliding_window + return SLIDING_WINDOW + + +def get_sliding_window() -> Optional[SlidingWindow]: + global SLIDING_WINDOW + return SLIDING_WINDOW diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index f38f130e..22c7b73a 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -604,15 +604,16 @@ try: elif rope_scaling["type"] == "yarn": return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], - max_position_embeddings=rope_scaling["original_max_position_embeddings"], + max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor, extrapolation_factor=1, attn_factor=1, beta_fast=32, - beta_slow=1 - + beta_slow=1, ) else: raise NotImplementedError( @@ -645,15 +646,16 @@ try: elif rope_scaling["type"] == "yarn": return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], - max_position_embeddings=rope_scaling["original_max_position_embeddings"], + max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor, extrapolation_factor=1, attn_factor=1, beta_fast=32, - beta_slow=1 - + beta_slow=1, ) else: raise NotImplementedError( @@ -734,19 +736,27 @@ try: self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) - # Inverse dim formula to find dim based on number of rotations import math - def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) + + def find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 + ): + return ( + dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) + ) / (2 * math.log(base)) # Find dim range bounds based on rotations - def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): - low = math.floor(find_correction_dim( - low_rot, dim, base, max_position_embeddings)) - high = math.ceil(find_correction_dim( - high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim-1) # Clamp values just in case + def find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 + ): + low = math.floor( + find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case def linear_ramp_mask(min, max, dim): if min == max: @@ -762,7 +772,19 @@ try: return 0.1 * math.log(scale) + 1.0 class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): - def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow): + def __init__( + self, + dim, + max_position_embeddings, + base, + device, + scaling_factor, + *, + extrapolation_factor, + attn_factor, + beta_fast, + beta_slow, + ): inv_freq = _create_inv_freq(dim, base, device) super().__init__(inv_freq, scaling_factor) self.dim = dim @@ -772,7 +794,9 @@ try: self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow - self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + get_mscale(self.scaling_factor) * self.attn_factor + ) # Get n-d magnitude scaling corrected for interpolation def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, @@ -788,13 +812,26 @@ try: ) freqs = 1.0 / inv_freq_extrapolation inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) - low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings) - inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + low, high = find_correction_range( + self.beta_fast, + self.beta_slow, + self.dim, + self.base, + self.max_position_embeddings, + ) + inv_freq_mask = ( + 1 + - linear_ramp_mask(low, high, self.dim // 2).float().to(device) + ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) self.inv_freq = inv_freq - self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation - + self.mscale = float( + get_mscale(self.scaling_factor) * self.attn_factor + ) # Get n-d magnitude scaling corrected for interpolation self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 2f330d9c..9c91662f 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -16,7 +16,7 @@ class Weights: dtype, process_group, aliases: Optional[Dict[str, List[str]]] = None, - prefix: Optional[str] = None + prefix: Optional[str] = None, ): routing = {} for filename in filenames: @@ -213,7 +213,8 @@ class Weights: bits, groupsize = self._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA - use_exllama = bits==4 and HAS_EXLLAMA and quantize == "gptq" + + use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] diff --git a/update_doc.py b/update_doc.py index 6206e211..6127418c 100644 --- a/update_doc.py +++ b/update_doc.py @@ -21,14 +21,14 @@ def main(): block = [] for line in lines: if line.startswith(" -") or line.startswith(" -"): - rendered_block = '\n'.join(block) + rendered_block = "\n".join(block) if header: final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n" else: final_doc += f"```shell\n{rendered_block}\n```\n" block = [] tokens = line.split("<") - if len(tokens)>1: + if len(tokens) > 1: header = tokens[-1][:-1] else: header = line.split("--")[-1] @@ -36,7 +36,7 @@ def main(): block.append(line) - rendered_block = '\n'.join(block) + rendered_block = "\n".join(block) final_doc += f"## {header}\n```shell\n{rendered_block}\n```\n" block = []