From b40e833493808ed80b0bd6d8a68252fff01d307a Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 28 Feb 2024 12:07:08 +0100 Subject: [PATCH] feat: starcoder2 (#1605) --- .../test_flash_starcoder2.json | 94 +++ .../test_flash_starcoder2_default_params.json | 394 +++++++++++++ .../test_flash_starcoder2_load.json | 378 ++++++++++++ .../models/test_flash_starcoder2.py | 55 ++ proto/generate.proto | 1 - .../text_generation_server/models/__init__.py | 24 + .../flash_starcoder2_modeling.py | 545 ++++++++++++++++++ .../models/flash_mistral.py | 40 +- .../models/flash_starcoder2.py | 86 +++ 9 files changed, 1601 insertions(+), 16 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json create mode 100644 integration-tests/models/test_flash_starcoder2.py create mode 100644 server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py create mode 100644 server/text_generation_server/models/flash_starcoder2.py diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json new file mode 100644 index 00000000..36a2ff4d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json @@ -0,0 +1,94 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40844727, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27905273, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6118164, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68652344, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4619141, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.7993164, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.63134766, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23278809, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2294922, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json new file mode 100644 index 00000000..38117272 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json @@ -0,0 +1,394 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 60, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": 0, + "tokens": [ + { + "id": 2284, + "logprob": -0.296875, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.28125, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -0.79248047, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.61816406, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.0619812, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -0.4091797, + "special": false, + "text": "def" + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7670, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 444, + "logprob": -0.21655273, + "special": false, + "text": "name" + }, + { + "id": 45, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 731, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 303, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 332, + "logprob": -0.034698486, + "special": false, + "text": " \"" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 655, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 494, + "logprob": -0.20141602, + "special": false, + "text": " +" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 16013, + "logprob": 0.0, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 7670, + "logprob": 0.0, + "special": false, + "text": "hello" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 100, + "logprob": 0.0, + "special": false, + "text": "_" + }, + { + "id": 400, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 45, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 444, + "logprob": 0.0, + "special": false, + "text": "name" + }, + { + "id": 49, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 11505, + "logprob": 0.0, + "special": false, + "text": " age" + }, + { + "id": 731, + "logprob": 0.0, + "special": false, + "text": "):" + }, + { + "id": 303, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": 0.0, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": 0.0, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": 0.0, + "special": false, + "text": "Hello" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 655, + "logprob": 0.0, + "special": false, + "text": " name" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 3021, + "logprob": -0.5761719, + "special": false, + "text": " \"," + }, + { + "id": 863, + "logprob": 0.0, + "special": false, + "text": " you" + }, + { + "id": 904, + "logprob": 0.0, + "special": false, + "text": " are" + }, + { + "id": 332, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 494, + "logprob": 0.0, + "special": false, + "text": " +" + }, + { + "id": 615, + "logprob": 0.0, + "special": false, + "text": " str" + }, + { + "id": 45, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 400, + "logprob": 0.0, + "special": false, + "text": "age" + }, + { + "id": 46, + "logprob": 0.0, + "special": false, + "text": ")" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name + \"!\")\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \", you are \" + str(age)" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json new file mode 100644 index 00000000..9e82d4be --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json @@ -0,0 +1,378 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 610, + "logprob": null, + "text": "def" + }, + { + "id": 1489, + "logprob": -5.2617188, + "text": " print" + }, + { + "id": 100, + "logprob": -0.38476562, + "text": "_" + }, + { + "id": 7670, + "logprob": -7.640625, + "text": "hello" + } + ], + "seed": null, + "tokens": [ + { + "id": 2284, + "logprob": -0.92626953, + "special": false, + "text": "():" + }, + { + "id": 303, + "logprob": -0.40722656, + "special": false, + "text": "\n " + }, + { + "id": 1489, + "logprob": -0.27954102, + "special": false, + "text": " print" + }, + { + "id": 459, + "logprob": -0.6142578, + "special": false, + "text": "(\"" + }, + { + "id": 8302, + "logprob": -0.68310547, + "special": false, + "text": "Hello" + }, + { + "id": 10914, + "logprob": -1.4570312, + "special": false, + "text": " World" + }, + { + "id": 16013, + "logprob": -0.80126953, + "special": false, + "text": "!\")" + }, + { + "id": 222, + "logprob": -0.6303711, + "special": false, + "text": "\n" + }, + { + "id": 222, + "logprob": -0.23327637, + "special": false, + "text": "\n" + }, + { + "id": 610, + "logprob": -1.2304688, + "special": false, + "text": "def" + } + ], + "top_tokens": null + }, + "generated_text": "():\n print(\"Hello World!\")\n\ndef" + } +] diff --git a/integration-tests/models/test_flash_starcoder2.py b/integration-tests/models/test_flash_starcoder2.py new file mode 100644 index 00000000..ea665b6c --- /dev/null +++ b/integration-tests/models/test_flash_starcoder2.py @@ -0,0 +1,55 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_starcoder2_handle(launcher): + with launcher("bigcode/starcoder2-3b", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_starcoder2(flash_starcoder2_handle): + await flash_starcoder2_handle.health(300) + return flash_starcoder2_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder2(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "def print_hello", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot): + response = await flash_starcoder2.generate( + "def print_hello", + max_new_tokens=60, + temperature=0.2, + top_p=0.95, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 60 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder2_load( + flash_starcoder2, generate_load, response_snapshot +): + responses = await generate_load( + flash_starcoder2, "def print_hello", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/proto/generate.proto b/proto/generate.proto index 0490029f..6351e37f 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -230,7 +230,6 @@ message WarmupRequest { uint32 max_total_tokens = 4; } -/// Empty response message WarmupResponse { /// Maximum number of tokens supported by the model optional uint32 max_supported_total_tokens = 1; diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3208275c..e2edbfa9 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -64,6 +64,7 @@ try: from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.flash_mixtral import FlashMixtral from text_generation_server.models.flash_phi import FlashPhi + from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA except ImportError as e: @@ -80,6 +81,7 @@ if FLASH_ATTENTION: __all__.append(FlashMistral) __all__.append(FlashMixtral) __all__.append(FlashPhi) + __all__.append(FlashStarcoder2) MAMBA_AVAILABLE = True try: @@ -184,6 +186,16 @@ def get_model( trust_remote_code=trust_remote_code, ) + if model_id.startswith("facebook/galactica"): + return GalacticaSharded( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + if ( model_type == "gpt_bigcode" or model_type == "gpt2" @@ -401,6 +413,18 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + if model_type == "starcoder2": + sliding_window = config_dict.get("sliding_window", -1) + if ( + (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION + ) or HAS_FLASH_ATTN_V2_CUDA: + return FlashStarcoder2( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type == "opt": return OPTSharded( diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py new file mode 100644 index 00000000..ed77af78 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -0,0 +1,545 @@ +# coding=utf-8 +# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + SpeculativeHead, + get_linear, + FastRMSNorm, + FastLayerNorm, +) + + +class Starcoder2Config(PretrainedConfig): + model_type = "starcoder2" + + def __init__( + self, + vocab_size=49152, + hidden_size=3072, + intermediate_size=12288, + num_hidden_layers=30, + num_attention_heads=24, + num_key_value_heads=2, + mlp_type="default", + hidden_act="gelu_pytorch_tanh", + max_position_embeddings=4096, + initializer_range=0.018042, + norm_type="layer_norm", + norm_epsilon=1e-5, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + rope_theta=10000.0, + sliding_window=None, + attention_dropout=0.0, + residual_dropout=0.0, + embedding_dropout=0.0, + use_bias: bool = True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.use_bias = use_bias + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.mlp_type = mlp_type + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_type = norm_type + self.norm_epsilon = norm_epsilon + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.residual_dropout = residual_dropout + self.embedding_dropout = embedding_dropout + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=config.use_bias, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + if config.use_bias: + w = [ + weights.get_sharded(f"{p}.bias", dim=0) + for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + ] + bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device) + else: + bias = None + + return TensorParallelColumnLinear( + get_linear(weight, bias=bias, quantize=config.quantize) + ) + + +class Starcoder2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.max_past = ( + config.sliding_window if config.sliding_window is not None else -1 + ) + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=config.use_bias, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + paged_attention.reshape_and_cache( + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + window_size_left=self.max_past, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class Starcoder2MLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.c_fc = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.c_fc", + weights=weights, + bias=config.use_bias, + ) + self.c_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=config.use_bias, + ) + + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + return self.c_proj(hidden_states) + + +class Starcoder2GatedMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=config.use_bias, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=config.use_bias, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +STARCODER2_NORMALIZATION_CLASSES = { + "layer_norm": FastLayerNorm, + "rms_norm": FastRMSNorm, +} + +STARCODER2_MLP_CLASSES = { + "default": Starcoder2MLP, + "gated": Starcoder2GatedMLP, +} + + +class Starcoder2Layer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = Starcoder2Attention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + + self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) + + self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.norm_epsilon + ) + self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[ + config.norm_type + ].load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.norm_epsilon, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class Starcoder2Model(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + Starcoder2Layer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( + prefix="model.norm", weights=weights, eps=config.norm_epsilon + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + true_max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, true_max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashStarcoder2ForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = Starcoder2Model(config, weights) + try: + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head", + weights=weights, + ) + except RuntimeError: + self.lm_head = SpeculativeHead.load( + config, + prefix="model.embed_tokens", + weights=weights, + ) + + self.max_past = config.sliding_window + self.max_past_tensor = ( + torch.tensor(config.sliding_window, device=weights.device) + if self.max_past is not None + else None + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + true_max_s = max_s + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif self.max_past is not None: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + true_max_s, + prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index d3c0da9c..fd5c18e0 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase from transformers.models.llama import LlamaTokenizerFast -from typing import Optional, Tuple, Type, List +from typing import Optional, Tuple, Type from text_generation_server.pb import generate_pb2 from text_generation_server.models import FlashCausalLM @@ -38,6 +38,19 @@ SLIDING_WINDOW_BLOCKS: Optional[int] = None MEM_POOL = torch.cuda.graph_pool_handle() +def set_sliding_window(sliding_window: int, sliding_window_blocks: int): + global SLIDING_WINDOW + global SLIDING_WINDOW_BLOCKS + SLIDING_WINDOW = sliding_window + SLIDING_WINDOW_BLOCKS = sliding_window_blocks + + +def get_sliding_windows() -> Tuple[int, int]: + global SLIDING_WINDOW + global SLIDING_WINDOW_BLOCKS + return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS + + # Adds windowing logic to FlashCausalLMBatch @dataclass class FlashMistralBatch(FlashCausalLMBatch): @@ -53,8 +66,7 @@ class FlashMistralBatch(FlashCausalLMBatch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": - global SLIDING_WINDOW - global SLIDING_WINDOW_BLOCKS + sliding_window, sliding_window_blocks = get_sliding_windows() batch_inputs = [] max_truncation = 0 @@ -139,8 +151,8 @@ class FlashMistralBatch(FlashCausalLMBatch): # Needed blocks can not go over SLIDING_WINDOW_BLOCKS needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) - if SLIDING_WINDOW_BLOCKS is not None: - needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS) + if sliding_window_blocks is not None: + needed_blocks = min(needed_blocks, sliding_window_blocks) blocks += needed_blocks needed_blocks_slots.append((needed_blocks, total_tokens)) @@ -154,9 +166,9 @@ class FlashMistralBatch(FlashCausalLMBatch): slot_indices.append(request_slot_indices) # Create tensor to slice into the kv tensor in prefill - if SLIDING_WINDOW is not None: + if sliding_window is not None: request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, input_length - SLIDING_WINDOW), + cumulative_length + max(0, input_length - sliding_window), cumulative_length + input_length, dtype=torch.int64, ) @@ -212,13 +224,13 @@ class FlashMistralBatch(FlashCausalLMBatch): input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) - if SLIDING_WINDOW is not None: + if sliding_window is not None: prefill_cache_indices = torch.cat(prefill_cache_indices) else: input_ids = all_input_ids[0] position_ids = position_ids[0] slot_indices = slot_indices[0] - if SLIDING_WINDOW is not None: + if sliding_window is not None: prefill_cache_indices = prefill_cache_indices[0] cu_seqlen_prefill = torch.tensor( @@ -228,7 +240,7 @@ class FlashMistralBatch(FlashCausalLMBatch): position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) prefill_cache_indices = ( - prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None + prefill_cache_indices.to(device) if sliding_window is not None else None ) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_lengths_tensor = torch.tensor( @@ -298,9 +310,6 @@ class BaseFlashMistral(FlashCausalLM): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - global SLIDING_WINDOW - global SLIDING_WINDOW_BLOCKS - self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") @@ -324,8 +333,9 @@ class BaseFlashMistral(FlashCausalLM): # Set context windows if config.sliding_window is not None: - SLIDING_WINDOW = config.sliding_window - SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE) + set_sliding_window( + config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) + ) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py new file mode 100644 index 00000000..2f6ae757 --- /dev/null +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -0,0 +1,86 @@ +import math + +import torch + +from typing import Optional + +from transformers.models.gpt2 import GPT2TokenizerFast + +from text_generation_server.models.cache_manager import BLOCK_SIZE +from text_generation_server.models.flash_mistral import ( + BaseFlashMistral, + set_sliding_window, +) +from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( + Starcoder2Config, + FlashStarcoder2ForCausalLM, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + + +# Starcoder2 has the same base as Mistral +class FlashStarcoder2(BaseFlashMistral): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + use_medusa: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashLlama is only available on GPU") + + tokenizer = GPT2TokenizerFast.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = Starcoder2Config.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.use_medusa = use_medusa + + # Set context windows + if config.sliding_window is not None: + set_sliding_window( + config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) + ) + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id, revision) + + model = FlashStarcoder2ForCausalLM(config, weights) + + self.cuda_graphs = {} + + torch.distributed.barrier(group=self.process_group) + super(BaseFlashMistral, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + sliding_window=config.sliding_window, + )