feat: starcoder2 (#1605)

This commit is contained in:
OlivierDehaene 2024-02-28 12:07:08 +01:00 committed by GitHub
parent 97e22369f4
commit b40e833493
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1601 additions and 16 deletions

View File

@ -0,0 +1,94 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 610,
"logprob": null,
"text": "def"
},
{
"id": 1489,
"logprob": -5.2617188,
"text": " print"
},
{
"id": 100,
"logprob": -0.38476562,
"text": "_"
},
{
"id": 7670,
"logprob": -7.640625,
"text": "hello"
}
],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.92626953,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40844727,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27905273,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6118164,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68652344,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4619141,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.7993164,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.63134766,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23278809,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2294922,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
}

View File

@ -0,0 +1,394 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 60,
"prefill": [
{
"id": 610,
"logprob": null,
"text": "def"
},
{
"id": 1489,
"logprob": -5.2617188,
"text": " print"
},
{
"id": 100,
"logprob": -0.38476562,
"text": "_"
},
{
"id": 7670,
"logprob": -7.640625,
"text": "hello"
}
],
"seed": 0,
"tokens": [
{
"id": 2284,
"logprob": -0.296875,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": 0.0,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.28125,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -0.79248047,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.61816406,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.0619812,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -0.4091797,
"special": false,
"text": "def"
},
{
"id": 1489,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 7670,
"logprob": 0.0,
"special": false,
"text": "hello"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 444,
"logprob": -0.21655273,
"special": false,
"text": "name"
},
{
"id": 45,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 444,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 731,
"logprob": 0.0,
"special": false,
"text": "):"
},
{
"id": 303,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": 0.0,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 332,
"logprob": -0.034698486,
"special": false,
"text": " \""
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 655,
"logprob": 0.0,
"special": false,
"text": " name"
},
{
"id": 494,
"logprob": -0.20141602,
"special": false,
"text": " +"
},
{
"id": 332,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 16013,
"logprob": 0.0,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": 0.0,
"special": false,
"text": "def"
},
{
"id": 1489,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 7670,
"logprob": 0.0,
"special": false,
"text": "hello"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 444,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 100,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 400,
"logprob": 0.0,
"special": false,
"text": "age"
},
{
"id": 45,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 444,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 49,
"logprob": 0.0,
"special": false,
"text": ","
},
{
"id": 11505,
"logprob": 0.0,
"special": false,
"text": " age"
},
{
"id": 731,
"logprob": 0.0,
"special": false,
"text": "):"
},
{
"id": 303,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": 0.0,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 332,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 655,
"logprob": 0.0,
"special": false,
"text": " name"
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 3021,
"logprob": -0.5761719,
"special": false,
"text": " \","
},
{
"id": 863,
"logprob": 0.0,
"special": false,
"text": " you"
},
{
"id": 904,
"logprob": 0.0,
"special": false,
"text": " are"
},
{
"id": 332,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 494,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 615,
"logprob": 0.0,
"special": false,
"text": " str"
},
{
"id": 45,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 400,
"logprob": 0.0,
"special": false,
"text": "age"
},
{
"id": 46,
"logprob": 0.0,
"special": false,
"text": ")"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)"
}

View File

@ -0,0 +1,378 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 610,
"logprob": null,
"text": "def"
},
{
"id": 1489,
"logprob": -5.2617188,
"text": " print"
},
{
"id": 100,
"logprob": -0.38476562,
"text": "_"
},
{
"id": 7670,
"logprob": -7.640625,
"text": "hello"
}
],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.92626953,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40722656,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4570312,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.6303711,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23327637,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2304688,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 610,
"logprob": null,
"text": "def"
},
{
"id": 1489,
"logprob": -5.2617188,
"text": " print"
},
{
"id": 100,
"logprob": -0.38476562,
"text": "_"
},
{
"id": 7670,
"logprob": -7.640625,
"text": "hello"
}
],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.92626953,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40722656,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4570312,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.6303711,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23327637,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2304688,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 610,
"logprob": null,
"text": "def"
},
{
"id": 1489,
"logprob": -5.2617188,
"text": " print"
},
{
"id": 100,
"logprob": -0.38476562,
"text": "_"
},
{
"id": 7670,
"logprob": -7.640625,
"text": "hello"
}
],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.92626953,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40722656,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4570312,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.6303711,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23327637,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2304688,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 610,
"logprob": null,
"text": "def"
},
{
"id": 1489,
"logprob": -5.2617188,
"text": " print"
},
{
"id": 100,
"logprob": -0.38476562,
"text": "_"
},
{
"id": 7670,
"logprob": -7.640625,
"text": "hello"
}
],
"seed": null,
"tokens": [
{
"id": 2284,
"logprob": -0.92626953,
"special": false,
"text": "():"
},
{
"id": 303,
"logprob": -0.40722656,
"special": false,
"text": "\n "
},
{
"id": 1489,
"logprob": -0.27954102,
"special": false,
"text": " print"
},
{
"id": 459,
"logprob": -0.6142578,
"special": false,
"text": "(\""
},
{
"id": 8302,
"logprob": -0.68310547,
"special": false,
"text": "Hello"
},
{
"id": 10914,
"logprob": -1.4570312,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.80126953,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.6303711,
"special": false,
"text": "\n"
},
{
"id": 222,
"logprob": -0.23327637,
"special": false,
"text": "\n"
},
{
"id": 610,
"logprob": -1.2304688,
"special": false,
"text": "def"
}
],
"top_tokens": null
},
"generated_text": "():\n print(\"Hello World!\")\n\ndef"
}
]

View File

@ -0,0 +1,55 @@
import pytest
@pytest.fixture(scope="module")
def flash_starcoder2_handle(launcher):
with launcher("bigcode/starcoder2-3b", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_starcoder2(flash_starcoder2_handle):
await flash_starcoder2_handle.health(300)
return flash_starcoder2_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"def print_hello", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):
response = await flash_starcoder2.generate(
"def print_hello",
max_new_tokens=60,
temperature=0.2,
top_p=0.95,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 60
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder2_load(
flash_starcoder2, generate_load, response_snapshot
):
responses = await generate_load(
flash_starcoder2, "def print_hello", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -230,7 +230,6 @@ message WarmupRequest {
uint32 max_total_tokens = 4;
}
/// Empty response
message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;

View File

@ -64,6 +64,7 @@ try:
from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
except ImportError as e:
@ -80,6 +81,7 @@ if FLASH_ATTENTION:
__all__.append(FlashMistral)
__all__.append(FlashMixtral)
__all__.append(FlashPhi)
__all__.append(FlashStarcoder2)
MAMBA_AVAILABLE = True
try:
@ -184,6 +186,16 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_id.startswith("facebook/galactica"):
return GalacticaSharded(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if (
model_type == "gpt_bigcode"
or model_type == "gpt2"
@ -401,6 +413,18 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "starcoder2":
sliding_window = config_dict.get("sliding_window", -1)
if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
) or HAS_FLASH_ATTN_V2_CUDA:
return FlashStarcoder2(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "opt":
return OPTSharded(

View File

@ -0,0 +1,545 @@
# coding=utf-8
# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
SpeculativeHead,
get_linear,
FastRMSNorm,
FastLayerNorm,
)
class Starcoder2Config(PretrainedConfig):
model_type = "starcoder2"
def __init__(
self,
vocab_size=49152,
hidden_size=3072,
intermediate_size=12288,
num_hidden_layers=30,
num_attention_heads=24,
num_key_value_heads=2,
mlp_type="default",
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=4096,
initializer_range=0.018042,
norm_type="layer_norm",
norm_epsilon=1e-5,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
rope_theta=10000.0,
sliding_window=None,
attention_dropout=0.0,
residual_dropout=0.0,
embedding_dropout=0.0,
use_bias: bool = True,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.use_bias = use_bias
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.mlp_type = mlp_type
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.norm_type = norm_type
self.norm_epsilon = norm_epsilon
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.embedding_dropout = embedding_dropout
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=config.use_bias,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
if config.use_bias:
w = [
weights.get_sharded(f"{p}.bias", dim=0)
for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
]
bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
else:
bias = None
return TensorParallelColumnLinear(
get_linear(weight, bias=bias, quantize=config.quantize)
)
class Starcoder2Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=config.use_bias,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class Starcoder2MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
self.c_fc = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.c_fc",
weights=weights,
bias=config.use_bias,
)
self.c_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=config.use_bias,
)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states)
class Starcoder2GatedMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=config.use_bias,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=config.use_bias,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
STARCODER2_NORMALIZATION_CLASSES = {
"layer_norm": FastLayerNorm,
"rms_norm": FastRMSNorm,
}
STARCODER2_MLP_CLASSES = {
"default": Starcoder2MLP,
"gated": Starcoder2GatedMLP,
}
class Starcoder2Layer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = Starcoder2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
prefix=f"{prefix}.mlp", config=config, weights=weights
)
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.norm_epsilon
)
self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[
config.norm_type
].load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.norm_epsilon,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
return mlp_output, attn_res
class Starcoder2Model(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
Starcoder2Layer(
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
prefix="model.norm", weights=weights, eps=config.norm_epsilon
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
prefill_cache_indices,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashStarcoder2ForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = Starcoder2Model(config, weights)
try:
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",
weights=weights,
)
except RuntimeError:
self.lm_head = SpeculativeHead.load(
config,
prefix="model.embed_tokens",
weights=weights,
)
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type, List
from typing import Optional, Tuple, Type
from text_generation_server.pb import generate_pb2
from text_generation_server.models import FlashCausalLM
@ -38,6 +38,19 @@ SLIDING_WINDOW_BLOCKS: Optional[int] = None
MEM_POOL = torch.cuda.graph_pool_handle()
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
SLIDING_WINDOW = sliding_window
SLIDING_WINDOW_BLOCKS = sliding_window_blocks
def get_sliding_windows() -> Tuple[int, int]:
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS
# Adds windowing logic to FlashCausalLMBatch
@dataclass
class FlashMistralBatch(FlashCausalLMBatch):
@ -53,8 +66,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
sliding_window, sliding_window_blocks = get_sliding_windows()
batch_inputs = []
max_truncation = 0
@ -139,8 +151,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
if SLIDING_WINDOW_BLOCKS is not None:
needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS)
if sliding_window_blocks is not None:
needed_blocks = min(needed_blocks, sliding_window_blocks)
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
@ -154,9 +166,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill
if SLIDING_WINDOW is not None:
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - SLIDING_WINDOW),
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
@ -212,13 +224,13 @@ class FlashMistralBatch(FlashCausalLMBatch):
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
if SLIDING_WINDOW is not None:
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
if SLIDING_WINDOW is not None:
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
cu_seqlen_prefill = torch.tensor(
@ -228,7 +240,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
prefill_cache_indices = (
prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None
prefill_cache_indices.to(device) if sliding_window is not None else None
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor(
@ -298,9 +310,6 @@ class BaseFlashMistral(FlashCausalLM):
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
@ -324,8 +333,9 @@ class BaseFlashMistral(FlashCausalLM):
# Set context windows
if config.sliding_window is not None:
SLIDING_WINDOW = config.sliding_window
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
set_sliding_window(
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
torch.distributed.barrier(group=self.process_group)

View File

@ -0,0 +1,86 @@
import math
import torch
from typing import Optional
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models.cache_manager import BLOCK_SIZE
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
Starcoder2Config,
FlashStarcoder2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
# Starcoder2 has the same base as Mistral
class FlashStarcoder2(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = GPT2TokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = Starcoder2Config.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
# Set context windows
if config.sliding_window is not None:
set_sliding_window(
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision)
model = FlashStarcoder2ForCausalLM(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)