Add support for Deepseek V2 (#2224)

Deepseek V2 is a MoE model from Deepseek. Relevant variations
compared to other models:

- Grouped top-K in expert selection.
- mscale in yarn is calculated using the `mscale` and `mscale_all_dim`
  configuration options.
- `mscale_all_dim` is also used in scaling attention softmax.
- Permuting of the query/key representations before applying rotary
  embeddings.
- Some projections cannot be sharded (`q_a_proj`, `kv_a_proj_with_mqa`).
  So, we need weight loads that supports quantized weights. To this
  end `{Weights,WeightLoader}.get_weight` was added.
- The query/key head dimensionality differs from that of the value,
  so we need to pad during attention.
- Heads with size 192, needs an extension to our paged attention
  fork and we need to ensure that the KV cache is allocated with the
  correct size.
- Shared experts.
This commit is contained in:
Daniël de Kok 2024-07-19 17:23:20 +02:00 committed by GitHub
parent 68a9685f1b
commit e52be9bba2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1836 additions and 51 deletions

View File

@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models on specific hardware
## Supported Models ## Supported Models
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<begin▁of▁sentence>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.1875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.84375,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.34375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.8359375,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.0859375,
"special": false,
"text": " is"
},
{
"id": 254,
"logprob": -1.5390625,
"special": false,
"text": " the"
},
{
"id": 1022,
"logprob": -1.1875,
"special": false,
"text": " first"
},
{
"id": 3458,
"logprob": -0.35546875,
"special": false,
"text": " step"
},
{
"id": 279,
"logprob": -0.8828125,
"special": false,
"text": " in"
},
{
"id": 254,
"logprob": -0.71484375,
"special": false,
"text": " the"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is the first step in the"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<begin▁of▁sentence>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.1875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 2143,
"logprob": -1.828125,
"special": false,
"text": " sent"
},
{
"id": 10081,
"logprob": -0.36914062,
"special": false,
"text": " successfully"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 185,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 1380,
"logprob": -0.38671875,
"special": false,
"text": "We"
},
{
"id": 543,
"logprob": -0.12695312,
"special": false,
"text": " will"
},
{
"id": 752,
"logprob": -0.20117188,
"special": false,
"text": " get"
},
{
"id": 279,
"logprob": 0.0,
"special": false,
"text": " in"
},
{
"id": 5402,
"logprob": 0.0,
"special": false,
"text": " touch"
},
{
"id": 366,
"logprob": 0.0,
"special": false,
"text": " with"
}
],
"top_tokens": null
},
"generated_text": "Test request sent successfully.\nWe will get in touch with"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<begin▁of▁sentence>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.1875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.8125,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.890625,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5390625,
"special": false,
"text": " a"
},
{
"id": 3102,
"logprob": -2.609375,
"special": false,
"text": " request"
},
{
"id": 327,
"logprob": -0.75,
"special": false,
"text": " for"
},
{
"id": 245,
"logprob": -1.1171875,
"special": false,
"text": " a"
},
{
"id": 1727,
"logprob": -0.90625,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a request for a test"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<begin▁of▁sentence>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.25,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.8125,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.890625,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5390625,
"special": false,
"text": " a"
},
{
"id": 3102,
"logprob": -2.609375,
"special": false,
"text": " request"
},
{
"id": 327,
"logprob": -0.75,
"special": false,
"text": " for"
},
{
"id": 245,
"logprob": -1.1171875,
"special": false,
"text": " a"
},
{
"id": 1727,
"logprob": -0.90625,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a request for a test"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<begin▁of▁sentence>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.25,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.8125,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.890625,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5390625,
"special": false,
"text": " a"
},
{
"id": 3102,
"logprob": -2.609375,
"special": false,
"text": " request"
},
{
"id": 327,
"logprob": -0.75,
"special": false,
"text": " for"
},
{
"id": 245,
"logprob": -1.1171875,
"special": false,
"text": " a"
},
{
"id": 1727,
"logprob": -0.90625,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a request for a test"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 100000,
"logprob": null,
"text": "<begin▁of▁sentence>"
},
{
"id": 3533,
"logprob": -9.625,
"text": "Test"
},
{
"id": 3102,
"logprob": -11.25,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 185,
"logprob": -1.5546875,
"special": false,
"text": "\n"
},
{
"id": 549,
"logprob": -2.8125,
"special": false,
"text": "The"
},
{
"id": 1727,
"logprob": -2.375,
"special": false,
"text": " test"
},
{
"id": 3102,
"logprob": -0.890625,
"special": false,
"text": " request"
},
{
"id": 317,
"logprob": -1.1484375,
"special": false,
"text": " is"
},
{
"id": 245,
"logprob": -1.5390625,
"special": false,
"text": " a"
},
{
"id": 3102,
"logprob": -2.609375,
"special": false,
"text": " request"
},
{
"id": 327,
"logprob": -0.75,
"special": false,
"text": " for"
},
{
"id": 245,
"logprob": -1.1171875,
"special": false,
"text": " a"
},
{
"id": 1727,
"logprob": -0.90625,
"special": false,
"text": " test"
}
],
"top_tokens": null
},
"generated_text": "\nThe test request is a request for a test"
}
]

View File

@ -0,0 +1,63 @@
import pytest
@pytest.fixture(scope="module")
def flash_deepseek_v2_handle(launcher):
with launcher("deepseek-ai/DeepSeek-V2-Lite", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_deepseek_v2(flash_deepseek_v2_handle):
await flash_deepseek_v2_handle.health(300)
return flash_deepseek_v2_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_deepseek_v2(flash_deepseek_v2, response_snapshot):
response = await flash_deepseek_v2.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_deepseek_v2_all_params(flash_deepseek_v2, response_snapshot):
response = await flash_deepseek_v2.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_deepseek_v2_load(
flash_deepseek_v2, generate_load, response_snapshot
):
responses = await generate_load(
flash_deepseek_v2, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -1,14 +1,14 @@
commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
build-vllm-cuda: build-vllm-cuda:
if [ ! -d 'vllm' ]; then \ if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \ pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/Narsil/vllm.git vllm; \ git clone https://github.com/Narsil/vllm.git vllm; \
fi fi
cd vllm && git fetch && git checkout $(commit_cuda) && python setup.py build cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
install-vllm-cuda: build-vllm-cuda install-vllm-cuda: build-vllm-cuda
cd vllm && git fetch && git checkout $(commit_cuda) && pip install -e . cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e .
build-vllm-rocm: build-vllm-rocm:
if [ ! -d 'vllm' ]; then \ if [ ! -d 'vllm' ]; then \

View File

@ -34,6 +34,30 @@ class Exl2Weight(Weight):
class Exl2WeightsLoader(WeightsLoader): class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights.""" """Loader for exl2-quantized weights."""
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
def get_weights_col_packed( def get_weights_col_packed(
self, self,
weights: Weights, weights: Weights,
@ -43,46 +67,12 @@ class Exl2WeightsLoader(WeightsLoader):
raise RuntimeError("Column-packed weights are not supported for exl") raise RuntimeError("Column-packed weights are not supported for exl")
def get_weights_col(self, weights: Weights, prefix: str): def get_weights_col(self, weights: Weights, prefix: str):
try: # Sharding is not yet supported, so we return the weights as-is.
q_weight = weights.get_tensor(f"{prefix}.q_weight") return self.get_weights(weights, prefix)
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
raise ValueError("get_multi_weights_col is not supported for exl2") raise ValueError("get_multi_weights_col is not supported for exl2")
def get_weights_row(self, weights: Weights, prefix: str): def get_weights_row(self, weights: Weights, prefix: str):
try: # Sharding is not yet supported, so we return the weights as-is.
q_weight = weights.get_tensor(f"{prefix}.q_weight") return self.get_weights(weights, prefix)
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)

View File

@ -134,6 +134,115 @@ class GPTQWeightsLoader(WeightsLoader):
self.quantize = quantize self.quantize = quantize
self.sym = sym self.sym = sym
def get_weights(self, weights: Weights, prefix: str):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = weights.get_tensor(f"{prefix}.g_idx")
scales = weights.get_tensor(f"{prefix}.scales")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
use_exllama = True
if self.bits != 4:
use_exllama = False
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def get_weights_col_packed( def get_weights_col_packed(
self, self,
weights: Weights, weights: Weights,

View File

@ -33,6 +33,35 @@ class MarlinWeightsLoader(WeightsLoader):
self.bits = bits self.bits = bits
self.is_marlin_24 = is_marlin_24 self.is_marlin_24 = is_marlin_24
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = weights.get_tensor(f"{prefix}.B_24")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_tensor(f"{prefix}.B_meta")
s = weights.get_tensor(f"{prefix}.s")
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_tensor(f"{prefix}.B")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
s = weights.get_tensor(f"{prefix}.s")
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_col_packed( def get_weights_col_packed(
self, self,
weights: Weights, weights: Weights,

View File

@ -1,6 +1,7 @@
import os import os
import torch import torch
from torch import nn from torch import nn
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
@ -97,6 +98,8 @@ class PositionRotaryEmbedding(nn.Module):
) )
elif rope_scaling["type"] == "yarn": elif rope_scaling["type"] == "yarn":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
return YarnPositionRotaryEmbedding( return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0], dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[ max_position_embeddings=rope_scaling[
@ -109,6 +112,8 @@ class PositionRotaryEmbedding(nn.Module):
attn_factor=1, attn_factor=1,
beta_fast=32, beta_fast=32,
beta_slow=1, beta_slow=1,
mscale=mscale,
mscale_all_dim=mscale_all_dim,
) )
elif rope_scaling["type"] in ["su", "longrope"]: elif rope_scaling["type"] in ["su", "longrope"]:
short_factor = torch.tensor( short_factor = torch.tensor(
@ -181,6 +186,8 @@ class PositionRotaryEmbedding(nn.Module):
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
) )
elif rope_scaling["type"] == "yarn": elif rope_scaling["type"] == "yarn":
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
return YarnPositionRotaryEmbedding( return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0], dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[ max_position_embeddings=rope_scaling[
@ -193,6 +200,8 @@ class PositionRotaryEmbedding(nn.Module):
attn_factor=1, attn_factor=1,
beta_fast=32, beta_fast=32,
beta_slow=1, beta_slow=1,
mscale=mscale,
mscale_all_dim=mscale_all_dim,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
@ -346,10 +355,10 @@ def linear_ramp_mask(min, max, dim):
return ramp_func return ramp_func
def get_mscale(scale=1): def get_mscale(scale: float = 1.0, mscale: float = 1.0):
if scale <= 1: if scale <= 1:
return 1.0 return 1.0
return 0.1 * math.log(scale) + 1.0 return 0.1 * mscale * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
@ -365,6 +374,8 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
attn_factor, attn_factor,
beta_fast, beta_fast,
beta_slow, beta_slow,
mscale: float,
mscale_all_dim: float,
): ):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor) super().__init__(inv_freq, scaling_factor)
@ -375,8 +386,12 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
self.attn_factor = attn_factor self.attn_factor = attn_factor
self.beta_fast = beta_fast self.beta_fast = beta_fast
self.beta_slow = beta_slow self.beta_slow = beta_slow
self.mscale_all_dim = mscale_all_dim
self.scaling_factor = scaling_factor
self.mscale = float( self.mscale = float(
get_mscale(self.scaling_factor) * self.attn_factor get_mscale(self.scaling_factor, mscale)
/ get_mscale(self.scaling_factor, mscale_all_dim)
* self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation ) # Get n-d magnitude scaling corrected for interpolation
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
@ -387,7 +402,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
or self._cos_cached.device != device or self._cos_cached.device != device
or self._cos_cached.dtype != dtype or self._cos_cached.dtype != dtype
): ):
if seqlen > self.max_position_embeddings: if seqlen > self.max_position_embeddings or True:
inv_freq_extrapolation = _create_inv_freq( inv_freq_extrapolation = _create_inv_freq(
self.dim, self.base, self.inv_freq.device self.dim, self.base, self.inv_freq.device
) )
@ -400,6 +415,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
self.base, self.base,
self.max_position_embeddings, self.max_position_embeddings,
) )
inv_freq_mask = ( inv_freq_mask = (
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
@ -409,9 +425,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
) )
self.inv_freq = inv_freq self.inv_freq = inv_freq
self.mscale = float(
get_mscale(self.scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)

View File

@ -61,6 +61,10 @@ FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
) )
@ -141,6 +145,11 @@ if MAMBA_AVAILABLE:
class ModelType(enum.Enum): class ModelType(enum.Enum):
DEEPSEEK_V2 = {
"type": "deepseek_v2",
"name": "Deepseek V2",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
}
IDEFICS2 = { IDEFICS2 = {
"type": "idefics2", "type": "idefics2",
"name": "Idefics 2", "name": "Idefics 2",
@ -459,7 +468,40 @@ def get_model(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
) )
if model_type == MAMBA: if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV2Config,
head_size=head_size,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
)
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == MAMBA:
return Mamba( return Mamba(
model_id, model_id,
revision, revision,

View File

@ -0,0 +1,983 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.distributed
from text_generation_server.layers import (
FastLinear,
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention import (
attention,
paged_attention,
reshape_and_cache,
)
from text_generation_server.layers.attention.common import Seqlen
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
class DeepseekV2Config(PretrainedConfig):
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=2,
n_routed_experts=160,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=8,
topk_group=3,
num_experts_per_tok=6,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Deepseek V2 models."
)
if ep_size != 1:
raise ValueError(
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _load_experts(config, prefix: str, mat: str, weights: Weights):
if config.quantize is not None:
raise NotImplementedError(
"Deepseek V2 does not support weight quantization yet."
)
assert mat in ["gate_proj", "up_proj", "down_proj"]
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.moe_intermediate_size % world_size == 0
), f"The chosen size {config.moe_intermediate_size} is not compatible with sharding on {world_size} shards"
block_size = config.moe_intermediate_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = torch.empty(
(config.n_routed_experts * block_size, config.hidden_size),
dtype=weights.dtype,
device=weights.device,
)
for i in range(config.n_routed_experts):
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
if mat == "down_proj":
expert_slice = slice_[:, start:stop].t().contiguous()
else:
expert_slice = slice_[start:stop]
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor
class DeepseekV2Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights: Weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.value_head_size = config.v_head_dim
self.head_pad_size = max(self.head_size, self.value_head_size)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.qk_rope_head_dim,
base=config.rope_theta,
device=weights.device,
)
mscale = get_mscale(
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
)
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if self.q_lora_rank is None:
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=config.attention_bias,
)
else:
self.q_a_proj = get_linear(
weight=weights.get_weights(f"{prefix}.q_a_proj"),
bias=(
weights.get_tensor(f"{prefix}.q_a_proj.bias")
if config.attention_bias
else None
),
quantize=config.quantize,
)
self.q_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.q_a_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.q_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.kv_a_proj_with_mqa = get_linear(
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
bias=(
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
if config.attention_bias
else None
),
quantize=config.quantize,
)
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.kv_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.kv_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: Seqlen,
max_s: int,
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
else:
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
query = query.view(-1, self.num_heads, self.head_size)
_, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
batch_size, heads, head_dim = query_pe.shape
query_pe = (
query_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
batch_size, heads, head_dim = key_pe.shape
key_pe = (
key_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
self.rotary_emb(query_pe, key_pe, cos, sin)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
query,
key,
value,
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
# Remove padding.
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
class DeepseekV2MLP(nn.Module):
def __init__(self, prefix: str, config, weights, intermediate_size: int):
super().__init__()
self.hidden_act = config.hidden_act
if self.hidden_act != "silu":
# Bail out because MoE only supports silu.
raise NotImplementedError(
"Currently only `silu` is supported as an activation for Deepseek V2."
)
self.act = ACT2FN[self.hidden_act]
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = intermediate_size // weights.process_group.size()
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out, reduce=reduce)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
class BlockSparseMoE(nn.Module):
def __init__(self, prefix, config: DeepseekV2Config, weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = (
config.moe_intermediate_size // weights.process_group.size()
)
self.n_routed_experts = config.n_routed_experts
self.n_expert_group = config.n_group
self.topk_group = config.topk_group
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.routed_scaling_factor = config.routed_scaling_factor
gate_proj = _load_experts(
config, f"{prefix}.experts", "gate_proj", weights
).view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim)
up_proj = _load_experts(config, f"{prefix}.experts", "up_proj", weights).view(
self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim
)
self.gate_up_proj = torch.cat([gate_proj, up_proj], dim=1)
self.down_proj = (
_load_experts(config, f"{prefix}.experts", "down_proj", weights)
.view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim)
.transpose(1, 2)
.contiguous()
)
# Gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV2MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
router_logits = self.gate(x)
topk_weights, topk_ids = grouped_topk(
x,
router_logits,
self.top_k,
renormalize=self.norm_topk_prob,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
out = (
fused_experts(
x,
self.gate_up_proj,
self.down_proj,
topk_weights,
topk_ids,
inplace=True,
)
* self.routed_scaling_factor
)
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class DenseMoE(nn.Module):
def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
self.n_routed_experts = config.n_routed_experts
self.n_expert_group = config.n_group
self.topk_group = config.topk_group
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.routed_scaling_factor = config.routed_scaling_factor
# Gating
#
# Seems like no one quantizes the gate.
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.experts = [
DeepseekV2MLP(
f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size
)
for i in range(self.n_routed_experts)
]
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV2MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
# gate_logits: (sequence_length, n_experts)
router_logits = self.gate(x)
topk_weights, topk_ids = grouped_topk(
x,
router_logits,
self.top_k,
renormalize=self.norm_topk_prob,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
def moe_infer_gpu(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
):
weights = torch.zeros(
topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device
)
weights.scatter_(1, topk_ids, topk_weight)
out = x.new_zeros(x.shape[0], self.hidden_dim)
for i, expert in enumerate(self.experts):
# Add expert output to out with masking
out += expert(x, reduce=False) * weights[:, i].view(-1, 1)
return out
class DeepseekV2Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = DeepseekV2Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
if (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
self.mlp = moe_cls(f"{prefix}.mlp", config, weights)
else:
self.mlp = DeepseekV2MLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: Seqlen,
max_s: int,
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
# faster post attention rms norm
normed_attn_res_output, residual = self.post_attention_layernorm(
attn_output, residual
)
output = self.mlp(normed_attn_res_output)
return output, residual
class DeepseekV2Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
DeepseekV2Layer(
prefix,
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashDeepseekV2ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.model = DeepseekV2Model(
"model" if not prefix else f"{prefix}.model", config, weights
)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
# Functions below are from vLLM:
#
# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397
#
# Remove after we have synced our version with upstream.
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def get_default_config(
M: int,
E: int,
N: int,
K: int,
topk: int,
dtype: Optional[str],
) -> Dict[str, int]:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
if M <= E:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
return config
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_moe_configs,
invoke_fused_moe_kernel,
moe_align_block_size,
)
M, _ = hidden_states.shape
E, N, _ = w1.shape
if override_config:
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], E
)
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
if inplace:
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)

View File

@ -839,7 +839,9 @@ class FlashCausalLM(Model):
default_dtype=torch.float16, default_dtype=torch.float16,
aliases=None, aliases=None,
# Used for Santacoder override of config # Used for Santacoder override of config
num_kv_heads=None, num_kv_heads: Optional[int] = None,
# Deepseek V2 uses different QK and V dims.
head_size: Optional[int] = None,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
@ -922,7 +924,11 @@ class FlashCausalLM(Model):
else num_kv_heads else num_kv_heads
) )
assert self.num_kv_heads > 0 assert self.num_kv_heads > 0
if head_size is None:
self.head_size = config.hidden_size // config.num_attention_heads self.head_size = config.hidden_size // config.num_attention_heads
else:
self.head_size = head_size
self.cuda_graphs = {} self.cuda_graphs = {}
self.kv_cache = [] self.kv_cache = []

View File

@ -21,6 +21,13 @@ class WeightsLoader(ABC):
with the format, etc. with the format, etc.
""" """
@abstractmethod
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
...
@abstractmethod @abstractmethod
def get_weights_col_packed( def get_weights_col_packed(
self, self,
@ -104,6 +111,9 @@ class DefaultWeightsLoader(WeightsLoader):
and/or concatenation. and/or concatenation.
""" """
def get_weights(self, weights: "Weights", prefix: str):
return weights.get_tensor(f"{prefix}.weight")
def get_weights_col_packed( def get_weights_col_packed(
self, self,
weights: "Weights", weights: "Weights",
@ -299,6 +309,9 @@ class Weights:
return tensor return tensor
def get_weights(self, prefix: str):
return self.weights_loader.get_weights(self, prefix)
def get_weights_col_packed_qkv( def get_weights_col_packed_qkv(
self, self,
prefix: str, prefix: str,