feat: adds phi model (#1442)
This PR adds basic modeling for phi-2 run ```bash text-generation-server \ serve \ microsoft/phi-2 \ --revision 834565c23f9b28b96ccbeabe614dd906b6db551a ``` test ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq . # { # "generated_text": "\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from data. These" # } ``` notes - recently (~1 day ago) the Phi weights and model were updated to accommodate adding [GQA/MQA attention to the model.](https://github.com/huggingface/transformers/pull/28163) This impl expects the original model format so a fixed revision is required at the moment. - this PR only includes a basic implementation of the model and can later be extended for support Flash and Sharded versions as well as make use of better optimization
This commit is contained in:
parent
86c8335f1b
commit
7e2a7433d3
|
@ -0,0 +1,84 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 14402,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2581,
|
||||
"logprob": -11.6171875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -2.3203125,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1391,
|
||||
"logprob": -0.98779297,
|
||||
"special": false,
|
||||
"text": " {"
|
||||
},
|
||||
{
|
||||
"id": 25927,
|
||||
"logprob": -0.76660156,
|
||||
"special": false,
|
||||
"text": "request"
|
||||
},
|
||||
{
|
||||
"id": 92,
|
||||
"logprob": -0.7246094,
|
||||
"special": false,
|
||||
"text": "}"
|
||||
},
|
||||
{
|
||||
"id": 4943,
|
||||
"logprob": -0.41333008,
|
||||
"special": false,
|
||||
"text": "\")"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.11785889,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 50280,
|
||||
"logprob": -0.97265625,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 26209,
|
||||
"logprob": -1.4414062,
|
||||
"special": false,
|
||||
"text": "response"
|
||||
},
|
||||
{
|
||||
"id": 796,
|
||||
"logprob": -0.0569458,
|
||||
"special": false,
|
||||
"text": " ="
|
||||
},
|
||||
{
|
||||
"id": 2116,
|
||||
"logprob": -1.1533203,
|
||||
"special": false,
|
||||
"text": " self"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": ": {request}\")\n response = self"
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "stop_sequence",
|
||||
"generated_tokens": 6,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 14402,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2581,
|
||||
"logprob": -11.6171875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.19421387,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 3758,
|
||||
"logprob": -0.62597656,
|
||||
"special": false,
|
||||
"text": " send"
|
||||
},
|
||||
{
|
||||
"id": 1366,
|
||||
"logprob": -0.87060547,
|
||||
"special": false,
|
||||
"text": " data"
|
||||
},
|
||||
{
|
||||
"id": 625,
|
||||
"logprob": -0.88427734,
|
||||
"special": false,
|
||||
"text": " over"
|
||||
},
|
||||
{
|
||||
"id": 257,
|
||||
"logprob": -1.0830078,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 3127,
|
||||
"logprob": -1.9462891,
|
||||
"special": false,
|
||||
"text": " network"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request to send data over a network"
|
||||
}
|
|
@ -0,0 +1,338 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 14402,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2581,
|
||||
"logprob": -11.6171875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -2.3203125,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1391,
|
||||
"logprob": -0.98779297,
|
||||
"special": false,
|
||||
"text": " {"
|
||||
},
|
||||
{
|
||||
"id": 25927,
|
||||
"logprob": -0.7729492,
|
||||
"special": false,
|
||||
"text": "request"
|
||||
},
|
||||
{
|
||||
"id": 92,
|
||||
"logprob": -0.7241211,
|
||||
"special": false,
|
||||
"text": "}"
|
||||
},
|
||||
{
|
||||
"id": 4943,
|
||||
"logprob": -0.4091797,
|
||||
"special": false,
|
||||
"text": "\")"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.119018555,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 50280,
|
||||
"logprob": -0.9707031,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 26209,
|
||||
"logprob": -1.4414062,
|
||||
"special": false,
|
||||
"text": "response"
|
||||
},
|
||||
{
|
||||
"id": 796,
|
||||
"logprob": -0.056854248,
|
||||
"special": false,
|
||||
"text": " ="
|
||||
},
|
||||
{
|
||||
"id": 2116,
|
||||
"logprob": -1.1533203,
|
||||
"special": false,
|
||||
"text": " self"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": ": {request}\")\n response = self"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 14402,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2581,
|
||||
"logprob": -11.6171875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -2.3203125,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1391,
|
||||
"logprob": -0.98779297,
|
||||
"special": false,
|
||||
"text": " {"
|
||||
},
|
||||
{
|
||||
"id": 25927,
|
||||
"logprob": -0.7729492,
|
||||
"special": false,
|
||||
"text": "request"
|
||||
},
|
||||
{
|
||||
"id": 92,
|
||||
"logprob": -0.7241211,
|
||||
"special": false,
|
||||
"text": "}"
|
||||
},
|
||||
{
|
||||
"id": 4943,
|
||||
"logprob": -0.4091797,
|
||||
"special": false,
|
||||
"text": "\")"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.119018555,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 50280,
|
||||
"logprob": -0.9707031,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 26209,
|
||||
"logprob": -1.4414062,
|
||||
"special": false,
|
||||
"text": "response"
|
||||
},
|
||||
{
|
||||
"id": 796,
|
||||
"logprob": -0.056854248,
|
||||
"special": false,
|
||||
"text": " ="
|
||||
},
|
||||
{
|
||||
"id": 2116,
|
||||
"logprob": -1.1533203,
|
||||
"special": false,
|
||||
"text": " self"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": ": {request}\")\n response = self"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 14402,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2581,
|
||||
"logprob": -11.6171875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -2.3203125,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1391,
|
||||
"logprob": -0.98779297,
|
||||
"special": false,
|
||||
"text": " {"
|
||||
},
|
||||
{
|
||||
"id": 25927,
|
||||
"logprob": -0.7729492,
|
||||
"special": false,
|
||||
"text": "request"
|
||||
},
|
||||
{
|
||||
"id": 92,
|
||||
"logprob": -0.7241211,
|
||||
"special": false,
|
||||
"text": "}"
|
||||
},
|
||||
{
|
||||
"id": 4943,
|
||||
"logprob": -0.4091797,
|
||||
"special": false,
|
||||
"text": "\")"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.119018555,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 50280,
|
||||
"logprob": -0.9707031,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 26209,
|
||||
"logprob": -1.4414062,
|
||||
"special": false,
|
||||
"text": "response"
|
||||
},
|
||||
{
|
||||
"id": 796,
|
||||
"logprob": -0.056854248,
|
||||
"special": false,
|
||||
"text": " ="
|
||||
},
|
||||
{
|
||||
"id": 2116,
|
||||
"logprob": -1.1533203,
|
||||
"special": false,
|
||||
"text": " self"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": ": {request}\")\n response = self"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 14402,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2581,
|
||||
"logprob": -11.6171875,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -2.3203125,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1391,
|
||||
"logprob": -0.98779297,
|
||||
"special": false,
|
||||
"text": " {"
|
||||
},
|
||||
{
|
||||
"id": 25927,
|
||||
"logprob": -0.7729492,
|
||||
"special": false,
|
||||
"text": "request"
|
||||
},
|
||||
{
|
||||
"id": 92,
|
||||
"logprob": -0.7241211,
|
||||
"special": false,
|
||||
"text": "}"
|
||||
},
|
||||
{
|
||||
"id": 4943,
|
||||
"logprob": -0.4091797,
|
||||
"special": false,
|
||||
"text": "\")"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.119018555,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 50280,
|
||||
"logprob": -0.9707031,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 26209,
|
||||
"logprob": -1.4414062,
|
||||
"special": false,
|
||||
"text": "response"
|
||||
},
|
||||
{
|
||||
"id": 796,
|
||||
"logprob": -0.056854248,
|
||||
"special": false,
|
||||
"text": " ="
|
||||
},
|
||||
{
|
||||
"id": 2116,
|
||||
"logprob": -1.1533203,
|
||||
"special": false,
|
||||
"text": " self"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": ": {request}\")\n response = self"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,65 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_phi_handle(launcher):
|
||||
with launcher("microsoft/phi-2", num_shard=1) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_phi(flash_phi_handle):
|
||||
await flash_phi_handle.health(300)
|
||||
return flash_phi_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_phi(flash_phi, response_snapshot):
|
||||
response = await flash_phi.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.generated_text == ": {request}\")\n response = self"
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||
response = await flash_phi.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["network"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 6
|
||||
assert response.generated_text == "Test request to send data over a network"
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_phi, "Test request", max_new_tokens=10, n=4
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"{[r.generated_text for r in responses]}"
|
||||
assert responses[0].generated_text == ": {request}\")\n response = self"
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
|
|||
sentencepiece = "^0.1.97"
|
||||
tokenizers = "^0.15.0"
|
||||
huggingface-hub = "^0.19.3"
|
||||
transformers = "^4.36.1"
|
||||
transformers = "^4.37.1"
|
||||
einops = "^0.6.1"
|
||||
texttable = { version = "^1.6.7", optional = true }
|
||||
datasets = { version = "^2.14.0", optional = true }
|
||||
|
|
|
@ -18,6 +18,7 @@ from text_generation_server.models.galactica import GalacticaSharded
|
|||
from text_generation_server.models.santacoder import SantaCoder
|
||||
from text_generation_server.models.t5 import T5Sharded
|
||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
from text_generation_server.models.phi import Phi
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||
# in PyTorch 1.12 and later.
|
||||
|
@ -57,6 +58,7 @@ try:
|
|||
from text_generation_server.models.idefics import IDEFICSSharded
|
||||
from text_generation_server.models.flash_mistral import FlashMistral
|
||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||
from text_generation_server.models.flash_phi import FlashPhi
|
||||
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
||||
|
||||
except ImportError as e:
|
||||
|
@ -72,6 +74,7 @@ if FLASH_ATTENTION:
|
|||
__all__.append(IDEFICSSharded)
|
||||
__all__.append(FlashMistral)
|
||||
__all__.append(FlashMixtral)
|
||||
__all__.append(FlashPhi)
|
||||
|
||||
|
||||
def get_model(
|
||||
|
@ -227,6 +230,37 @@ def get_model(
|
|||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "phi":
|
||||
if FLASH_ATTENTION:
|
||||
return FlashPhi(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_medusa=use_medusa,
|
||||
)
|
||||
else:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "phi-msft":
|
||||
if FLASH_ATTENTION:
|
||||
raise NotImplementedError("Legacy phi-msft is not supported with Flash Attention")
|
||||
else:
|
||||
return Phi(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "llama" or model_type == "baichuan":
|
||||
if FLASH_ATTENTION:
|
||||
|
|
|
@ -0,0 +1,400 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
FastLayerNorm,
|
||||
)
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=51200,
|
||||
hidden_size=2560,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
hidden_act="gelu_fast", # llama uses silu
|
||||
layer_norm_eps=1e-05, # rms in llama,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
resid_pdrop=0.1, # llama doesn't have this
|
||||
partial_rotary_factor=0.5, # important difference between llama and phi
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.rope_theta = rope_theta
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# this is the same as llama except for Phi uses bias=True
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
# this is the same as llama except for Phi uses bias=True
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=True, quantize=config.quantize)
|
||||
)
|
||||
|
||||
class FlashPhiAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.rotary_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
# in llama the dense layer is called "o_proj" and has bias=False
|
||||
self.dense = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.dense",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
# Compute query, key, value and split
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Reshape query and key for rotary embeddings
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
# NOTE: this is the main difference between Llama and Phi
|
||||
# in llama the rotary embeddings are applied to the whole query and key.
|
||||
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
|
||||
#
|
||||
# Apply partial positional embeddings in place
|
||||
self.rotary_emb(
|
||||
query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim],
|
||||
cos, sin
|
||||
)
|
||||
|
||||
# Reshape key and value and cache
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads*self.head_size))
|
||||
|
||||
class PhiMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none",
|
||||
)
|
||||
)
|
||||
|
||||
# llama weights are up_proj and down_proj and bias=False
|
||||
self.up_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# NOTE: Llama requires the gate up states to an intermediate size
|
||||
# Phi does not and we can avoid the `view` operation
|
||||
return self.down_proj(self.act(self.up_proj(hidden_states)))
|
||||
|
||||
|
||||
class FlashPhiLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = FlashPhiAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.input_layernorm = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states)))
|
||||
|
||||
return hidden_states, res
|
||||
|
||||
class FlashPhiModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashPhiLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
self.norm = FastLayerNorm.load(
|
||||
prefix="model.final_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class FlashPhiForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashPhiModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
||||
return self.lm_head(hidden_states)
|
|
@ -0,0 +1,308 @@
|
|||
# imlementation of the PhiModel and PhiForCausalLM classes
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import math
|
||||
from torch import nn
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelHead,
|
||||
FastLinear,
|
||||
)
|
||||
|
||||
|
||||
# PhiConfig is the configuration class for the PhiModel.
|
||||
class PhiConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=51200,
|
||||
n_positions=2048,
|
||||
n_embd=2560,
|
||||
n_layer=32,
|
||||
n_inner=None,
|
||||
n_head=32,
|
||||
rotary_dim=32,
|
||||
layer_norm_epsilon=1e-5,
|
||||
tie_word_embeddings=False,
|
||||
pad_vocab_size_multiple=64,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
no_bias=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_inner = n_inner
|
||||
self.n_head = n_head
|
||||
self.rotary_dim = rotary_dim
|
||||
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.no_bias = no_bias
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# RotaryEmbedding is a class that implements the rotary embedding.
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
inv_freq = [
|
||||
1.0 / 10000.0 ** (i / dim)
|
||||
for i in range(0, dim, 2)
|
||||
]
|
||||
inv_freq_len = len(inv_freq)
|
||||
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
|
||||
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
|
||||
freqs = t.matmul(inv_freq)
|
||||
self.sin = freqs.sin()
|
||||
self.cos = freqs.cos()
|
||||
|
||||
def apply_rotary_emb_qkv(self, qkv, seqlen_offset):
|
||||
b_size, seqlen, three, _, _headdim = qkv.shape
|
||||
if three != 3:
|
||||
raise Exception("unexpected shape for qkv")
|
||||
_, rotary_dim = self.cos.shape
|
||||
rotary_dim = rotary_dim * 2
|
||||
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
||||
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
||||
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
||||
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
||||
q12 = torch.chunk(q_rot, 2, dim=-1)
|
||||
k12 = torch.chunk(k_rot, 2, dim=-1)
|
||||
q1, q2 = q12[0], q12[1]
|
||||
k1, k2 = k12[0], k12[1]
|
||||
c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
|
||||
s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
|
||||
q_rot = torch.cat(
|
||||
[
|
||||
q1 * c - q2 * s,
|
||||
q1 * s + q2 * c,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
k_rot = torch.cat(
|
||||
[
|
||||
k1 * c - k2 * s,
|
||||
k1 * s + k2 * c,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
q = torch.cat([q_rot, q_pass], dim=-1)
|
||||
k = torch.cat([k_rot, k_pass], dim=-1)
|
||||
v = qkv[:, :, 2]
|
||||
return q, k, v
|
||||
|
||||
|
||||
# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm.
|
||||
class PhiCausalLMHead(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm.load(
|
||||
prefix="lm_head.ln",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.linear = TensorParallelHead.load(
|
||||
config=config, prefix="lm_head.linear", weights=weights
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.ln(hidden_states)
|
||||
hidden_states = self.linear(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
|
||||
class PhiMHA(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.Wqkv = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||
)
|
||||
self.out_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
weights=weights,
|
||||
bias=not config.no_bias,
|
||||
)
|
||||
self.op_size = config.n_embd
|
||||
self.head_dim = int(config.n_embd / config.n_head)
|
||||
self.num_heads = config.n_head
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
config.rotary_dim,
|
||||
config.n_positions,
|
||||
)
|
||||
self.softmax_scale = 1.0 / math.sqrt(self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
past_kv_cache,
|
||||
attention_mask=None,
|
||||
):
|
||||
b_size, seq_len, _n_embd = hidden_states.shape
|
||||
qkv = self.Wqkv(hidden_states)
|
||||
qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||
seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1]
|
||||
q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset)
|
||||
|
||||
# if there is a kv_cache, then we need to concatenate
|
||||
if past_kv_cache is not None:
|
||||
prev_k, prev_v = past_kv_cache
|
||||
k = torch.cat([prev_k, k], dim=1)
|
||||
v = torch.cat([prev_v, v], dim=1)
|
||||
|
||||
past_kv_cache = [k, v]
|
||||
attn_weights = torch.einsum('bthd,bshd->bhts', q, k * self.softmax_scale)
|
||||
|
||||
if attention_mask is not None:
|
||||
seqlen_k = k.shape[1]
|
||||
seqlen_q = q.shape[1]
|
||||
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), 1)
|
||||
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
|
||||
attn_output = attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)).transpose(1, 2).flatten(-2)
|
||||
return self.out_proj(attn_output), past_kv_cache
|
||||
|
||||
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
|
||||
class PhiMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.n_inner = config.n_inner
|
||||
self.fc1 = FastLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.fc2 = FastLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.activation = torch.nn.functional.gelu
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
|
||||
class PhiBlock(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.layer_norm = nn.LayerNorm.load(prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon)
|
||||
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
||||
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
kv_cache,
|
||||
attention_mask,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
out = attn_outputs + feed_forward_hidden_states + residual
|
||||
return out, past_kv_cache
|
||||
|
||||
# PhiModel implements the embedding layer and the transformer blocks.
|
||||
class PhiModel(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.tp_rank = weights.process_group.rank()
|
||||
self.tp_world_size = weights.process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="transformer.embd.wte", weights=weights
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[PhiBlock(f"transformer.h.{layer_id}", config, weights) for layer_id in range(config.n_layer)]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||
attention_mask: Optional[torch.ByteTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
seq_len = hidden_states.shape[1]
|
||||
mask = None if seq_len <= 1 else attention_mask
|
||||
|
||||
past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values
|
||||
|
||||
for index, block in enumerate(self.blocks):
|
||||
hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask)
|
||||
past_key_values[index] = new_key_values
|
||||
|
||||
return hidden_states, past_key_values
|
||||
|
||||
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
|
||||
class PhiForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.model = PhiModel(config, weights)
|
||||
self.lm_head = PhiCausalLMHead(config, weights)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
||||
attention_mask: Optional[torch.ByteTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
model_output = self.model(
|
||||
input_ids, past_key_values, attention_mask, return_dict, use_cache
|
||||
)
|
||||
logits = self.lm_head(model_output[0])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = nn.CrossEntropyLoss()(
|
||||
logits[:, :-1].view(-1, logits.size(-1)),
|
||||
labels[:, 1:].view(-1)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return ((loss,) + (logits,) + model_output[1:]) if loss is not None else (logits,) + model_output[1:]
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=model_output[1],
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||
FlashPhiForCausalLM,
|
||||
PhiConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashPhi(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_medusa: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashPhi is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = PhiConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashPhiForCausalLM(config, weights)
|
||||
if use_medusa:
|
||||
from text_generation_server.utils.medusa import MedusaModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
|
||||
"WEIGHTS_CACHE_OVERRIDE", None
|
||||
) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
)
|
||||
medusa_head = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
weights = Weights(
|
||||
[medusa_sf], device, dtype, process_group=self.process_group
|
||||
)
|
||||
lm_head = model.lm_head
|
||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashPhi, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
|
@ -0,0 +1,63 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.custom_modeling.phi_modeling import PhiConfig, PhiForCausalLM
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
class Phi(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config = PhiConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
tokenizer.bos_token_id = config.bos_token_id
|
||||
tokenizer.eos_token_id = config.eos_token_id
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config.quantize = quantize
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
model = PhiForCausalLM(config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
Loading…
Reference in New Issue