feat: adds phi model (#1442)

This PR adds basic modeling for phi-2 

run
```bash
text-generation-server \
    serve \
    microsoft/phi-2 \
    --revision 834565c23f9b28b96ccbeabe614dd906b6db551a
```


test
```bash
curl -s localhost:3000/generate \
   -X POST \
   -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
   -H 'Content-Type: application/json' | jq .
# {
#   "generated_text": "\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from data. These"
# }
```



notes 
- recently (~1 day ago) the Phi weights and model were updated to
accommodate adding [GQA/MQA attention to the
model.](https://github.com/huggingface/transformers/pull/28163) This
impl expects the original model format so a fixed revision is required
at the moment.
- this PR only includes a basic implementation of the model and can
later be extended for support Flash and Sharded versions as well as make
use of better optimization
This commit is contained in:
drbh 2024-01-25 09:37:53 -05:00 committed by GitHub
parent 86c8335f1b
commit 7e2a7433d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1455 additions and 1 deletions

View File

@ -0,0 +1,84 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.76660156,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7246094,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.41333008,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.11785889,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.97265625,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.0569458,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
}

View File

@ -0,0 +1,60 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "stop_sequence",
"generated_tokens": 6,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 284,
"logprob": -0.19421387,
"special": false,
"text": " to"
},
{
"id": 3758,
"logprob": -0.62597656,
"special": false,
"text": " send"
},
{
"id": 1366,
"logprob": -0.87060547,
"special": false,
"text": " data"
},
{
"id": 625,
"logprob": -0.88427734,
"special": false,
"text": " over"
},
{
"id": 257,
"logprob": -1.0830078,
"special": false,
"text": " a"
},
{
"id": 3127,
"logprob": -1.9462891,
"special": false,
"text": " network"
}
],
"top_tokens": null
},
"generated_text": "Test request to send data over a network"
}

View File

@ -0,0 +1,338 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.7729492,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7241211,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.4091797,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.119018555,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.9707031,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.056854248,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.7729492,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7241211,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.4091797,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.119018555,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.9707031,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.056854248,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.7729492,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7241211,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.4091797,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.119018555,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.9707031,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.056854248,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 14402,
"logprob": null,
"text": "Test"
},
{
"id": 2581,
"logprob": -11.6171875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 25,
"logprob": -2.3203125,
"special": false,
"text": ":"
},
{
"id": 1391,
"logprob": -0.98779297,
"special": false,
"text": " {"
},
{
"id": 25927,
"logprob": -0.7729492,
"special": false,
"text": "request"
},
{
"id": 92,
"logprob": -0.7241211,
"special": false,
"text": "}"
},
{
"id": 4943,
"logprob": -0.4091797,
"special": false,
"text": "\")"
},
{
"id": 198,
"logprob": -0.119018555,
"special": false,
"text": "\n"
},
{
"id": 50280,
"logprob": -0.9707031,
"special": false,
"text": " "
},
{
"id": 26209,
"logprob": -1.4414062,
"special": false,
"text": "response"
},
{
"id": 796,
"logprob": -0.056854248,
"special": false,
"text": " ="
},
{
"id": 2116,
"logprob": -1.1533203,
"special": false,
"text": " self"
}
],
"top_tokens": null
},
"generated_text": ": {request}\")\n response = self"
}
]

View File

@ -0,0 +1,65 @@
import pytest
@pytest.fixture(scope="module")
def flash_phi_handle(launcher):
with launcher("microsoft/phi-2", num_shard=1) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_phi(flash_phi_handle):
await flash_phi_handle.health(300)
return flash_phi_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi(flash_phi, response_snapshot):
response = await flash_phi.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response.generated_text == ": {request}\")\n response = self"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_all_params(flash_phi, response_snapshot):
response = await flash_phi.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["network"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 6
assert response.generated_text == "Test request to send data over a network"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
responses = await generate_load(
flash_phi, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ": {request}\")\n response = self"
assert responses == response_snapshot

View File

@ -26,7 +26,7 @@ hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
tokenizers = "^0.15.0"
huggingface-hub = "^0.19.3"
transformers = "^4.36.1"
transformers = "^4.37.1"
einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }

View File

@ -18,6 +18,7 @@ from text_generation_server.models.galactica import GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.phi import Phi
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
@ -57,6 +58,7 @@ try:
from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
except ImportError as e:
@ -72,6 +74,7 @@ if FLASH_ATTENTION:
__all__.append(IDEFICSSharded)
__all__.append(FlashMistral)
__all__.append(FlashMixtral)
__all__.append(FlashPhi)
def get_model(
@ -227,6 +230,37 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "phi":
if FLASH_ATTENTION:
return FlashPhi(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
)
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "phi-msft":
if FLASH_ATTENTION:
raise NotImplementedError("Legacy phi-msft is not supported with Flash Attention")
else:
return Phi(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "llama" or model_type == "baichuan":
if FLASH_ATTENTION:

View File

@ -0,0 +1,400 @@
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
FastLayerNorm,
)
class PhiConfig(PretrainedConfig):
def __init__(
self,
vocab_size=51200,
hidden_size=2560,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="gelu_fast", # llama uses silu
layer_norm_eps=1e-05, # rms in llama,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
resid_pdrop=0.1, # llama doesn't have this
partial_rotary_factor=0.5, # important difference between llama and phi
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.layer_norm_eps = layer_norm_eps
self.rope_theta = rope_theta
self.resid_pdrop = resid_pdrop
self.partial_rotary_factor = partial_rotary_factor
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# this is the same as llama except for Phi uses bias=True
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
# this is the same as llama except for Phi uses bias=True
return TensorParallelColumnLinear(
get_linear(weight, bias=True, quantize=config.quantize)
)
class FlashPhiAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.softmax_scale = self.head_size**-0.5
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.rotary_dim,
base=config.rope_theta,
device=weights.device,
)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
# in llama the dense layer is called "o_proj" and has bias=False
self.dense = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.dense",
weights=weights,
bias=True,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
# Compute query, key, value and split
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
# Reshape query and key for rotary embeddings
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# NOTE: this is the main difference between Llama and Phi
# in llama the rotary embeddings are applied to the whole query and key.
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
#
# Apply partial positional embeddings in place
self.rotary_emb(
query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim],
cos, sin
)
# Reshape key and value and cache
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads*self.head_size))
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate="tanh"
if act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none",
)
)
# llama weights are up_proj and down_proj and bias=False
self.up_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.fc1",
weights=weights,
bias=True,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.fc2",
weights=weights,
bias=True,
)
def forward(self, hidden_states):
# NOTE: Llama requires the gate up states to an intermediate size
# Phi does not and we can avoid the `view` operation
return self.down_proj(self.act(self.up_proj(hidden_states)))
class FlashPhiLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashPhiAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps
)
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states)))
return hidden_states, res
class FlashPhiModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
FlashPhiLayer(
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
self.norm = FastLayerNorm.load(
prefix="model.final_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = FlashPhiModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,
prefix="lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
return self.lm_head(hidden_states)

View File

@ -0,0 +1,308 @@
# imlementation of the PhiModel and PhiForCausalLM classes
import torch
import torch.distributed
import math
from torch import nn
from typing import Optional, List, Tuple, Any
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelHead,
FastLinear,
)
# PhiConfig is the configuration class for the PhiModel.
class PhiConfig(PretrainedConfig):
def __init__(
self,
vocab_size=51200,
n_positions=2048,
n_embd=2560,
n_layer=32,
n_inner=None,
n_head=32,
rotary_dim=32,
layer_norm_epsilon=1e-5,
tie_word_embeddings=False,
pad_vocab_size_multiple=64,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
no_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.rotary_dim = rotary_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.tie_word_embeddings = tie_word_embeddings
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.no_bias = no_bias
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# RotaryEmbedding is a class that implements the rotary embedding.
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
inv_freq = [
1.0 / 10000.0 ** (i / dim)
for i in range(0, dim, 2)
]
inv_freq_len = len(inv_freq)
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
freqs = t.matmul(inv_freq)
self.sin = freqs.sin()
self.cos = freqs.cos()
def apply_rotary_emb_qkv(self, qkv, seqlen_offset):
b_size, seqlen, three, _, _headdim = qkv.shape
if three != 3:
raise Exception("unexpected shape for qkv")
_, rotary_dim = self.cos.shape
rotary_dim = rotary_dim * 2
q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:]
k_rot = qkv[:, :, 1, :, :rotary_dim]
k_pass = qkv[:, :, 1, :, rotary_dim:]
q12 = torch.chunk(q_rot, 2, dim=-1)
k12 = torch.chunk(k_rot, 2, dim=-1)
q1, q2 = q12[0], q12[1]
k1, k2 = k12[0], k12[1]
c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
q_rot = torch.cat(
[
q1 * c - q2 * s,
q1 * s + q2 * c,
],
dim=-1,
)
k_rot = torch.cat(
[
k1 * c - k2 * s,
k1 * s + k2 * c,
],
dim=-1,
)
q = torch.cat([q_rot, q_pass], dim=-1)
k = torch.cat([k_rot, k_pass], dim=-1)
v = qkv[:, :, 2]
return q, k, v
# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm.
class PhiCausalLMHead(nn.Module):
def __init__(self, config, weights):
super().__init__()
self.ln = nn.LayerNorm.load(
prefix="lm_head.ln",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.linear = TensorParallelHead.load(
config=config, prefix="lm_head.linear", weights=weights
)
def forward(self, hidden_states):
hidden_states = self.ln(hidden_states)
hidden_states = self.linear(hidden_states)
return hidden_states
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
class PhiMHA(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.Wqkv = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
self.out_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=not config.no_bias,
)
self.op_size = config.n_embd
self.head_dim = int(config.n_embd / config.n_head)
self.num_heads = config.n_head
self.rotary_emb = RotaryEmbedding(
config.rotary_dim,
config.n_positions,
)
self.softmax_scale = 1.0 / math.sqrt(self.head_dim)
def forward(
self,
hidden_states,
past_kv_cache,
attention_mask=None,
):
b_size, seq_len, _n_embd = hidden_states.shape
qkv = self.Wqkv(hidden_states)
qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim)
seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1]
q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset)
# if there is a kv_cache, then we need to concatenate
if past_kv_cache is not None:
prev_k, prev_v = past_kv_cache
k = torch.cat([prev_k, k], dim=1)
v = torch.cat([prev_v, v], dim=1)
past_kv_cache = [k, v]
attn_weights = torch.einsum('bthd,bshd->bhts', q, k * self.softmax_scale)
if attention_mask is not None:
seqlen_k = k.shape[1]
seqlen_q = q.shape[1]
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), 1)
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
attn_output = attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)).transpose(1, 2).flatten(-2)
return self.out_proj(attn_output), past_kv_cache
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.n_inner = config.n_inner
self.fc1 = FastLinear.load(
config=config,
prefix=f"{prefix}.fc1",
weights=weights,
bias=False,
)
self.fc2 = FastLinear.load(
config=config,
prefix=f"{prefix}.fc2",
weights=weights,
bias=False,
)
self.activation = torch.nn.functional.gelu
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
class PhiBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.layer_id = layer_id
self.layer_norm = nn.LayerNorm.load(prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon)
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
def forward(
self,
hidden_states,
kv_cache,
attention_mask,
):
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask)
feed_forward_hidden_states = self.mlp(hidden_states)
out = attn_outputs + feed_forward_hidden_states + residual
return out, past_kv_cache
# PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module):
def __init__(self, config, weights):
super().__init__()
self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.embd.wte", weights=weights
)
self.blocks = nn.ModuleList(
[PhiBlock(f"transformer.h.{layer_id}", config, weights) for layer_id in range(config.n_layer)]
)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
hidden_states = self.embed_tokens(input_ids)
seq_len = hidden_states.shape[1]
mask = None if seq_len <= 1 else attention_mask
past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values
for index, block in enumerate(self.blocks):
hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask)
past_key_values[index] = new_key_values
return hidden_states, past_key_values
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.model = PhiModel(config, weights)
self.lm_head = PhiCausalLMHead(config, weights)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
model_output = self.model(
input_ids, past_key_values, attention_mask, return_dict, use_cache
)
logits = self.lm_head(model_output[0])
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(
logits[:, :-1].view(-1, logits.size(-1)),
labels[:, 1:].view(-1)
)
if not return_dict:
return ((loss,) + (logits,) + model_output[1:]) if loss is not None else (logits,) + model_output[1:]
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=model_output[1],
hidden_states=None,
attentions=None,
)

View File

@ -0,0 +1,102 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
PhiConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashPhi(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashPhi is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PhiConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
"WEIGHTS_CACHE_OVERRIDE", None
) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group)
super(FlashPhi, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -0,0 +1,63 @@
import torch
import torch.distributed
from transformers import AutoConfig, AutoTokenizer
from typing import Optional, List, Tuple
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.phi_modeling import PhiConfig, PhiForCausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class Phi(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, _rank, _world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PhiConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
tokenizer.bos_token_id = config.bos_token_id
tokenizer.eos_token_id = config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
model = PhiForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)