From c86f58d37cfff019fea878d8f2bf9b4da26c1d8e Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 21 Feb 2024 14:15:22 +0100 Subject: [PATCH] feat: add support for Gemma (#1583) --- integration-tests/conftest.py | 3 + .../test_flash_gemma/test_flash_gemma.json | 89 +++ .../test_flash_gemma_all_params.json | 89 +++ .../test_flash_gemma_load.json | 358 ++++++++++ integration-tests/models/test_flash_gemma.py | 61 ++ .../text_generation_server/models/__init__.py | 25 + .../custom_modeling/flash_gemma_modeling.py | 609 ++++++++++++++++++ .../models/flash_gemma.py | 104 +++ 8 files changed, 1338 insertions(+) create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json create mode 100644 integration-tests/models/test_flash_gemma.py create mode 100644 server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py create mode 100644 server/text_generation_server/models/flash_gemma.py diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index e0228894..80457bc2 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -40,6 +40,9 @@ class ResponseComparator(JSONSnapshotExtension): exclude=None, matcher=None, ): + if isinstance(data, Response): + data = data.dict() + if isinstance(data, List): data = [d.dict() for d in data] diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json new file mode 100644 index 00000000..80f0d053 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.8671875, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.4375, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8203125, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23242188, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.08544922, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.9375, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.671875, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.40429688, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.1875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json new file mode 100644 index 00000000..8253dc96 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 7539, + "logprob": -0.73046875, + "special": false, + "text": " forms" + }, + { + "id": 708, + "logprob": 0.0, + "special": false, + "text": " are" + }, + { + "id": 671, + "logprob": -1.703125, + "special": false, + "text": " an" + }, + { + "id": 8727, + "logprob": 0.0, + "special": false, + "text": " essential" + }, + { + "id": 1702, + "logprob": 0.0, + "special": false, + "text": " part" + }, + { + "id": 576, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 573, + "logprob": 0.0, + "special": false, + "text": " the" + }, + { + "id": 11859, + "logprob": -1.6953125, + "special": false, + "text": " lab" + }, + { + "id": 2185, + "logprob": -1.3125, + "special": false, + "text": " process" + }, + { + "id": 578, + "logprob": -1.5, + "special": false, + "text": " and" + } + ], + "top_tokens": null + }, + "generated_text": "Test request forms are an essential part of the lab process and" +} diff --git a/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json new file mode 100644 index 00000000..e69ee25d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2, + "logprob": null, + "text": "" + }, + { + "id": 2015, + "logprob": -10.0, + "text": "Test" + }, + { + "id": 3853, + "logprob": -10.875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 1736, + "logprob": -2.09375, + "special": false, + "text": " form" + }, + { + "id": 109, + "logprob": -1.9140625, + "special": false, + "text": "\n\n" + }, + { + "id": 651, + "logprob": -2.453125, + "special": false, + "text": "The" + }, + { + "id": 2121, + "logprob": -1.8984375, + "special": false, + "text": " test" + }, + { + "id": 3853, + "logprob": -0.23535156, + "special": false, + "text": " request" + }, + { + "id": 1736, + "logprob": -0.091308594, + "special": false, + "text": " form" + }, + { + "id": 603, + "logprob": -0.96875, + "special": false, + "text": " is" + }, + { + "id": 1671, + "logprob": -1.6484375, + "special": false, + "text": " used" + }, + { + "id": 577, + "logprob": -0.43164062, + "special": false, + "text": " to" + }, + { + "id": 3853, + "logprob": -1.2421875, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": " form\n\nThe test request form is used to request" + } +] diff --git a/integration-tests/models/test_flash_gemma.py b/integration-tests/models/test_flash_gemma.py new file mode 100644 index 00000000..2822b5e2 --- /dev/null +++ b/integration-tests/models/test_flash_gemma.py @@ -0,0 +1,61 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_gemma_handle(launcher): + with launcher("gg-hf/gemma-2b", num_shard=1) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_gemma(flash_gemma_handle): + await flash_gemma_handle.health(300) + return flash_gemma_handle.client + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma(flash_gemma, response_snapshot): + response = await flash_gemma.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_all_params(flash_gemma, response_snapshot): + response = await flash_gemma.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): + responses = await generate_load(flash_gemma, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index da7d8416..abab3486 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -52,6 +52,9 @@ try: from text_generation_server.models.flash_llama import ( FlashLlama, ) + from text_generation_server.models.flash_gemma import ( + FlashGemma, + ) from text_generation_server.models.flash_santacoder import ( FlashSantacoderSharded, ) @@ -312,6 +315,28 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + if model_type == "gemma": + if FLASH_ATTENTION: + return FlashGemma( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + use_medusa=use_medusa, + ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate") + ) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]: if sharded: diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py new file mode 100644 index 00000000..4a08bc2a --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -0,0 +1,609 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed +import os +from shutil import copyfile + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple +from tokenizers import processors +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +from transformers.utils import logging + +from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + TensorParallelHead, + get_linear, + FastRMSNorm, +) + +GemmaTokenizer = None + +logger = logging.get_logger(__name__) +VOCAB_FILES_NAMES = { + "vocab_file": "tokenizer.model", + "tokenizer_file": "tokenizer.json", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + }, + "tokenizer_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + }, +} +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class GemmaTokenizerFast(PreTrainedTokenizerFast): + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + slow_tokenizer_class = GemmaTokenizer + padding_side = "left" + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + clean_up_tokenization_spaces=False, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + add_bos_token=True, + add_eos_token=False, + use_default_system_prompt=False, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + use_default_system_prompt=use_default_system_prompt, + **kwargs, + ) + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + self.update_post_processor() + self.use_default_system_prompt = use_default_system_prompt + self.vocab_file = vocab_file + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + def update_post_processor(self): + """ + Updates the underlying post processor with the current `bos_token` and `eos_token`. + """ + bos = self.bos_token + bos_token_id = self.bos_token_id + if bos is None and self.add_bos_token: + raise ValueError("add_bos_token = True but bos_token = None") + + eos = self.eos_token + eos_token_id = self.eos_token_id + if eos is None and self.add_eos_token: + raise ValueError("add_eos_token = True but eos_token = None") + + single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" + pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" + + special_tokens = [] + if self.add_bos_token: + special_tokens.append((bos, bos_token_id)) + if self.add_eos_token: + special_tokens.append((eos, eos_token_id)) + self._tokenizer.post_processor = processors.TemplateProcessing( + single=single, pair=pair, special_tokens=special_tokens + ) + + @property + def add_eos_token(self): + return self._add_eos_token + + @property + def add_bos_token(self): + return self._add_bos_token + + @add_eos_token.setter + def add_eos_token(self, value): + self._add_eos_token = value + self.update_post_processor() + + @add_bos_token.setter + def add_bos_token(self, value): + self._add_bos_token = value + self.update_post_processor() + + def save_vocabulary( + self, save_directory: str, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + @property + def default_chat_template(self): + raise NotImplementedError + + # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + +class GemmaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=256128, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.head_dim = head_dim + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class GemmaFastRMSNorm(FastRMSNorm): + @classmethod + def load(cls, prefix, weights, eps=1e-6): + weight = weights.get_tensor(f"{prefix}.weight") + 1 + return cls(weight, eps) + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.head_dim + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear( + get_linear(weight, bias=None, quantize=config.quantize) + ) + + +class FlashGemmaAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.head_size = config.head_dim + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + paged_attention.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class GemmaMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class FlashGemmaLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = FlashGemmaAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = GemmaFastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = GemmaFastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class FlashGemmaModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + embed_norm = config.hidden_size**0.5 + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.embed_tokens.weight *= embed_norm + + self.layers = nn.ModuleList( + [ + FlashGemmaLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = GemmaFastRMSNorm.load( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashGemmaForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = FlashGemmaModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py new file mode 100644 index 00000000..220b3992 --- /dev/null +++ b/server/text_generation_server/models/flash_gemma.py @@ -0,0 +1,104 @@ +import torch +import torch.distributed + +from opentelemetry import trace +from typing import Optional + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + GemmaTokenizerFast, + FlashGemmaForCausalLM, + GemmaConfig, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +tracer = trace.get_tracer(__name__) + + +class FlashGemma(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + use_medusa: Optional[str] = None, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.bfloat16 if dtype is None else dtype + else: + raise NotImplementedError("FlashGemma is only available on GPU") + + tokenizer = GemmaTokenizerFast.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + use_fast=True, + from_slow=False, + ) + + config = GemmaConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id, revision) + + model = FlashGemmaForCausalLM(config, weights) + if use_medusa: + from text_generation_server.utils.medusa import MedusaModel + from huggingface_hub import hf_hub_download + import json + import os + from pathlib import Path + + is_local_model = ( + Path(use_medusa).exists() and Path(use_medusa).is_dir() + ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None + + if not is_local_model: + medusa_config = hf_hub_download( + use_medusa, revision=revision, filename="config.json" + ) + medusa_head = hf_hub_download( + use_medusa, revision=revision, filename="medusa_lm_head.pt" + ) + else: + medusa_config = str(Path(use_medusa) / "config.json") + medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") + + with open(medusa_config, "r") as f: + config = json.load(f) + medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" + weights = Weights( + [medusa_sf], device, dtype, process_group=self.process_group + ) + lm_head = model.lm_head + model.lm_head = MedusaModel(config, weights, lm_head) + + torch.distributed.barrier(group=self.process_group) + super(FlashGemma, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + )