feat: add support for Gemma (#1583)
This commit is contained in:
parent
fa8a8e05af
commit
c86f58d37c
|
@ -40,6 +40,9 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||
exclude=None,
|
||||
matcher=None,
|
||||
):
|
||||
if isinstance(data, Response):
|
||||
data = data.dict()
|
||||
|
||||
if isinstance(data, List):
|
||||
data = [d.dict() for d in data]
|
||||
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -10.0,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 1736,
|
||||
"logprob": -2.09375,
|
||||
"special": false,
|
||||
"text": " form"
|
||||
},
|
||||
{
|
||||
"id": 109,
|
||||
"logprob": -1.8671875,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 651,
|
||||
"logprob": -2.4375,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 2121,
|
||||
"logprob": -1.8203125,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -0.23242188,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 1736,
|
||||
"logprob": -0.08544922,
|
||||
"special": false,
|
||||
"text": " form"
|
||||
},
|
||||
{
|
||||
"id": 603,
|
||||
"logprob": -0.9375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 1671,
|
||||
"logprob": -1.671875,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
},
|
||||
{
|
||||
"id": 577,
|
||||
"logprob": -0.40429688,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -1.1875,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " form\n\nThe test request form is used to request"
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -10.0,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 7539,
|
||||
"logprob": -0.73046875,
|
||||
"special": false,
|
||||
"text": " forms"
|
||||
},
|
||||
{
|
||||
"id": 708,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " are"
|
||||
},
|
||||
{
|
||||
"id": 671,
|
||||
"logprob": -1.703125,
|
||||
"special": false,
|
||||
"text": " an"
|
||||
},
|
||||
{
|
||||
"id": 8727,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " essential"
|
||||
},
|
||||
{
|
||||
"id": 1702,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " part"
|
||||
},
|
||||
{
|
||||
"id": 576,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 573,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 11859,
|
||||
"logprob": -1.6953125,
|
||||
"special": false,
|
||||
"text": " lab"
|
||||
},
|
||||
{
|
||||
"id": 2185,
|
||||
"logprob": -1.3125,
|
||||
"special": false,
|
||||
"text": " process"
|
||||
},
|
||||
{
|
||||
"id": 578,
|
||||
"logprob": -1.5,
|
||||
"special": false,
|
||||
"text": " and"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request forms are an essential part of the lab process and"
|
||||
}
|
|
@ -0,0 +1,358 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -10.0,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 1736,
|
||||
"logprob": -2.09375,
|
||||
"special": false,
|
||||
"text": " form"
|
||||
},
|
||||
{
|
||||
"id": 109,
|
||||
"logprob": -1.9140625,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 651,
|
||||
"logprob": -2.453125,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 2121,
|
||||
"logprob": -1.8984375,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -0.23535156,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 1736,
|
||||
"logprob": -0.091308594,
|
||||
"special": false,
|
||||
"text": " form"
|
||||
},
|
||||
{
|
||||
"id": 603,
|
||||
"logprob": -0.96875,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 1671,
|
||||
"logprob": -1.6484375,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
},
|
||||
{
|
||||
"id": 577,
|
||||
"logprob": -0.43164062,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -1.2421875,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " form\n\nThe test request form is used to request"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -10.0,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 1736,
|
||||
"logprob": -2.09375,
|
||||
"special": false,
|
||||
"text": " form"
|
||||
},
|
||||
{
|
||||
"id": 109,
|
||||
"logprob": -1.9140625,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 651,
|
||||
"logprob": -2.453125,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 2121,
|
||||
"logprob": -1.8984375,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -0.23535156,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 1736,
|
||||
"logprob": -0.091308594,
|
||||
"special": false,
|
||||
"text": " form"
|
||||
},
|
||||
{
|
||||
"id": 603,
|
||||
"logprob": -0.96875,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 1671,
|
||||
"logprob": -1.6484375,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
},
|
||||
{
|
||||
"id": 577,
|
||||
"logprob": -0.43164062,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -1.2421875,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " form\n\nThe test request form is used to request"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -10.0,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 1736,
|
||||
"logprob": -2.09375,
|
||||
"special": false,
|
||||
"text": " form"
|
||||
},
|
||||
{
|
||||
"id": 109,
|
||||
"logprob": -1.9140625,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 651,
|
||||
"logprob": -2.453125,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 2121,
|
||||
"logprob": -1.8984375,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -0.23535156,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 1736,
|
||||
"logprob": -0.091308594,
|
||||
"special": false,
|
||||
"text": " form"
|
||||
},
|
||||
{
|
||||
"id": 603,
|
||||
"logprob": -0.96875,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 1671,
|
||||
"logprob": -1.6484375,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
},
|
||||
{
|
||||
"id": 577,
|
||||
"logprob": -0.43164062,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -1.2421875,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " form\n\nThe test request form is used to request"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": null,
|
||||
"text": "<bos>"
|
||||
},
|
||||
{
|
||||
"id": 2015,
|
||||
"logprob": -10.0,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -10.875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 1736,
|
||||
"logprob": -2.09375,
|
||||
"special": false,
|
||||
"text": " form"
|
||||
},
|
||||
{
|
||||
"id": 109,
|
||||
"logprob": -1.9140625,
|
||||
"special": false,
|
||||
"text": "\n\n"
|
||||
},
|
||||
{
|
||||
"id": 651,
|
||||
"logprob": -2.453125,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 2121,
|
||||
"logprob": -1.8984375,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -0.23535156,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 1736,
|
||||
"logprob": -0.091308594,
|
||||
"special": false,
|
||||
"text": " form"
|
||||
},
|
||||
{
|
||||
"id": 603,
|
||||
"logprob": -0.96875,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 1671,
|
||||
"logprob": -1.6484375,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
},
|
||||
{
|
||||
"id": 577,
|
||||
"logprob": -0.43164062,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 3853,
|
||||
"logprob": -1.2421875,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " form\n\nThe test request form is used to request"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,61 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_gemma_handle(launcher):
|
||||
with launcher("gg-hf/gemma-2b", num_shard=1) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_gemma(flash_gemma_handle):
|
||||
await flash_gemma_handle.health(300)
|
||||
return flash_gemma_handle.client
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma(flash_gemma, response_snapshot):
|
||||
response = await flash_gemma.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
||||
response = await flash_gemma.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
|
||||
responses = await generate_load(flash_gemma, "Test request", max_new_tokens=10, n=4)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -52,6 +52,9 @@ try:
|
|||
from text_generation_server.models.flash_llama import (
|
||||
FlashLlama,
|
||||
)
|
||||
from text_generation_server.models.flash_gemma import (
|
||||
FlashGemma,
|
||||
)
|
||||
from text_generation_server.models.flash_santacoder import (
|
||||
FlashSantacoderSharded,
|
||||
)
|
||||
|
@ -312,6 +315,28 @@ def get_model(
|
|||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if model_type == "gemma":
|
||||
if FLASH_ATTENTION:
|
||||
return FlashGemma(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_medusa=use_medusa,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate")
|
||||
)
|
||||
else:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
|
||||
if sharded:
|
||||
|
|
|
@ -0,0 +1,609 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import os
|
||||
from shutil import copyfile
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
from tokenizers import processors
|
||||
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from transformers.utils import logging
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
)
|
||||
|
||||
GemmaTokenizer = None
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
VOCAB_FILES_NAMES = {
|
||||
"vocab_file": "tokenizer.model",
|
||||
"tokenizer_file": "tokenizer.json",
|
||||
}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
|
||||
},
|
||||
}
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
|
||||
# fmt: off
|
||||
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
|
||||
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
||||
that your responses are socially unbiased and positive in nature.
|
||||
|
||||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
||||
correct. If you don't know the answer to a question, please don't share false information."""
|
||||
# fmt: on
|
||||
|
||||
|
||||
class GemmaTokenizerFast(PreTrainedTokenizerFast):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
slow_tokenizer_class = GemmaTokenizer
|
||||
padding_side = "left"
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
clean_up_tokenization_spaces=False,
|
||||
unk_token="<unk>",
|
||||
bos_token="<bos>",
|
||||
eos_token="<eos>",
|
||||
pad_token="<pad>",
|
||||
add_bos_token=True,
|
||||
add_eos_token=False,
|
||||
use_default_system_prompt=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_file=vocab_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
unk_token=unk_token,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
add_bos_token=add_bos_token,
|
||||
add_eos_token=add_eos_token,
|
||||
use_default_system_prompt=use_default_system_prompt,
|
||||
**kwargs,
|
||||
)
|
||||
self._add_bos_token = add_bos_token
|
||||
self._add_eos_token = add_eos_token
|
||||
self.update_post_processor()
|
||||
self.use_default_system_prompt = use_default_system_prompt
|
||||
self.vocab_file = vocab_file
|
||||
|
||||
@property
|
||||
def can_save_slow_tokenizer(self) -> bool:
|
||||
return os.path.isfile(self.vocab_file) if self.vocab_file else False
|
||||
|
||||
def update_post_processor(self):
|
||||
"""
|
||||
Updates the underlying post processor with the current `bos_token` and `eos_token`.
|
||||
"""
|
||||
bos = self.bos_token
|
||||
bos_token_id = self.bos_token_id
|
||||
if bos is None and self.add_bos_token:
|
||||
raise ValueError("add_bos_token = True but bos_token = None")
|
||||
|
||||
eos = self.eos_token
|
||||
eos_token_id = self.eos_token_id
|
||||
if eos is None and self.add_eos_token:
|
||||
raise ValueError("add_eos_token = True but eos_token = None")
|
||||
|
||||
single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
|
||||
pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
|
||||
|
||||
special_tokens = []
|
||||
if self.add_bos_token:
|
||||
special_tokens.append((bos, bos_token_id))
|
||||
if self.add_eos_token:
|
||||
special_tokens.append((eos, eos_token_id))
|
||||
self._tokenizer.post_processor = processors.TemplateProcessing(
|
||||
single=single, pair=pair, special_tokens=special_tokens
|
||||
)
|
||||
|
||||
@property
|
||||
def add_eos_token(self):
|
||||
return self._add_eos_token
|
||||
|
||||
@property
|
||||
def add_bos_token(self):
|
||||
return self._add_bos_token
|
||||
|
||||
@add_eos_token.setter
|
||||
def add_eos_token(self, value):
|
||||
self._add_eos_token = value
|
||||
self.update_post_processor()
|
||||
|
||||
@add_bos_token.setter
|
||||
def add_bos_token(self, value):
|
||||
self._add_bos_token = value
|
||||
self.update_post_processor()
|
||||
|
||||
def save_vocabulary(
|
||||
self, save_directory: str, filename_prefix: Optional[str] = None
|
||||
) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "")
|
||||
+ VOCAB_FILES_NAMES["vocab_file"],
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
@property
|
||||
def default_chat_template(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
||||
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
||||
|
||||
output = bos_token_id + token_ids_0 + eos_token_id
|
||||
|
||||
if token_ids_1 is not None:
|
||||
output = output + bos_token_id + token_ids_1 + eos_token_id
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class GemmaConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256128,
|
||||
hidden_size=3072,
|
||||
intermediate_size=24576,
|
||||
num_hidden_layers=28,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=256,
|
||||
hidden_act="gelu",
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.head_dim = head_dim
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class GemmaFastRMSNorm(FastRMSNorm):
|
||||
@classmethod
|
||||
def load(cls, prefix, weights, eps=1e-6):
|
||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||
return cls(weight, eps)
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.head_dim
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
class FlashGemmaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = config.head_dim
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
|
||||
|
||||
class FlashGemmaLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = FlashGemmaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = GemmaFastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = GemmaFastRMSNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
||||
class FlashGemmaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
embed_norm = config.hidden_size**0.5
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.embed_tokens.weight *= embed_norm
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashGemmaLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = GemmaFastRMSNorm.load(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashGemmaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config,
|
||||
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
|
@ -0,0 +1,104 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
GemmaTokenizerFast,
|
||||
FlashGemmaForCausalLM,
|
||||
GemmaConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashGemma(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_medusa: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma is only available on GPU")
|
||||
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=True,
|
||||
from_slow=False,
|
||||
)
|
||||
|
||||
config = GemmaConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashGemmaForCausalLM(config, weights)
|
||||
if use_medusa:
|
||||
from text_generation_server.utils.medusa import MedusaModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (
|
||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
)
|
||||
medusa_head = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
weights = Weights(
|
||||
[medusa_sf], device, dtype, process_group=self.process_group
|
||||
)
|
||||
lm_head = model.lm_head
|
||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGemma, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
Loading…
Reference in New Issue