feat: mixtral (#1328)

This commit is contained in:
OlivierDehaene 2023-12-11 14:43:40 +01:00 committed by GitHub
parent 9ecfa16b12
commit 3a521c92b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 959 additions and 231 deletions

View File

@ -154,6 +154,11 @@ COPY server/Makefile-vllm Makefile
# Build specific version of vllm # Build specific version of vllm
RUN make build-vllm-cuda RUN make build-vllm-cuda
# Build megablocks
FROM kernel-builder as megablocks-builder
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base
@ -175,8 +180,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \ curl \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy conda with PyTorch installed # Copy conda with PyTorch and Megablocks installed
COPY --from=pytorch-install /opt/conda /opt/conda COPY --from=megablocks-builder /opt/conda /opt/conda
# Copy build artifacts from flash attention builder # Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

View File

@ -629,6 +629,9 @@ pub async fn run(
// Batch size buckets // Batch size buckets
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size")); let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect(); let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
// Speculated tokens buckets
let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
// Prometheus handler // Prometheus handler
let builder = PrometheusBuilder::new() let builder = PrometheusBuilder::new()
@ -641,6 +644,8 @@ pub async fn run(
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets) .set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
.unwrap() .unwrap()
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets) .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
.unwrap()
.set_buckets_for_metric(skipped_matcher, &skipped_buckets)
.unwrap(); .unwrap();
let prom_handle = builder let prom_handle = builder
.install_recorder() .install_recorder()

View File

@ -16,6 +16,9 @@ gen-server:
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py touch text_generation_server/pb/__init__.py
install-megablocks:
pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
install: gen-server install: gen-server
pip install pip --upgrade pip install pip --upgrade
pip install -r requirements_cuda.txt pip install -r requirements_cuda.txt

View File

@ -1,4 +1,3 @@
import os
import torch import torch
from loguru import logger from loguru import logger
@ -78,6 +77,18 @@ except ImportError as e:
if MISTRAL: if MISTRAL:
__all__.append(FlashMistral) __all__.append(FlashMistral)
MIXTRAL = True
try:
from text_generation_server.models.flash_mixtral import FlashMixtral
except ImportError as e:
logger.warning(f"Could not import Mixtral model: {e}")
MIXTRAL = False
if MIXTRAL:
__all__.append(FlashMixtral)
def get_model( def get_model(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
@ -141,7 +152,6 @@ def get_model(
use_medusa = None use_medusa = None
if "medusa_num_heads" in config_dict: if "medusa_num_heads" in config_dict:
use_medusa = model_id use_medusa = model_id
medusa_config = config_dict
model_id = config_dict["base_model_name_or_path"] model_id = config_dict["base_model_name_or_path"]
revision = "main" revision = "main"
speculate_medusa = config_dict["medusa_num_heads"] speculate_medusa = config_dict["medusa_num_heads"]
@ -292,7 +302,18 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
raise NotImplementedError("Mistral model requires flash attention v2") raise NotImplementedError("Mistral models requires flash attention v2")
if model_type == "mixtral":
if MIXTRAL:
return FlashMixtral(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")
if model_type == "opt": if model_type == "opt":
return OPTSharded( return OPTSharded(

View File

@ -34,14 +34,8 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, TensorParallelHead,
get_linear, get_linear,
FastRMSNorm
) )
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
class LlamaConfig(PretrainedConfig): class LlamaConfig(PretrainedConfig):
def __init__( def __init__(
@ -95,75 +89,6 @@ class LlamaConfig(PretrainedConfig):
) )
class LlamaRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
@ -363,10 +288,8 @@ class FlashLlamaLayer(nn.Module):
) )
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = LlamaRMSNorm( self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps self.post_attention_layernorm = FastRMSNorm.load(
)
self.post_attention_layernorm = LlamaRMSNorm(
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
@ -430,7 +353,7 @@ class FlashLlamaModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = LlamaRMSNorm( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm", weights=weights, eps=config.rms_norm_eps
) )

View File

@ -35,13 +35,9 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, TensorParallelHead,
get_linear, get_linear,
FastRMSNorm
) )
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if IS_CUDA_SYSTEM:
import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM: if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
raise ImportError("Mistral model requires flash attn v2") raise ImportError("Mistral model requires flash attn v2")
@ -100,76 +96,6 @@ class MistralConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
class MistralRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
@ -371,10 +297,10 @@ class MistralLayer(nn.Module):
) )
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = MistralRMSNorm( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
) )
self.post_attention_layernorm = MistralRMSNorm( self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
@ -440,7 +366,7 @@ class MistralModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = MistralRMSNorm( self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm", weights=weights, eps=config.rms_norm_eps
) )

View File

@ -0,0 +1,708 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
import numpy as np
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA
from text_generation_server.utils.layers import (
FastLinear,
FastRMSNorm,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
)
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
raise ImportError("Mixtral model requires flash attn v2")
try:
import megablocks.ops as ops
except ImportError:
raise ImportError("Mixtral model requires megablocks to be installed")
try:
import stk
except ImportError:
raise ImportError("Mixtral model requires stk to be installed")
class MixtralConfig(PretrainedConfig):
model_type = "mixtral"
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
num_experts_per_tok=2,
num_local_experts=8,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
)
def _load_experts(config, prefix, mat, weights):
if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.")
assert mat in ["w1", "w2", "w3"]
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.intermediate_size % world_size == 0
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
block_size = config.intermediate_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size),
dtype=weights.dtype,
device=weights.device)
for i in range(config.num_local_experts):
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
if mat == "w2":
expert_slice = slice_[:, start:stop].t().contiguous()
else:
expert_slice = slice_[start:stop]
tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device)
return tensor
class MixtralAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else 0
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size ** -0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@torch.jit.script
def select_experts(gate_logits: torch.Tensor, top_k: int):
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
# weights, selected_experts: (sequence_length, top-k)
weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
weights = weights.view(-1)
selected_experts = selected_experts.view(-1)
return selected_experts, weights
@torch.jit.script
def round_up(x: torch.Tensor, value: int):
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
class BlockSparseMoE(nn.Module):
"""
Built on the paper and library Megablocks as described in
https://arxiv.org/abs/2211.15841. This implementation is
strictly equivalent to standard MoE with full capacity (no
dropped tokens). It's faster since it formulates MoE operations
in terms of block-sparse operations to accomodate imbalanced
assignments of tokens to experts, whereas standard MoE either
(1) drop tokens at the cost of reduced performance or (2) set
capacity factor to number of experts and thus waste computation
and memory on padding.
"""
def __init__(self, prefix, config: MixtralConfig, weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size // weights.process_group.size()
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
act = config.hidden_act
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t()
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t()
self.offsets = None
self.offsets_block_rows = 0
self.process_group = weights.process_group
# Calculate the number of bits needed to represent the expert indices
# so that we can pass it to radix sort.
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
self.blocking = 128
self.quantize_scatter_num_bits = -1
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.ffn_dim % self.blocking == 0
# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.ffn_dim // self.blocking
if self.offsets is None or block_rows > self.offsets_block_rows:
self.offsets = torch.arange(
0,
block_rows * blocks_per_row + 1,
blocks_per_row,
dtype=torch.int32,
device=x.device,
)
self.offsets_block_rows = block_rows
offsets = self.offsets
else:
offsets = self.offsets[:block_rows]
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(padded_bins, self.blocking, block_rows,
blocks_per_row)
# For now, use meta init to save the device memory.
data = torch.empty(
column_indices.numel(),
self.blocking,
self.blocking,
dtype=x.dtype,
device="meta",
)
shape = (padded_tokens, self.ffn_dim * self.num_experts)
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
return stk.Matrix(
shape,
data,
row_indices,
column_indices,
offsets,
False,
False,
False,
)
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
# selected_experts = selected_experts.int()
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
# and indices == how to sort tokens?
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
# Round the token counts up to the block size used in
# the matrix muliplications. Caculate the starting
# position of each bin.
# List of size num_experts
padded_tokens_per_expert = round_up(tokens_per_expert,
self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...]
# Cumulative selected experts per token
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
padded_bins = promote_scalar(padded_bins)
# padded_bins => [128, 128, 256, ...]
# Calculate the bin bounds for the sorted tokens.
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
bins = promote_scalar(bins)
# bins => [3, 3, 5, ...]
return indices, bin_ids, bins, padded_bins, tokens_per_expert
@torch.inference_mode()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
selected_experts, weights = select_experts(gate_logits, self.top_k)
(
indices,
bin_ids,
bins,
padded_bins,
_,
) = self.indices_and_padded_bins(selected_experts)
# Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
self.top_k)
# Create the sparse matrix topology
with torch.no_grad():
topo = self.topology(x, padded_bins)
# Perform the expert computation
# First Dense x Dense -> Sparse for w1 and w3,
# (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix(
topo.size(),
self.act(stk.ops.sdd(x, self.w1, topo).data) *
stk.ops.sdd(x, self.w3, topo).data,
topo.row_indices,
topo.column_indices,
topo.offsets,
topo.column_indices_t,
topo.offsets_t,
topo.block_offsets_t,
)
# Then Sparse x Dense -> Dense for w2
# (top_k * sequence_length + padding, model_dim)
x = stk.ops.dsd(x, self.w2)
# Permute back and remove padding
# (sequence_length, model_dim)
x = ops.padded_scatter(
x,
indices,
bin_ids,
weights,
bins,
padded_bins,
self.top_k,
self.quantize_scatter_num_bits,
).view(*input_shape)
if self.process_group.size() > 1:
torch.distributed.all_reduce(x, group=self.process_group)
return x.view(*input_shape)
class MixtralLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = MixtralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output)
return block_sparse_moe_output, attn_res
class MixtralModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
MixtralLayer(
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashMixtralForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = MixtralModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,
prefix="lm_head",
weights=weights,
)
self.max_past = config.sliding_window
if self.max_past is None:
raise ValueError("max_past cannot be None")
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
else:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s = min(self.max_past, max_s)
input_lengths = torch.clamp(input_lengths, max=self.max_past)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits

View File

@ -6,7 +6,6 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,

View File

@ -8,14 +8,13 @@ from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.models.llama import LlamaTokenizerFast from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type from typing import Optional, Tuple, Type, List
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
from text_generation_server.models.cache_manager import ( from text_generation_server.models.cache_manager import (
get_cache_manager, get_cache_manager,
set_cache_manager,
) )
from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM, FlashMistralForCausalLM,
@ -46,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
global SLIDING_WINDOW global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS global SLIDING_WINDOW_BLOCKS
@ -100,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs) zip(pb.requests, batch_tokenized_inputs)
): ):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :] tokenized_input = tokenized_input[-r.truncate:]
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
@ -278,14 +277,16 @@ class FlashMistralBatch(FlashCausalLMBatch):
) )
class FlashMistral(FlashCausalLM): class BaseFlashMistral(FlashCausalLM):
def __init__( def __init__(
self, self,
model_id: str, config_cls,
revision: Optional[str] = None, model_cls,
quantize: Optional[str] = None, model_id: str,
dtype: Optional[torch.dtype] = None, revision: Optional[str] = None,
trust_remote_code: bool = False, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
): ):
global SLIDING_WINDOW global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS global SLIDING_WINDOW_BLOCKS
@ -305,7 +306,7 @@ class FlashMistral(FlashCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config = MistralConfig.from_pretrained( config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
@ -321,10 +322,10 @@ class FlashMistral(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id)
model = FlashMistralForCausalLM(config, weights) model = model_cls(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashMistral, self).__init__( super(BaseFlashMistral, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
@ -396,3 +397,23 @@ class FlashMistral(FlashCausalLM):
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
return logits return logits
class FlashMistral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMistral, self).__init__(
config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code
)

View File

@ -0,0 +1,26 @@
import torch
from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
class FlashMixtral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMixtral, self).__init__(
config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code
)

View File

@ -44,15 +44,17 @@ elif CAN_EXLLAMA:
try: try:
if V2: if V2:
from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear, from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear,
create_exllama_buffers, create_exllama_buffers,
set_device, set_device,
) )
HAS_EXLLAMA = "2" HAS_EXLLAMA = "2"
else: else:
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear, from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear,
create_exllama_buffers, create_exllama_buffers,
set_device, set_device,
) )
HAS_EXLLAMA = "1" HAS_EXLLAMA = "1"
except ImportError: except ImportError:
@ -112,7 +114,7 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st
@classmethod @classmethod
def load_conv2d_no_bias( def load_conv2d_no_bias(
cls, prefix, weights, in_channels, out_channels, kernel_size, stride cls, prefix, weights, in_channels, out_channels, kernel_size, stride
): ):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights(): with init_empty_weights():
@ -136,9 +138,9 @@ torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
class FastLinear(nn.Module): class FastLinear(nn.Module):
def __init__( def __init__(
self, self,
weight, weight,
bias, bias,
) -> None: ) -> None:
super().__init__() super().__init__()
self.weight = nn.Parameter(weight) self.weight = nn.Parameter(weight)
@ -162,9 +164,9 @@ class FastLinear(nn.Module):
class EETQLinear(nn.Module): class EETQLinear(nn.Module):
def __init__( def __init__(
self, self,
weight, weight,
bias, bias,
) -> None: ) -> None:
super().__init__() super().__init__()
device = weight.device device = weight.device
@ -183,13 +185,13 @@ class EETQLinear(nn.Module):
class Linear8bitLt(nn.Module): class Linear8bitLt(nn.Module):
def __init__( def __init__(
self, self,
weight, weight,
bias, bias,
has_fp16_weights=True, has_fp16_weights=True,
memory_efficient_backward=False, memory_efficient_backward=False,
threshold=0.0, threshold=0.0,
index=None, index=None,
): ):
super().__init__() super().__init__()
assert ( assert (
@ -526,9 +528,12 @@ class TensorParallelEmbedding(nn.Module):
try: try:
if IS_CUDA_SYSTEM: if IS_CUDA_SYSTEM:
import dropout_layer_norm import dropout_layer_norm
elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops
else: else:
dropout_layer_norm = None dropout_layer_norm = None
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
@ -563,10 +568,81 @@ try:
residual = hidden_states residual = hidden_states
return normed_hidden_states, residual return normed_hidden_states, residual
class FastRMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float):
super().__init__()
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
@classmethod
def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight")
return cls(weight, eps)
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
elif IS_CUDA_SYSTEM:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif IS_ROCM_SYSTEM:
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
except ImportError: except ImportError:
pass pass
try: try:
if IS_CUDA_SYSTEM: if IS_CUDA_SYSTEM:
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
@ -574,12 +650,14 @@ try:
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
from vllm import pos_encoding_ops from vllm import pos_encoding_ops
def _create_inv_freq(dim, base, device): def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / ( inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
) )
return inv_freq return inv_freq
def _get_rope_config(config): def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None: if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = { rope_scaling = {
@ -589,6 +667,7 @@ try:
return rope_scaling return rope_scaling
return getattr(config, "rope_scaling", None) return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module): class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor): def __init__(self, inv_freq, scaling_factor):
super().__init__() super().__init__()
@ -606,12 +685,12 @@ try:
if IS_CUDA_SYSTEM: if IS_CUDA_SYSTEM:
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
q1 = query[..., :rotary_dim] q1 = query[..., :rotary_dim]
q2 = query[..., rotary_dim : 2 * rotary_dim] q2 = query[..., rotary_dim: 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., :rotary_dim] k1 = key[..., :rotary_dim]
k2 = key[..., rotary_dim : 2 * rotary_dim] k2 = key[..., rotary_dim: 2 * rotary_dim]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
@ -630,7 +709,8 @@ try:
True True
) )
else: else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
@classmethod @classmethod
def static(cls, config, dim, base, device): def static(cls, config, dim, base, device):
@ -713,9 +793,9 @@ try:
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if ( if (
seqlen > self._seq_len_cached seqlen > self._seq_len_cached
or self._cos_cached.device != device or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
@ -729,7 +809,7 @@ try:
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin( def get_cos_sin(
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
): ):
""" """
Return cos and sin for the asked position ids Return cos and sin for the asked position ids
@ -747,6 +827,7 @@ try:
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1) return cos.unsqueeze(1), sin.unsqueeze(1)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
@ -759,14 +840,14 @@ try:
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if ( if (
seqlen > self._seq_len_cached seqlen > self._seq_len_cached
or self._cos_cached.device != device or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
if seqlen > self.max_position_embeddings: if seqlen > self.max_position_embeddings:
newbase = self.base * ( newbase = self.base * (
(self.scaling_factor * seqlen / self.max_position_embeddings) (self.scaling_factor * seqlen / self.max_position_embeddings)
- (self.scaling_factor - 1) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2)) ) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq( self.inv_freq = _create_inv_freq(
self.dim, newbase, self.inv_freq.device self.dim, newbase, self.inv_freq.device
@ -783,8 +864,11 @@ try:
# Inverse dim formula to find dim based on number of rotations # Inverse dim formula to find dim based on number of rotations
import math import math
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
# Find dim range bounds based on rotations # Find dim range bounds based on rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
@ -792,7 +876,8 @@ try:
low_rot, dim, base, max_position_embeddings)) low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim( high = math.ceil(find_correction_dim(
high_rot, dim, base, max_position_embeddings)) high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim-1) # Clamp values just in case return max(low, 0), min(high, dim - 1) # Clamp values just in case
def linear_ramp_mask(min, max, dim): def linear_ramp_mask(min, max, dim):
if min == max: if min == max:
@ -802,13 +887,16 @@ try:
ramp_func = torch.clamp(linear_func, 0, 1) ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func return ramp_func
def get_mscale(scale=1): def get_mscale(scale=1):
if scale <= 1: if scale <= 1:
return 1.0 return 1.0
return 0.1 * math.log(scale) + 1.0 return 0.1 * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor, *, extrapolation_factor,
attn_factor, beta_fast, beta_slow):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor) super().__init__(inv_freq, scaling_factor)
self.dim = dim self.dim = dim
@ -818,15 +906,16 @@ try:
self.attn_factor = attn_factor self.attn_factor = attn_factor
self.beta_fast = beta_fast self.beta_fast = beta_fast
self.beta_slow = beta_slow self.beta_slow = beta_slow
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation self.mscale = float(get_mscale(
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if ( if (
seqlen > self._seq_len_cached seqlen > self._seq_len_cached
or self._cos_cached.device != device or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
if seqlen > self.max_position_embeddings: if seqlen > self.max_position_embeddings:
inv_freq_extrapolation = _create_inv_freq( inv_freq_extrapolation = _create_inv_freq(
@ -834,13 +923,15 @@ try:
) )
freqs = 1.0 / inv_freq_extrapolation freqs = 1.0 / inv_freq_extrapolation
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs) inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings) low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base,
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation self.max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(
device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
self.inv_freq = inv_freq self.inv_freq = inv_freq
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation self.mscale = float(get_mscale(
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)