diff --git a/Dockerfile b/Dockerfile index 02540f81..b6c5b2ed 100644 --- a/Dockerfile +++ b/Dockerfile @@ -154,6 +154,11 @@ COPY server/Makefile-vllm Makefile # Build specific version of vllm RUN make build-vllm-cuda +# Build megablocks +FROM kernel-builder as megablocks-builder + +RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e + # Text Generation Inference base image FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base @@ -175,8 +180,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins curl \ && rm -rf /var/lib/apt/lists/* -# Copy conda with PyTorch installed -COPY --from=pytorch-install /opt/conda /opt/conda +# Copy conda with PyTorch and Megablocks installed +COPY --from=megablocks-builder /opt/conda /opt/conda # Copy build artifacts from flash attention builder COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages diff --git a/router/src/server.rs b/router/src/server.rs index 5f41fd5e..fe1b8309 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -629,6 +629,9 @@ pub async fn run( // Batch size buckets let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size")); let batch_size_buckets: Vec = (0..1024).map(|x| (x + 1) as f64).collect(); + // Speculated tokens buckets + let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens")); + let skipped_buckets: Vec = (0..shard_info.speculate + 1).map(|x| x as f64).collect(); // Prometheus handler let builder = PrometheusBuilder::new() @@ -641,6 +644,8 @@ pub async fn run( .set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets) .unwrap() .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets) + .unwrap() + .set_buckets_for_metric(skipped_matcher, &skipped_buckets) .unwrap(); let prom_handle = builder .install_recorder() diff --git a/server/Makefile b/server/Makefile index 2810a528..b1926828 100644 --- a/server/Makefile +++ b/server/Makefile @@ -16,6 +16,9 @@ gen-server: find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; touch text_generation_server/pb/__init__.py +install-megablocks: + pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e + install: gen-server pip install pip --upgrade pip install -r requirements_cuda.txt diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 27e3897d..0172d32c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,4 +1,3 @@ -import os import torch from loguru import logger @@ -78,6 +77,18 @@ except ImportError as e: if MISTRAL: __all__.append(FlashMistral) +MIXTRAL = True +try: + from text_generation_server.models.flash_mixtral import FlashMixtral +except ImportError as e: + logger.warning(f"Could not import Mixtral model: {e}") + MIXTRAL = False + +if MIXTRAL: + __all__.append(FlashMixtral) + + + def get_model( model_id: str, revision: Optional[str], @@ -141,7 +152,6 @@ def get_model( use_medusa = None if "medusa_num_heads" in config_dict: use_medusa = model_id - medusa_config = config_dict model_id = config_dict["base_model_name_or_path"] revision = "main" speculate_medusa = config_dict["medusa_num_heads"] @@ -292,7 +302,18 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - raise NotImplementedError("Mistral model requires flash attention v2") + raise NotImplementedError("Mistral models requires flash attention v2") + + if model_type == "mixtral": + if MIXTRAL: + return FlashMixtral( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks") if model_type == "opt": return OPTSharded( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 4aeb447d..d06b87eb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -34,14 +34,8 @@ from text_generation_server.utils.layers import ( PositionRotaryEmbedding, TensorParallelHead, get_linear, + FastRMSNorm ) -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM - -if IS_CUDA_SYSTEM: - import dropout_layer_norm -elif IS_ROCM_SYSTEM: - from vllm import layernorm_ops - class LlamaConfig(PretrainedConfig): def __init__( @@ -95,75 +89,6 @@ class LlamaConfig(PretrainedConfig): ) -class LlamaRMSNorm(nn.Module): - def __init__(self, prefix, weights, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - - weight = weights.get_tensor(f"{prefix}.weight") - self.weight = nn.Parameter(weight) - self.variance_epsilon = eps - - def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states - - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states, residual - elif IS_CUDA_SYSTEM: - # faster post attention rms norm - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - None, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - True, # Activate RMSNorm - ) - if res is None: - res = hidden_states - - return normed_hidden_states, res - elif IS_ROCM_SYSTEM: - # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. - if residual is not None: - hidden_states += residual - residual = hidden_states - - out = torch.empty_like(hidden_states) - layernorm_ops.rms_norm( - out, - hidden_states, - self.weight.data, - self.variance_epsilon, - ) - return out, residual - else: - raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") - - def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) @@ -363,10 +288,8 @@ class FlashLlamaLayer(nn.Module): ) self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.input_layernorm = LlamaRMSNorm( - prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = LlamaRMSNorm( + self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps) + self.post_attention_layernorm = FastRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps, @@ -430,7 +353,7 @@ class FlashLlamaModel(torch.nn.Module): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = LlamaRMSNorm( + self.norm = FastRMSNorm.load( prefix="model.norm", weights=weights, eps=config.rms_norm_eps ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 959949f0..4e56b188 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -35,13 +35,9 @@ from text_generation_server.utils.layers import ( PositionRotaryEmbedding, TensorParallelHead, get_linear, + FastRMSNorm ) -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM -if IS_CUDA_SYSTEM: - import dropout_layer_norm -elif IS_ROCM_SYSTEM: - from vllm import layernorm_ops if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM: raise ImportError("Mistral model requires flash attn v2") @@ -100,76 +96,6 @@ class MistralConfig(PretrainedConfig): **kwargs, ) - -class MistralRMSNorm(nn.Module): - def __init__(self, prefix, weights, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - - weight = weights.get_tensor(f"{prefix}.weight") - self.weight = nn.Parameter(weight) - self.variance_epsilon = eps - - def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states - - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt( - variance + self.variance_epsilon - ) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states, residual - elif IS_CUDA_SYSTEM: - # faster post attention rms norm - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - None, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - True, # Activate RMSNorm - ) - if res is None: - res = hidden_states - - return normed_hidden_states, res - elif IS_ROCM_SYSTEM: - # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. - if residual is not None: - hidden_states += residual - residual = hidden_states - - out = torch.empty_like(hidden_states) - layernorm_ops.rms_norm( - out, - hidden_states, - self.weight.data, - self.variance_epsilon, - ) - return out, residual - else: - raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") - - def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) @@ -371,10 +297,10 @@ class MistralLayer(nn.Module): ) self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.input_layernorm = MistralRMSNorm( + self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) - self.post_attention_layernorm = MistralRMSNorm( + self.post_attention_layernorm = FastRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps, @@ -440,7 +366,7 @@ class MistralModel(torch.nn.Module): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = MistralRMSNorm( + self.norm = FastRMSNorm.load( prefix="model.norm", weights=weights, eps=config.rms_norm_eps ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py new file mode 100644 index 00000000..66753d5a --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -0,0 +1,708 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +import numpy as np + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA +from text_generation_server.utils.layers import ( + FastLinear, + FastRMSNorm, + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + TensorParallelHead, + get_linear, +) + +if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM: + raise ImportError("Mixtral model requires flash attn v2") + +try: + import megablocks.ops as ops +except ImportError: + raise ImportError("Mixtral model requires megablocks to be installed") + +try: + import stk +except ImportError: + raise ImportError("Mixtral model requires stk to be installed") + + +class MixtralConfig(PretrainedConfig): + model_type = "mixtral" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + num_experts_per_tok=2, + num_local_experts=8, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def promote_scalar(x: torch.Tensor) -> torch.Tensor: + return x.view(1) if len(x.size()) == 0 else x + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear( + get_linear(weight, bias=None, quantize=config.quantize) + ) + + +def _load_experts(config, prefix, mat, weights): + if config.quantize is not None: + raise NotImplementedError("Mixtral does not support weight quantization yet.") + + assert mat in ["w1", "w2", "w3"] + + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert ( + config.intermediate_size % world_size == 0 + ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" + + block_size = config.intermediate_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device) + + for i in range(config.num_local_experts): + slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") + + if mat == "w2": + expert_slice = slice_[:, start:stop].t().contiguous() + else: + expert_slice = slice_[start:stop] + tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device) + return tensor + + +class MixtralAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.max_past = ( + config.sliding_window if config.sliding_window is not None else 0 + ) + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size ** -0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + paged_attention.reshape_and_cache( + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + window_size_left=self.max_past, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +@torch.jit.script +def select_experts(gate_logits: torch.Tensor, top_k: int): + # all_probs: (sequence_length, n_experts) and upcast for softmax + all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float) + # weights, selected_experts: (sequence_length, top-k) + weights, selected_experts = torch.topk(all_probs, top_k, dim=-1) + weights /= weights.sum(dim=-1, keepdim=True) + weights = weights.view(-1) + selected_experts = selected_experts.view(-1) + + return selected_experts, weights + + +@torch.jit.script +def round_up(x: torch.Tensor, value: int): + return torch.div(x + (value - 1), value, rounding_mode="trunc") * value + + +class BlockSparseMoE(nn.Module): + """ + Built on the paper and library Megablocks as described in + https://arxiv.org/abs/2211.15841. This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, prefix, config: MixtralConfig, weights): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size // weights.process_group.size() + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + act = config.hidden_act + if "gelu" in act: + self.act = lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none", + ) + elif "silu" in act: + self.act = torch.nn.functional.silu + else: + self.act = ACT2FN[act] + + # gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + # merged expert weights, all of size (n_experts * ffn_dim, hidden_dim) + self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t() + self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights) + self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t() + + self.offsets = None + self.offsets_block_rows = 0 + + self.process_group = weights.process_group + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + self.blocking = 128 + self.quantize_scatter_num_bits = -1 + + def topology(self, x: torch.Tensor, padded_bins: torch.Tensor): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + assert self.ffn_dim % self.blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_dim // self.blocking + if self.offsets is None or block_rows > self.offsets_block_rows: + self.offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + self.offsets_block_rows = block_rows + offsets = self.offsets + else: + offsets = self.offsets[:block_rows] + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology(padded_bins, self.blocking, block_rows, + blocks_per_row) + + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=x.dtype, + device="meta", + ) + shape = (padded_tokens, self.ffn_dim * self.num_experts) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + False, + False, + False, + ) + + def indices_and_padded_bins(self, selected_experts: torch.Tensor): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # selected_experts = selected_experts.int() + + # returns bin_ids == num of experts for this sequence ? == unique selected experts? + # and indices == how to sort tokens? + bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit) + # bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k] + # indices => [14, 32, 33, ...] => [num_tokens * top_k] + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(selected_experts, self.num_experts) + # tokens_per_expert => [3, 0, 2, ...] => [num_experts] + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + + # List of size num_experts + padded_tokens_per_expert = round_up(tokens_per_expert, + self.blocking) + # padded_tokens_per_expert => [128, O, 128, ...] + + # Cumulative selected experts per token + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + # padded_bins => [128, 128, 256, ...] + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + # bins => [3, 3, 5, ...] + + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + @torch.inference_mode() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: (sequence_length, model_dim) + gate_logits: (sequence_length, n_experts) + """ + # optional reshape + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + + # gate_logits: (sequence_length, n_experts) + gate_logits = self.gate(x) + selected_experts, weights = select_experts(gate_logits, self.top_k) + + ( + indices, + bin_ids, + bins, + padded_bins, + _, + ) = self.indices_and_padded_bins(selected_experts) + + # Permute tokens and pad to prepare expert computation + # (top_k * sequence_length + padding, model_dim) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, + self.top_k) + + # Create the sparse matrix topology + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation + # First Dense x Dense -> Sparse for w1 and w3, + # (top_k * sequence_length + padding, ffn_dim * n_experts) + x = stk.Matrix( + topo.size(), + self.act(stk.ops.sdd(x, self.w1, topo).data) * + stk.ops.sdd(x, self.w3, topo).data, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Then Sparse x Dense -> Dense for w2 + # (top_k * sequence_length + padding, model_dim) + x = stk.ops.dsd(x, self.w2) + + # Permute back and remove padding + # (sequence_length, model_dim) + x = ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + self.top_k, + self.quantize_scatter_num_bits, + ).view(*input_shape) + + if self.process_group.size() > 1: + torch.distributed.all_reduce(x, group=self.process_group) + + return x.view(*input_shape) + + +class MixtralLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + + self.self_attn = MixtralAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output) + + return block_sparse_moe_output, attn_res + + +class MixtralModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + + self.layers = nn.ModuleList( + [ + MixtralLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashMixtralForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = MixtralModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ) + self.max_past = config.sliding_window + if self.max_past is None: + raise ValueError("max_past cannot be None") + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + else: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + max_s = min(self.max_past, max_s) + input_lengths = torch.clamp(input_lengths, max=self.max_past) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c3c7617a..cd93d32a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -6,7 +6,6 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index e103d9fc..5ce37164 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -8,14 +8,13 @@ from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase from transformers.models.llama import LlamaTokenizerFast -from typing import Optional, Tuple, Type +from typing import Optional, Tuple, Type, List from text_generation_server.pb import generate_pb2 from text_generation_server.models import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE from text_generation_server.models.cache_manager import ( get_cache_manager, - set_cache_manager, ) from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, @@ -46,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch): @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, ) -> "FlashCausalLMBatch": global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS @@ -100,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch): # Parse batch for i, (r, tokenized_input) in enumerate( - zip(pb.requests, batch_tokenized_inputs) + zip(pb.requests, batch_tokenized_inputs) ): # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenized_input[-r.truncate :] + tokenized_input = tokenized_input[-r.truncate:] input_length = len(tokenized_input) input_lengths.append(input_length) @@ -278,14 +277,16 @@ class FlashMistralBatch(FlashCausalLMBatch): ) -class FlashMistral(FlashCausalLM): +class BaseFlashMistral(FlashCausalLM): def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, + self, + config_cls, + model_cls, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, ): global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS @@ -305,7 +306,7 @@ class FlashMistral(FlashCausalLM): trust_remote_code=trust_remote_code, ) - config = MistralConfig.from_pretrained( + config = config_cls.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize @@ -321,10 +322,10 @@ class FlashMistral(FlashCausalLM): if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id) - model = FlashMistralForCausalLM(config, weights) + model = model_cls(config, weights) torch.distributed.barrier(group=self.process_group) - super(FlashMistral, self).__init__( + super(BaseFlashMistral, self).__init__( model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), @@ -396,3 +397,23 @@ class FlashMistral(FlashCausalLM): if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits + + +class FlashMistral(BaseFlashMistral): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + super(FlashMistral, self).__init__( + config_cls=MistralConfig, + model_cls=FlashMistralForCausalLM, + model_id=model_id, + revision=revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code + ) diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py new file mode 100644 index 00000000..c45ae50f --- /dev/null +++ b/server/text_generation_server/models/flash_mixtral.py @@ -0,0 +1,26 @@ +import torch + +from typing import Optional + +from text_generation_server.models.flash_mistral import BaseFlashMistral +from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM + + +class FlashMixtral(BaseFlashMistral): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + super(FlashMixtral, self).__init__( + config_cls=MixtralConfig, + model_cls=FlashMixtralForCausalLM, + model_id=model_id, + revision=revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code + ) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index a93ccd0e..d533016d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -18,7 +18,7 @@ except ImportError: from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM HAS_AWQ = True try: @@ -43,16 +43,18 @@ if os.getenv("DISABLE_EXLLAMA") == "True": elif CAN_EXLLAMA: try: if V2: - from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear, - create_exllama_buffers, - set_device, + from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear, + create_exllama_buffers, + set_device, ) + HAS_EXLLAMA = "2" else: from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear, - create_exllama_buffers, - set_device, - ) + create_exllama_buffers, + set_device, + ) + HAS_EXLLAMA = "1" except ImportError: @@ -112,7 +114,7 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st @classmethod def load_conv2d_no_bias( - cls, prefix, weights, in_channels, out_channels, kernel_size, stride + cls, prefix, weights, in_channels, out_channels, kernel_size, stride ): weight = weights.get_tensor(f"{prefix}.weight") with init_empty_weights(): @@ -136,9 +138,9 @@ torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias class FastLinear(nn.Module): def __init__( - self, - weight, - bias, + self, + weight, + bias, ) -> None: super().__init__() self.weight = nn.Parameter(weight) @@ -162,9 +164,9 @@ class FastLinear(nn.Module): class EETQLinear(nn.Module): def __init__( - self, - weight, - bias, + self, + weight, + bias, ) -> None: super().__init__() device = weight.device @@ -183,13 +185,13 @@ class EETQLinear(nn.Module): class Linear8bitLt(nn.Module): def __init__( - self, - weight, - bias, - has_fp16_weights=True, - memory_efficient_backward=False, - threshold=0.0, - index=None, + self, + weight, + bias, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, ): super().__init__() assert ( @@ -526,9 +528,12 @@ class TensorParallelEmbedding(nn.Module): try: if IS_CUDA_SYSTEM: import dropout_layer_norm + elif IS_ROCM_SYSTEM: + from vllm import layernorm_ops else: dropout_layer_norm = None + class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: @@ -563,10 +568,81 @@ try: residual = hidden_states return normed_hidden_states, residual + + + class FastRMSNorm(nn.Module): + def __init__(self, weight: torch.Tensor, eps: float): + super().__init__() + + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + @classmethod + def load(cls, prefix, weights, eps=1e-6): + weight = weights.get_tensor(f"{prefix}.weight") + return cls(weight, eps) + + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states, residual + elif IS_CUDA_SYSTEM: + # faster post attention rms norm + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res + elif IS_ROCM_SYSTEM: + # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not. + if residual is not None: + hidden_states += residual + residual = hidden_states + + out = torch.empty_like(hidden_states) + layernorm_ops.rms_norm( + out, + hidden_states, + self.weight.data, + self.variance_epsilon, + ) + return out, residual + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") + except ImportError: pass - try: if IS_CUDA_SYSTEM: from flash_attn.layers.rotary import RotaryEmbedding @@ -574,12 +650,14 @@ try: elif IS_ROCM_SYSTEM: from vllm import pos_encoding_ops + def _create_inv_freq(dim, base, device): inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) ) return inv_freq + def _get_rope_config(config): if os.getenv("ROPE_SCALING", None) is not None: rope_scaling = { @@ -589,6 +667,7 @@ try: return rope_scaling return getattr(config, "rope_scaling", None) + class PositionRotaryEmbedding(nn.Module): def __init__(self, inv_freq, scaling_factor): super().__init__() @@ -606,12 +685,12 @@ try: if IS_CUDA_SYSTEM: rotary_dim = cos.shape[-1] q1 = query[..., :rotary_dim] - q2 = query[..., rotary_dim : 2 * rotary_dim] + q2 = query[..., rotary_dim: 2 * rotary_dim] rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) k1 = key[..., :rotary_dim] - k2 = key[..., rotary_dim : 2 * rotary_dim] + k2 = key[..., rotary_dim: 2 * rotary_dim] rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) elif IS_ROCM_SYSTEM: @@ -630,7 +709,8 @@ try: True ) else: - raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") @classmethod def static(cls, config, dim, base, device): @@ -713,9 +793,9 @@ try: # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) @@ -729,7 +809,7 @@ try: self._sin_cached = torch.sin(freqs).to(dtype) def get_cos_sin( - self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype + self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype ): """ Return cos and sin for the asked position ids @@ -747,6 +827,7 @@ try: # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. return cos.unsqueeze(1), sin.unsqueeze(1) + class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): inv_freq = _create_inv_freq(dim, base, device) @@ -755,18 +836,18 @@ try: self.max_position_embeddings = max_position_embeddings self.base = base - def _update_cos_sin_cache(self, dtype, device, seqlen): + def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype ): if seqlen > self.max_position_embeddings: newbase = self.base * ( - (self.scaling_factor * seqlen / self.max_position_embeddings) - - (self.scaling_factor - 1) + (self.scaling_factor * seqlen / self.max_position_embeddings) + - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) self.inv_freq = _create_inv_freq( self.dim, newbase, self.inv_freq.device @@ -783,8 +864,11 @@ try: # Inverse dim formula to find dim based on number of rotations import math + + def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): - return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + # Find dim range bounds based on rotations def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): @@ -792,7 +876,8 @@ try: low_rot, dim, base, max_position_embeddings)) high = math.ceil(find_correction_dim( high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim-1) # Clamp values just in case + return max(low, 0), min(high, dim - 1) # Clamp values just in case + def linear_ramp_mask(min, max, dim): if min == max: @@ -802,13 +887,16 @@ try: ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func + def get_mscale(scale=1): if scale <= 1: return 1.0 return 0.1 * math.log(scale) + 1.0 + class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): - def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow): + def __init__(self, dim, max_position_embeddings, base, device, scaling_factor, *, extrapolation_factor, + attn_factor, beta_fast, beta_slow): inv_freq = _create_inv_freq(dim, base, device) super().__init__(inv_freq, scaling_factor) self.dim = dim @@ -818,15 +906,16 @@ try: self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow - self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation + self.mscale = float(get_mscale( + self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype ): if seqlen > self.max_position_embeddings: inv_freq_extrapolation = _create_inv_freq( @@ -834,13 +923,15 @@ try: ) freqs = 1.0 / inv_freq_extrapolation inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) - low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings) - inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, + self.max_position_embeddings) + inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to( + device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask self.inv_freq = inv_freq - self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation - + self.mscale = float(get_mscale( + self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)