feat: mixtral (#1328)
This commit is contained in:
parent
9ecfa16b12
commit
3a521c92b3
|
@ -154,6 +154,11 @@ COPY server/Makefile-vllm Makefile
|
||||||
# Build specific version of vllm
|
# Build specific version of vllm
|
||||||
RUN make build-vllm-cuda
|
RUN make build-vllm-cuda
|
||||||
|
|
||||||
|
# Build megablocks
|
||||||
|
FROM kernel-builder as megablocks-builder
|
||||||
|
|
||||||
|
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
|
||||||
|
|
||||||
# Text Generation Inference base image
|
# Text Generation Inference base image
|
||||||
FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base
|
FROM nvidia/cuda:12.1.0-base-ubuntu20.04 as base
|
||||||
|
|
||||||
|
@ -175,8 +180,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||||
curl \
|
curl \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy conda with PyTorch installed
|
# Copy conda with PyTorch and Megablocks installed
|
||||||
COPY --from=pytorch-install /opt/conda /opt/conda
|
COPY --from=megablocks-builder /opt/conda /opt/conda
|
||||||
|
|
||||||
# Copy build artifacts from flash attention builder
|
# Copy build artifacts from flash attention builder
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
|
@ -629,6 +629,9 @@ pub async fn run(
|
||||||
// Batch size buckets
|
// Batch size buckets
|
||||||
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
|
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
|
||||||
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
|
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
|
||||||
|
// Speculated tokens buckets
|
||||||
|
let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
|
||||||
|
let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
|
||||||
|
|
||||||
// Prometheus handler
|
// Prometheus handler
|
||||||
let builder = PrometheusBuilder::new()
|
let builder = PrometheusBuilder::new()
|
||||||
|
@ -641,6 +644,8 @@ pub async fn run(
|
||||||
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
|
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
|
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
|
||||||
|
.unwrap()
|
||||||
|
.set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let prom_handle = builder
|
let prom_handle = builder
|
||||||
.install_recorder()
|
.install_recorder()
|
||||||
|
|
|
@ -16,6 +16,9 @@ gen-server:
|
||||||
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
touch text_generation_server/pb/__init__.py
|
touch text_generation_server/pb/__init__.py
|
||||||
|
|
||||||
|
install-megablocks:
|
||||||
|
pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
|
||||||
|
|
||||||
install: gen-server
|
install: gen-server
|
||||||
pip install pip --upgrade
|
pip install pip --upgrade
|
||||||
pip install -r requirements_cuda.txt
|
pip install -r requirements_cuda.txt
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
@ -78,6 +77,18 @@ except ImportError as e:
|
||||||
if MISTRAL:
|
if MISTRAL:
|
||||||
__all__.append(FlashMistral)
|
__all__.append(FlashMistral)
|
||||||
|
|
||||||
|
MIXTRAL = True
|
||||||
|
try:
|
||||||
|
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Could not import Mixtral model: {e}")
|
||||||
|
MIXTRAL = False
|
||||||
|
|
||||||
|
if MIXTRAL:
|
||||||
|
__all__.append(FlashMixtral)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
|
@ -141,7 +152,6 @@ def get_model(
|
||||||
use_medusa = None
|
use_medusa = None
|
||||||
if "medusa_num_heads" in config_dict:
|
if "medusa_num_heads" in config_dict:
|
||||||
use_medusa = model_id
|
use_medusa = model_id
|
||||||
medusa_config = config_dict
|
|
||||||
model_id = config_dict["base_model_name_or_path"]
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
revision = "main"
|
revision = "main"
|
||||||
speculate_medusa = config_dict["medusa_num_heads"]
|
speculate_medusa = config_dict["medusa_num_heads"]
|
||||||
|
@ -292,7 +302,18 @@ def get_model(
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
raise NotImplementedError("Mistral model requires flash attention v2")
|
raise NotImplementedError("Mistral models requires flash attention v2")
|
||||||
|
|
||||||
|
if model_type == "mixtral":
|
||||||
|
if MIXTRAL:
|
||||||
|
return FlashMixtral(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks")
|
||||||
|
|
||||||
if model_type == "opt":
|
if model_type == "opt":
|
||||||
return OPTSharded(
|
return OPTSharded(
|
||||||
|
|
|
@ -34,14 +34,8 @@ from text_generation_server.utils.layers import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
|
FastRMSNorm
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM:
|
|
||||||
import dropout_layer_norm
|
|
||||||
elif IS_ROCM_SYSTEM:
|
|
||||||
from vllm import layernorm_ops
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaConfig(PretrainedConfig):
|
class LlamaConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -95,75 +89,6 @@ class LlamaConfig(PretrainedConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaRMSNorm(nn.Module):
|
|
||||||
def __init__(self, prefix, weights, eps=1e-6):
|
|
||||||
"""
|
|
||||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
|
||||||
self.weight = nn.Parameter(weight)
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
|
||||||
if hidden_states.shape[-1] > 8192:
|
|
||||||
if residual is not None:
|
|
||||||
hidden_states += residual
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(
|
|
||||||
variance + self.variance_epsilon
|
|
||||||
)
|
|
||||||
|
|
||||||
# convert into half-precision if necessary
|
|
||||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
|
||||||
|
|
||||||
return self.weight * hidden_states, residual
|
|
||||||
elif IS_CUDA_SYSTEM:
|
|
||||||
# faster post attention rms norm
|
|
||||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
self.weight,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.variance_epsilon,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
True, # Activate RMSNorm
|
|
||||||
)
|
|
||||||
if res is None:
|
|
||||||
res = hidden_states
|
|
||||||
|
|
||||||
return normed_hidden_states, res
|
|
||||||
elif IS_ROCM_SYSTEM:
|
|
||||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
|
||||||
if residual is not None:
|
|
||||||
hidden_states += residual
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
out = torch.empty_like(hidden_states)
|
|
||||||
layernorm_ops.rms_norm(
|
|
||||||
out,
|
|
||||||
hidden_states,
|
|
||||||
self.weight.data,
|
|
||||||
self.variance_epsilon,
|
|
||||||
)
|
|
||||||
return out, residual
|
|
||||||
else:
|
|
||||||
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
|
@ -363,10 +288,8 @@ class FlashLlamaLayer(nn.Module):
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
self.input_layernorm = LlamaRMSNorm(
|
self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
)
|
|
||||||
self.post_attention_layernorm = LlamaRMSNorm(
|
|
||||||
prefix=f"{prefix}.post_attention_layernorm",
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
|
@ -430,7 +353,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = LlamaRMSNorm(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -35,13 +35,9 @@ from text_generation_server.utils.layers import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
|
FastRMSNorm
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
|
||||||
|
|
||||||
if IS_CUDA_SYSTEM:
|
|
||||||
import dropout_layer_norm
|
|
||||||
elif IS_ROCM_SYSTEM:
|
|
||||||
from vllm import layernorm_ops
|
|
||||||
|
|
||||||
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
|
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
|
||||||
raise ImportError("Mistral model requires flash attn v2")
|
raise ImportError("Mistral model requires flash attn v2")
|
||||||
|
@ -100,76 +96,6 @@ class MistralConfig(PretrainedConfig):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MistralRMSNorm(nn.Module):
|
|
||||||
def __init__(self, prefix, weights, eps=1e-6):
|
|
||||||
"""
|
|
||||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
|
||||||
self.weight = nn.Parameter(weight)
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
|
||||||
if hidden_states.shape[-1] > 8192:
|
|
||||||
if residual is not None:
|
|
||||||
hidden_states += residual
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(
|
|
||||||
variance + self.variance_epsilon
|
|
||||||
)
|
|
||||||
|
|
||||||
# convert into half-precision if necessary
|
|
||||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
|
||||||
|
|
||||||
return self.weight * hidden_states, residual
|
|
||||||
elif IS_CUDA_SYSTEM:
|
|
||||||
# faster post attention rms norm
|
|
||||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
self.weight,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.variance_epsilon,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
True, # Activate RMSNorm
|
|
||||||
)
|
|
||||||
if res is None:
|
|
||||||
res = hidden_states
|
|
||||||
|
|
||||||
return normed_hidden_states, res
|
|
||||||
elif IS_ROCM_SYSTEM:
|
|
||||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
|
||||||
if residual is not None:
|
|
||||||
hidden_states += residual
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
out = torch.empty_like(hidden_states)
|
|
||||||
layernorm_ops.rms_norm(
|
|
||||||
out,
|
|
||||||
hidden_states,
|
|
||||||
self.weight.data,
|
|
||||||
self.variance_epsilon,
|
|
||||||
)
|
|
||||||
return out, residual
|
|
||||||
else:
|
|
||||||
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return _load_gqa(config, prefix, weights)
|
||||||
|
@ -371,10 +297,10 @@ class MistralLayer(nn.Module):
|
||||||
)
|
)
|
||||||
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
self.input_layernorm = MistralRMSNorm(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
self.post_attention_layernorm = MistralRMSNorm(
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.post_attention_layernorm",
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
|
@ -440,7 +366,7 @@ class MistralModel(torch.nn.Module):
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = MistralRMSNorm(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,708 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
|
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
FastLinear,
|
||||||
|
FastRMSNorm,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
TensorParallelHead,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
|
||||||
|
raise ImportError("Mixtral model requires flash attn v2")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import megablocks.ops as ops
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Mixtral model requires megablocks to be installed")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import stk
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Mixtral model requires stk to be installed")
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralConfig(PretrainedConfig):
|
||||||
|
model_type = "mixtral"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=14336,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096 * 32,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-05,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
sliding_window=4096,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
num_local_experts=8,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.num_local_experts = num_local_experts
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x.view(1) if len(x.size()) == 0 else x
|
||||||
|
|
||||||
|
|
||||||
|
def load_attention(config, prefix, weights):
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
|
return _load_gqa(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gqa(config, prefix: str, weights):
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
quantize=config.quantize,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.quantize not in ["gptq", "awq"]:
|
||||||
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
|
assert list(weight.shape) == [
|
||||||
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
|
config.hidden_size,
|
||||||
|
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(
|
||||||
|
get_linear(weight, bias=None, quantize=config.quantize)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_experts(config, prefix, mat, weights):
|
||||||
|
if config.quantize is not None:
|
||||||
|
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
||||||
|
|
||||||
|
assert mat in ["w1", "w2", "w3"]
|
||||||
|
|
||||||
|
world_size = weights.process_group.size()
|
||||||
|
rank = weights.process_group.rank()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
config.intermediate_size % world_size == 0
|
||||||
|
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
|
||||||
|
|
||||||
|
block_size = config.intermediate_size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
|
||||||
|
tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size),
|
||||||
|
dtype=weights.dtype,
|
||||||
|
device=weights.device)
|
||||||
|
|
||||||
|
for i in range(config.num_local_experts):
|
||||||
|
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
||||||
|
|
||||||
|
if mat == "w2":
|
||||||
|
expert_slice = slice_[:, start:stop].t().contiguous()
|
||||||
|
else:
|
||||||
|
expert_slice = slice_[start:stop]
|
||||||
|
tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralAttention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.max_past = (
|
||||||
|
config.sliding_window if config.sliding_window is not None else 0
|
||||||
|
)
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.softmax_scale = self.head_size ** -0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
|
else:
|
||||||
|
kv_to_cache = kv
|
||||||
|
|
||||||
|
paged_attention.reshape_and_cache(
|
||||||
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
|
)
|
||||||
|
|
||||||
|
# output tensor
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# flash attention
|
||||||
|
flash_attn.attention(
|
||||||
|
query,
|
||||||
|
torch.select(kv, dim=1, index=0),
|
||||||
|
torch.select(kv, dim=1, index=1),
|
||||||
|
attn_output,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
max_s,
|
||||||
|
self.softmax_scale,
|
||||||
|
window_size_left=self.max_past,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
paged_attention.attention(
|
||||||
|
attn_output,
|
||||||
|
query,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def select_experts(gate_logits: torch.Tensor, top_k: int):
|
||||||
|
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||||
|
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||||
|
# weights, selected_experts: (sequence_length, top-k)
|
||||||
|
weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
|
||||||
|
weights /= weights.sum(dim=-1, keepdim=True)
|
||||||
|
weights = weights.view(-1)
|
||||||
|
selected_experts = selected_experts.view(-1)
|
||||||
|
|
||||||
|
return selected_experts, weights
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def round_up(x: torch.Tensor, value: int):
|
||||||
|
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
|
||||||
|
|
||||||
|
|
||||||
|
class BlockSparseMoE(nn.Module):
|
||||||
|
"""
|
||||||
|
Built on the paper and library Megablocks as described in
|
||||||
|
https://arxiv.org/abs/2211.15841. This implementation is
|
||||||
|
strictly equivalent to standard MoE with full capacity (no
|
||||||
|
dropped tokens). It's faster since it formulates MoE operations
|
||||||
|
in terms of block-sparse operations to accomodate imbalanced
|
||||||
|
assignments of tokens to experts, whereas standard MoE either
|
||||||
|
(1) drop tokens at the cost of reduced performance or (2) set
|
||||||
|
capacity factor to number of experts and thus waste computation
|
||||||
|
and memory on padding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefix, config: MixtralConfig, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_dim = config.hidden_size
|
||||||
|
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
||||||
|
self.num_experts = config.num_local_experts
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
|
||||||
|
act = config.hidden_act
|
||||||
|
if "gelu" in act:
|
||||||
|
self.act = lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate="tanh"
|
||||||
|
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none",
|
||||||
|
)
|
||||||
|
elif "silu" in act:
|
||||||
|
self.act = torch.nn.functional.silu
|
||||||
|
else:
|
||||||
|
self.act = ACT2FN[act]
|
||||||
|
|
||||||
|
# gating
|
||||||
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||||
|
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).t()
|
||||||
|
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||||
|
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).t()
|
||||||
|
|
||||||
|
self.offsets = None
|
||||||
|
self.offsets_block_rows = 0
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
# Calculate the number of bits needed to represent the expert indices
|
||||||
|
# so that we can pass it to radix sort.
|
||||||
|
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
|
||||||
|
self.blocking = 128
|
||||||
|
self.quantize_scatter_num_bits = -1
|
||||||
|
|
||||||
|
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
|
||||||
|
padded_tokens, _ = x.size()
|
||||||
|
assert padded_tokens % self.blocking == 0
|
||||||
|
assert self.ffn_dim % self.blocking == 0
|
||||||
|
|
||||||
|
# Offsets for the sparse matrix. All rows have the
|
||||||
|
# same number of nonzero blocks dictated by the
|
||||||
|
# dimensionality of a single expert.
|
||||||
|
block_rows = padded_tokens // self.blocking
|
||||||
|
blocks_per_row = self.ffn_dim // self.blocking
|
||||||
|
if self.offsets is None or block_rows > self.offsets_block_rows:
|
||||||
|
self.offsets = torch.arange(
|
||||||
|
0,
|
||||||
|
block_rows * blocks_per_row + 1,
|
||||||
|
blocks_per_row,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
self.offsets_block_rows = block_rows
|
||||||
|
offsets = self.offsets
|
||||||
|
else:
|
||||||
|
offsets = self.offsets[:block_rows]
|
||||||
|
|
||||||
|
# Indices for the sparse matrix. The indices for
|
||||||
|
# the intermediate matrix are dynamic depending
|
||||||
|
# on the mapping of tokens to experts.
|
||||||
|
column_indices = ops.topology(padded_bins, self.blocking, block_rows,
|
||||||
|
blocks_per_row)
|
||||||
|
|
||||||
|
# For now, use meta init to save the device memory.
|
||||||
|
data = torch.empty(
|
||||||
|
column_indices.numel(),
|
||||||
|
self.blocking,
|
||||||
|
self.blocking,
|
||||||
|
dtype=x.dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
shape = (padded_tokens, self.ffn_dim * self.num_experts)
|
||||||
|
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
||||||
|
return stk.Matrix(
|
||||||
|
shape,
|
||||||
|
data,
|
||||||
|
row_indices,
|
||||||
|
column_indices,
|
||||||
|
offsets,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
|
||||||
|
# Sort the expert ids to produce the scatter/gather
|
||||||
|
# indices for the permutation.
|
||||||
|
# selected_experts = selected_experts.int()
|
||||||
|
|
||||||
|
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
|
||||||
|
# and indices == how to sort tokens?
|
||||||
|
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
|
||||||
|
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
|
||||||
|
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
|
||||||
|
|
||||||
|
# Histogram the expert ids to identify the number of
|
||||||
|
# tokens routed to each expert.
|
||||||
|
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
|
||||||
|
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
|
||||||
|
|
||||||
|
# Round the token counts up to the block size used in
|
||||||
|
# the matrix muliplications. Caculate the starting
|
||||||
|
# position of each bin.
|
||||||
|
|
||||||
|
# List of size num_experts
|
||||||
|
padded_tokens_per_expert = round_up(tokens_per_expert,
|
||||||
|
self.blocking)
|
||||||
|
# padded_tokens_per_expert => [128, O, 128, ...]
|
||||||
|
|
||||||
|
# Cumulative selected experts per token
|
||||||
|
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
||||||
|
padded_bins = promote_scalar(padded_bins)
|
||||||
|
# padded_bins => [128, 128, 256, ...]
|
||||||
|
|
||||||
|
# Calculate the bin bounds for the sorted tokens.
|
||||||
|
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
||||||
|
bins = promote_scalar(bins)
|
||||||
|
# bins => [3, 3, 5, ...]
|
||||||
|
|
||||||
|
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: (sequence_length, model_dim)
|
||||||
|
gate_logits: (sequence_length, n_experts)
|
||||||
|
"""
|
||||||
|
# optional reshape
|
||||||
|
input_shape = x.shape
|
||||||
|
x = x.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
# gate_logits: (sequence_length, n_experts)
|
||||||
|
gate_logits = self.gate(x)
|
||||||
|
selected_experts, weights = select_experts(gate_logits, self.top_k)
|
||||||
|
|
||||||
|
(
|
||||||
|
indices,
|
||||||
|
bin_ids,
|
||||||
|
bins,
|
||||||
|
padded_bins,
|
||||||
|
_,
|
||||||
|
) = self.indices_and_padded_bins(selected_experts)
|
||||||
|
|
||||||
|
# Permute tokens and pad to prepare expert computation
|
||||||
|
# (top_k * sequence_length + padding, model_dim)
|
||||||
|
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
|
||||||
|
self.top_k)
|
||||||
|
|
||||||
|
# Create the sparse matrix topology
|
||||||
|
with torch.no_grad():
|
||||||
|
topo = self.topology(x, padded_bins)
|
||||||
|
|
||||||
|
# Perform the expert computation
|
||||||
|
# First Dense x Dense -> Sparse for w1 and w3,
|
||||||
|
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
||||||
|
x = stk.Matrix(
|
||||||
|
topo.size(),
|
||||||
|
self.act(stk.ops.sdd(x, self.w1, topo).data) *
|
||||||
|
stk.ops.sdd(x, self.w3, topo).data,
|
||||||
|
topo.row_indices,
|
||||||
|
topo.column_indices,
|
||||||
|
topo.offsets,
|
||||||
|
topo.column_indices_t,
|
||||||
|
topo.offsets_t,
|
||||||
|
topo.block_offsets_t,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then Sparse x Dense -> Dense for w2
|
||||||
|
# (top_k * sequence_length + padding, model_dim)
|
||||||
|
x = stk.ops.dsd(x, self.w2)
|
||||||
|
|
||||||
|
# Permute back and remove padding
|
||||||
|
# (sequence_length, model_dim)
|
||||||
|
x = ops.padded_scatter(
|
||||||
|
x,
|
||||||
|
indices,
|
||||||
|
bin_ids,
|
||||||
|
weights,
|
||||||
|
bins,
|
||||||
|
padded_bins,
|
||||||
|
self.top_k,
|
||||||
|
self.quantize_scatter_num_bits,
|
||||||
|
).view(*input_shape)
|
||||||
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(x, group=self.process_group)
|
||||||
|
|
||||||
|
return x.view(*input_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralLayer(nn.Module):
|
||||||
|
def __init__(self, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"model.layers.{layer_id}"
|
||||||
|
|
||||||
|
self.self_attn = MixtralAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
|
||||||
|
|
||||||
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
):
|
||||||
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
|
attn_output, res
|
||||||
|
)
|
||||||
|
|
||||||
|
block_sparse_moe_output = self.block_sparse_moe(normed_attn_res_output)
|
||||||
|
|
||||||
|
return block_sparse_moe_output, attn_res
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralModel(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix="model.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MixtralLayer(
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = FastRMSNorm.load(
|
||||||
|
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
|
position_ids, max_s, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMixtralForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = MixtralModel(config, weights)
|
||||||
|
self.lm_head = TensorParallelHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.max_past = config.sliding_window
|
||||||
|
if self.max_past is None:
|
||||||
|
raise ValueError("max_past cannot be None")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
|
slots = slots[prefill_cache_indices]
|
||||||
|
else:
|
||||||
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
|
# kernel requires the true values
|
||||||
|
max_s = min(self.max_past, max_s)
|
||||||
|
input_lengths = torch.clamp(input_lengths, max=self.max_past)
|
||||||
|
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
return logits
|
|
@ -6,7 +6,6 @@ from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.flash_attn import attention
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
|
|
@ -8,14 +8,13 @@ from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.models.llama import LlamaTokenizerFast
|
from transformers.models.llama import LlamaTokenizerFast
|
||||||
from typing import Optional, Tuple, Type
|
from typing import Optional, Tuple, Type, List
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
|
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
|
||||||
from text_generation_server.models.cache_manager import (
|
from text_generation_server.models.cache_manager import (
|
||||||
get_cache_manager,
|
get_cache_manager,
|
||||||
set_cache_manager,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
FlashMistralForCausalLM,
|
FlashMistralForCausalLM,
|
||||||
|
@ -46,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
global SLIDING_WINDOW
|
global SLIDING_WINDOW
|
||||||
global SLIDING_WINDOW_BLOCKS
|
global SLIDING_WINDOW_BLOCKS
|
||||||
|
@ -100,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
zip(pb.requests, batch_tokenized_inputs)
|
zip(pb.requests, batch_tokenized_inputs)
|
||||||
):
|
):
|
||||||
# request id -> idx in list mapping
|
# request id -> idx in list mapping
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
|
|
||||||
tokenized_input = tokenized_input[-r.truncate :]
|
tokenized_input = tokenized_input[-r.truncate:]
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
@ -278,14 +277,16 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashMistral(FlashCausalLM):
|
class BaseFlashMistral(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
config_cls,
|
||||||
revision: Optional[str] = None,
|
model_cls,
|
||||||
quantize: Optional[str] = None,
|
model_id: str,
|
||||||
dtype: Optional[torch.dtype] = None,
|
revision: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
global SLIDING_WINDOW
|
global SLIDING_WINDOW
|
||||||
global SLIDING_WINDOW_BLOCKS
|
global SLIDING_WINDOW_BLOCKS
|
||||||
|
@ -305,7 +306,7 @@ class FlashMistral(FlashCausalLM):
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = MistralConfig.from_pretrained(
|
config = config_cls.from_pretrained(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
@ -321,10 +322,10 @@ class FlashMistral(FlashCausalLM):
|
||||||
if config.quantize in ["gptq", "awq"]:
|
if config.quantize in ["gptq", "awq"]:
|
||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id)
|
||||||
|
|
||||||
model = FlashMistralForCausalLM(config, weights)
|
model = model_cls(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashMistral, self).__init__(
|
super(BaseFlashMistral, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_layers=len(model.model.layers),
|
num_layers=len(model.model.layers),
|
||||||
|
@ -396,3 +397,23 @@ class FlashMistral(FlashCausalLM):
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMistral(BaseFlashMistral):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
super(FlashMistral, self).__init__(
|
||||||
|
config_cls=MistralConfig,
|
||||||
|
model_cls=FlashMistralForCausalLM,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from text_generation_server.models.flash_mistral import BaseFlashMistral
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import MixtralConfig, FlashMixtralForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMixtral(BaseFlashMistral):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
super(FlashMixtral, self).__init__(
|
||||||
|
config_cls=MixtralConfig,
|
||||||
|
model_cls=FlashMixtralForCausalLM,
|
||||||
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code
|
||||||
|
)
|
|
@ -18,7 +18,7 @@ except ImportError:
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
|
|
||||||
HAS_AWQ = True
|
HAS_AWQ = True
|
||||||
try:
|
try:
|
||||||
|
@ -43,16 +43,18 @@ if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
elif CAN_EXLLAMA:
|
elif CAN_EXLLAMA:
|
||||||
try:
|
try:
|
||||||
if V2:
|
if V2:
|
||||||
from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear,
|
from text_generation_server.utils.gptq.exllamav2 import (QuantLinear as ExllamaQuantLinear,
|
||||||
create_exllama_buffers,
|
create_exllama_buffers,
|
||||||
set_device,
|
set_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_EXLLAMA = "2"
|
HAS_EXLLAMA = "2"
|
||||||
else:
|
else:
|
||||||
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear,
|
from text_generation_server.utils.gptq.exllama import (Ex4bitLinear as ExllamaQuantLinear,
|
||||||
create_exllama_buffers,
|
create_exllama_buffers,
|
||||||
set_device,
|
set_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
HAS_EXLLAMA = "1"
|
HAS_EXLLAMA = "1"
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -112,7 +114,7 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_conv2d_no_bias(
|
def load_conv2d_no_bias(
|
||||||
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
|
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
|
||||||
):
|
):
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
|
@ -136,9 +138,9 @@ torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||||
|
|
||||||
class FastLinear(nn.Module):
|
class FastLinear(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
bias,
|
bias,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(weight)
|
self.weight = nn.Parameter(weight)
|
||||||
|
@ -162,9 +164,9 @@ class FastLinear(nn.Module):
|
||||||
|
|
||||||
class EETQLinear(nn.Module):
|
class EETQLinear(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
bias,
|
bias,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
device = weight.device
|
device = weight.device
|
||||||
|
@ -183,13 +185,13 @@ class EETQLinear(nn.Module):
|
||||||
|
|
||||||
class Linear8bitLt(nn.Module):
|
class Linear8bitLt(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
bias,
|
bias,
|
||||||
has_fp16_weights=True,
|
has_fp16_weights=True,
|
||||||
memory_efficient_backward=False,
|
memory_efficient_backward=False,
|
||||||
threshold=0.0,
|
threshold=0.0,
|
||||||
index=None,
|
index=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert (
|
assert (
|
||||||
|
@ -526,9 +528,12 @@ class TensorParallelEmbedding(nn.Module):
|
||||||
try:
|
try:
|
||||||
if IS_CUDA_SYSTEM:
|
if IS_CUDA_SYSTEM:
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
from vllm import layernorm_ops
|
||||||
else:
|
else:
|
||||||
dropout_layer_norm = None
|
dropout_layer_norm = None
|
||||||
|
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||||
|
@ -563,10 +568,81 @@ try:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
return normed_hidden_states, residual
|
return normed_hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class FastRMSNorm(nn.Module):
|
||||||
|
def __init__(self, weight: torch.Tensor, eps: float):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(weight)
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, prefix, weights, eps=1e-6):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
return cls(weight, eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
if hidden_states.shape[-1] > 8192:
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(
|
||||||
|
variance + self.variance_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert into half-precision if necessary
|
||||||
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
|
||||||
|
return self.weight * hidden_states, residual
|
||||||
|
elif IS_CUDA_SYSTEM:
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.weight,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.variance_epsilon,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
True, # Activate RMSNorm
|
||||||
|
)
|
||||||
|
if res is None:
|
||||||
|
res = hidden_states
|
||||||
|
|
||||||
|
return normed_hidden_states, res
|
||||||
|
elif IS_ROCM_SYSTEM:
|
||||||
|
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
out = torch.empty_like(hidden_states)
|
||||||
|
layernorm_ops.rms_norm(
|
||||||
|
out,
|
||||||
|
hidden_states,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return out, residual
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if IS_CUDA_SYSTEM:
|
if IS_CUDA_SYSTEM:
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
|
@ -574,12 +650,14 @@ try:
|
||||||
elif IS_ROCM_SYSTEM:
|
elif IS_ROCM_SYSTEM:
|
||||||
from vllm import pos_encoding_ops
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
|
|
||||||
def _create_inv_freq(dim, base, device):
|
def _create_inv_freq(dim, base, device):
|
||||||
inv_freq = 1.0 / (
|
inv_freq = 1.0 / (
|
||||||
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||||
)
|
)
|
||||||
return inv_freq
|
return inv_freq
|
||||||
|
|
||||||
|
|
||||||
def _get_rope_config(config):
|
def _get_rope_config(config):
|
||||||
if os.getenv("ROPE_SCALING", None) is not None:
|
if os.getenv("ROPE_SCALING", None) is not None:
|
||||||
rope_scaling = {
|
rope_scaling = {
|
||||||
|
@ -589,6 +667,7 @@ try:
|
||||||
return rope_scaling
|
return rope_scaling
|
||||||
return getattr(config, "rope_scaling", None)
|
return getattr(config, "rope_scaling", None)
|
||||||
|
|
||||||
|
|
||||||
class PositionRotaryEmbedding(nn.Module):
|
class PositionRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, inv_freq, scaling_factor):
|
def __init__(self, inv_freq, scaling_factor):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -606,12 +685,12 @@ try:
|
||||||
if IS_CUDA_SYSTEM:
|
if IS_CUDA_SYSTEM:
|
||||||
rotary_dim = cos.shape[-1]
|
rotary_dim = cos.shape[-1]
|
||||||
q1 = query[..., :rotary_dim]
|
q1 = query[..., :rotary_dim]
|
||||||
q2 = query[..., rotary_dim : 2 * rotary_dim]
|
q2 = query[..., rotary_dim: 2 * rotary_dim]
|
||||||
|
|
||||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||||
|
|
||||||
k1 = key[..., :rotary_dim]
|
k1 = key[..., :rotary_dim]
|
||||||
k2 = key[..., rotary_dim : 2 * rotary_dim]
|
k2 = key[..., rotary_dim: 2 * rotary_dim]
|
||||||
|
|
||||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||||
elif IS_ROCM_SYSTEM:
|
elif IS_ROCM_SYSTEM:
|
||||||
|
@ -630,7 +709,8 @@ try:
|
||||||
True
|
True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
raise ValueError(
|
||||||
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def static(cls, config, dim, base, device):
|
def static(cls, config, dim, base, device):
|
||||||
|
@ -713,9 +793,9 @@ try:
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
if (
|
if (
|
||||||
seqlen > self._seq_len_cached
|
seqlen > self._seq_len_cached
|
||||||
or self._cos_cached.device != device
|
or self._cos_cached.device != device
|
||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
@ -729,7 +809,7 @@ try:
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
def get_cos_sin(
|
def get_cos_sin(
|
||||||
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
|
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Return cos and sin for the asked position ids
|
Return cos and sin for the asked position ids
|
||||||
|
@ -747,6 +827,7 @@ try:
|
||||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
@ -755,18 +836,18 @@ try:
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.base = base
|
self.base = base
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
if (
|
if (
|
||||||
seqlen > self._seq_len_cached
|
seqlen > self._seq_len_cached
|
||||||
or self._cos_cached.device != device
|
or self._cos_cached.device != device
|
||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
if seqlen > self.max_position_embeddings:
|
if seqlen > self.max_position_embeddings:
|
||||||
newbase = self.base * (
|
newbase = self.base * (
|
||||||
(self.scaling_factor * seqlen / self.max_position_embeddings)
|
(self.scaling_factor * seqlen / self.max_position_embeddings)
|
||||||
- (self.scaling_factor - 1)
|
- (self.scaling_factor - 1)
|
||||||
) ** (self.dim / (self.dim - 2))
|
) ** (self.dim / (self.dim - 2))
|
||||||
self.inv_freq = _create_inv_freq(
|
self.inv_freq = _create_inv_freq(
|
||||||
self.dim, newbase, self.inv_freq.device
|
self.dim, newbase, self.inv_freq.device
|
||||||
|
@ -783,8 +864,11 @@ try:
|
||||||
|
|
||||||
# Inverse dim formula to find dim based on number of rotations
|
# Inverse dim formula to find dim based on number of rotations
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
||||||
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))
|
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
||||||
|
|
||||||
|
|
||||||
# Find dim range bounds based on rotations
|
# Find dim range bounds based on rotations
|
||||||
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
||||||
|
@ -792,7 +876,8 @@ try:
|
||||||
low_rot, dim, base, max_position_embeddings))
|
low_rot, dim, base, max_position_embeddings))
|
||||||
high = math.ceil(find_correction_dim(
|
high = math.ceil(find_correction_dim(
|
||||||
high_rot, dim, base, max_position_embeddings))
|
high_rot, dim, base, max_position_embeddings))
|
||||||
return max(low, 0), min(high, dim-1) # Clamp values just in case
|
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||||
|
|
||||||
|
|
||||||
def linear_ramp_mask(min, max, dim):
|
def linear_ramp_mask(min, max, dim):
|
||||||
if min == max:
|
if min == max:
|
||||||
|
@ -802,13 +887,16 @@ try:
|
||||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||||
return ramp_func
|
return ramp_func
|
||||||
|
|
||||||
|
|
||||||
def get_mscale(scale=1):
|
def get_mscale(scale=1):
|
||||||
if scale <= 1:
|
if scale <= 1:
|
||||||
return 1.0
|
return 1.0
|
||||||
return 0.1 * math.log(scale) + 1.0
|
return 0.1 * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor,*, extrapolation_factor, attn_factor, beta_fast, beta_slow):
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor, *, extrapolation_factor,
|
||||||
|
attn_factor, beta_fast, beta_slow):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
super().__init__(inv_freq, scaling_factor)
|
super().__init__(inv_freq, scaling_factor)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
@ -818,15 +906,16 @@ try:
|
||||||
self.attn_factor = attn_factor
|
self.attn_factor = attn_factor
|
||||||
self.beta_fast = beta_fast
|
self.beta_fast = beta_fast
|
||||||
self.beta_slow = beta_slow
|
self.beta_slow = beta_slow
|
||||||
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
self.mscale = float(get_mscale(
|
||||||
|
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
# Reset the tables if the sequence length has changed,
|
# Reset the tables if the sequence length has changed,
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
if (
|
if (
|
||||||
seqlen > self._seq_len_cached
|
seqlen > self._seq_len_cached
|
||||||
or self._cos_cached.device != device
|
or self._cos_cached.device != device
|
||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
if seqlen > self.max_position_embeddings:
|
if seqlen > self.max_position_embeddings:
|
||||||
inv_freq_extrapolation = _create_inv_freq(
|
inv_freq_extrapolation = _create_inv_freq(
|
||||||
|
@ -834,13 +923,15 @@ try:
|
||||||
)
|
)
|
||||||
freqs = 1.0 / inv_freq_extrapolation
|
freqs = 1.0 / inv_freq_extrapolation
|
||||||
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
|
||||||
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.max_position_embeddings)
|
low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base,
|
||||||
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
self.max_position_embeddings)
|
||||||
|
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(
|
||||||
|
device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
|
||||||
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||||
|
|
||||||
self.inv_freq = inv_freq
|
self.inv_freq = inv_freq
|
||||||
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
self.mscale = float(get_mscale(
|
||||||
|
self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation
|
||||||
|
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
|
Loading…
Reference in New Issue