This commit is contained in:
OlivierDehaene 2024-01-26 19:04:57 +01:00 committed by GitHub
parent d9758851be
commit c2d4a3b5c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 2333 additions and 849 deletions

View File

@ -1,12 +0,0 @@
name: Delete doc comment
on:
pull_request:
types: [ closed ]
jobs:
delete:
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
with:
pr_number: ${{ github.event.number }}

420
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ members = [
resolver = "2"
[workspace.package]
version = "1.3.4"
version = "1.4.0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -62,7 +62,7 @@ For a detailed starting guide, please see the [Quick Tour](https://huggingface.c
model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
```
And then you can make requests like
@ -76,7 +76,7 @@ curl 127.0.0.1:8080/generate \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3-rocm --model-id $model` instead of the command above.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```
@ -106,7 +106,7 @@ model=meta-llama/Llama-2-7b-chat-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token>
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
```
### A note on Shared Memory (shm)

File diff suppressed because one or more lines are too long

View File

@ -19,6 +19,6 @@ docker run --gpus all \
--shm-size 1g \
-e HUGGING_FACE_HUB_TOKEN=$token \
-p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 \
--model-id $model
```

View File

@ -8,7 +8,7 @@ Let's say you want to deploy [Falcon-7B Instruct](https://huggingface.co/tiiuae/
model=tiiuae/falcon-7b-instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3 --model-id $model
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
```
<Tip warning={true}>
@ -20,7 +20,7 @@ To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://d
TGI also supports ROCm-enabled AMD GPUs (only MI210 and MI250 are tested), details are available in the [Supported Hardware section](./supported_models#supported-hardware) and [AMD documentation](https://rocm.docs.amd.com/en/latest/deploy/docker.html). To launch TGI on ROCm GPUs, please use instead:
```bash
docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.3-rocm --model-id $model
docker run --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model
```
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
@ -91,7 +91,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash
docker run ghcr.io/huggingface/text-generation-inference:1.3 --help
docker run ghcr.io/huggingface/text-generation-inference:1.4 --help
```
</Tip>

View File

@ -21,7 +21,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
)
assert response.details.generated_tokens == 10
assert response.generated_text == ": {request}\")\n response = self"
assert response.generated_text == ': {request}")\n response = self'
assert response == response_snapshot
@ -52,14 +52,12 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
responses = await generate_load(
flash_phi, "Test request", max_new_tokens=10, n=4
)
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ": {request}\")\n response = self"
assert responses[0].generated_text == ': {request}")\n response = self'
assert responses == response_snapshot

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-integration-tests"
version = "1.3.4"
version = "1.4.0"
description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"]

1193
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-server"
version = "1.3.4"
version = "1.4.0"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -13,11 +13,11 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -28,18 +28,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -12,11 +12,11 @@ grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -27,18 +27,18 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -3,24 +3,27 @@ from text_generation_server.utils.layers import (
TensorParallelEmbedding,
)
class ProcessGroup:
def __init__(self, rank: int, world_size: int):
self._rank = rank
self.world_size = world_size
def size(self)->int:
def size(self) -> int:
return self.world_size
def rank(self)->int:
def rank(self) -> int:
return self._rank
class Weights:
def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
self.weight = torch.arange(vocab_size*hidden_dim).float().view(vocab_size, hidden_dim)
self.weight = (
torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
)
self.process_group = ProcessGroup(rank, world_size)
def get_partial_sharded(self, name:str, dim: int):
def get_partial_sharded(self, name: str, dim: int):
assert dim == 0
rank = self.process_group.rank()
@ -35,10 +38,11 @@ class Weights:
def get_shape(self, name: str):
return self.weight.shape
def test_weight_hub_files_offline_error():
vocab_size= 17
weights = Weights(rank=0, world_size=1, vocab_size = vocab_size,hidden_dim = 256)
vocab_size = 17
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256)
embeddings = TensorParallelEmbedding("", weights)
input_ids = torch.arange(vocab_size)
@ -47,18 +51,27 @@ def test_weight_hub_files_offline_error():
assert embeddings.max_id == 17
torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))
weights_0_2 = Weights(rank=0, world_size=2, vocab_size = vocab_size,hidden_dim = 256)
weights_1_2 = Weights(rank=1, world_size=2, vocab_size = vocab_size,hidden_dim = 256)
weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
assert embeddings_0_2.min_id == 0
assert embeddings_0_2.max_id == 9
torch.testing.assert_close(embeddings_0_2.weight , torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0).view(10, 256).float())
torch.testing.assert_close(
embeddings_0_2.weight,
torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
.view(10, 256)
.float(),
)
embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
assert embeddings_1_2.min_id == 9
assert embeddings_1_2.max_id == 17
torch.testing.assert_close(embeddings_1_2.weight , torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0).view(9, 256).float())
torch.testing.assert_close(
embeddings_1_2.weight,
torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
.view(9, 256)
.float(),
)
output_tp_0 = embeddings_0_2.forward(input_ids)
output_tp_1 = embeddings_1_2.forward(input_ids)
torch.testing.assert_close(output, output_tp_0 + output_tp_1)

View File

@ -252,7 +252,9 @@ def get_model(
elif model_type == "phi-msft":
if FLASH_ATTENTION:
raise NotImplementedError("Legacy phi-msft is not supported with Flash Attention")
raise NotImplementedError(
"Legacy phi-msft is not supported with Flash Attention"
)
else:
return Phi(
model_id,

View File

@ -17,6 +17,7 @@ from text_generation_server.utils.layers import (
FastLayerNorm,
)
class PhiConfig(PretrainedConfig):
def __init__(
self,
@ -55,6 +56,7 @@ class PhiConfig(PretrainedConfig):
**kwargs,
)
# this is the same as llama except for Phi uses bias=True
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
@ -68,6 +70,7 @@ def load_attention(config, prefix, weights):
bias=True,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
@ -94,6 +97,7 @@ def _load_gqa(config, prefix: str, weights):
get_linear(weight, bias=True, quantize=config.quantize)
)
class FlashPhiAttention(torch.nn.Module):
def __init__(
self,
@ -173,8 +177,7 @@ class FlashPhiAttention(torch.nn.Module):
#
# Apply partial positional embeddings in place
self.rotary_emb(
query[:, :, :self.rotary_dim], kv[:, 0, :, :self.rotary_dim],
cos, sin
query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
)
# Reshape key and value and cache
@ -210,7 +213,8 @@ class FlashPhiAttention(torch.nn.Module):
max_s,
)
return self.dense(attn_output.view(-1, self.num_heads*self.head_size))
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
@ -256,7 +260,9 @@ class FlashPhiLayer(nn.Module):
)
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
@ -287,10 +293,13 @@ class FlashPhiLayer(nn.Module):
max_s,
)
hidden_states = self.resid_dropout(attn_output).add(self.resid_dropout(self.mlp(hidden_states)))
hidden_states = self.resid_dropout(attn_output).add(
self.resid_dropout(self.mlp(hidden_states))
)
return hidden_states, res
class FlashPhiModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
@ -361,6 +370,7 @@ class FlashPhiModel(torch.nn.Module):
return hidden_states
class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()

View File

@ -54,9 +54,19 @@ def load_col(config, prefix, weights, bias):
bias_h = bias_h[0]
bias_block_size = bias_h // bias_size
bias_q_part = bias_slice_[bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size]
bias_k_part = bias_slice_[bias_h + bias_rank * bias_block_size : bias_h + (bias_rank + 1) * bias_block_size]
bias_v_part = bias_slice_[2 * bias_h + bias_rank * bias_block_size : 2 * bias_h + (bias_rank + 1) * bias_block_size]
bias_q_part = bias_slice_[
bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size
]
bias_k_part = bias_slice_[
bias_h
+ bias_rank * bias_block_size : bias_h
+ (bias_rank + 1) * bias_block_size
]
bias_v_part = bias_slice_[
2 * bias_h
+ bias_rank * bias_block_size : 2 * bias_h
+ (bias_rank + 1) * bias_block_size
]
bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)
if bias.dtype != torch.int32:
@ -352,8 +362,12 @@ class MultiheadAttention(nn.Module):
hidden_size = config.d_model
head_dim = hidden_size // self.n_heads
self.q_ln = LPLayerNorm(d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights)
self.k_ln = LPLayerNorm(self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights)
self.q_ln = LPLayerNorm(
d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights
)
self.k_ln = LPLayerNorm(
self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights
)
if self.attn_impl == "flash":
self.attn_fn = flash_attn_fn
elif self.attn_impl == "triton":
@ -684,7 +698,6 @@ class LPLayerNorm(torch.nn.LayerNorm):
self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0))
self.normalized_shape = self.weight.shape
def forward(self, x):
module_device = x.device
downcast_x = _cast_if_autocast_enabled(x)

View File

@ -62,14 +62,12 @@ class PhiConfig(PretrainedConfig):
**kwargs,
)
# RotaryEmbedding is a class that implements the rotary embedding.
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
inv_freq = [
1.0 / 10000.0 ** (i / dim)
for i in range(0, dim, 2)
]
inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)]
inv_freq_len = len(inv_freq)
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
@ -131,6 +129,7 @@ class PhiCausalLMHead(nn.Module):
hidden_states = self.linear(hidden_states)
return hidden_states
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
class PhiMHA(nn.Module):
def __init__(self, prefix, config, weights):
@ -172,19 +171,27 @@ class PhiMHA(nn.Module):
v = torch.cat([prev_v, v], dim=1)
past_kv_cache = [k, v]
attn_weights = torch.einsum('bthd,bshd->bhts', q, k * self.softmax_scale)
attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale)
if attention_mask is not None:
seqlen_k = k.shape[1]
seqlen_q = q.shape[1]
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), 1)
causal_mask = torch.triu(
torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device),
1,
)
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
attn_output = attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)).transpose(1, 2).flatten(-2)
attn_output = (
attn_output.view((b_size, self.num_heads, seq_len, self.head_dim))
.transpose(1, 2)
.flatten(-2)
)
return self.out_proj(attn_output), past_kv_cache
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
@ -211,12 +218,15 @@ class PhiMLP(nn.Module):
hidden_states = self.fc2(hidden_states)
return hidden_states
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
class PhiBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.layer_id = layer_id
self.layer_norm = nn.LayerNorm.load(prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon)
self.layer_norm = nn.LayerNorm.load(
prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon
)
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
@ -228,11 +238,14 @@ class PhiBlock(nn.Module):
):
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
attn_outputs, past_kv_cache = self.mixer(hidden_states, kv_cache, attention_mask)
attn_outputs, past_kv_cache = self.mixer(
hidden_states, kv_cache, attention_mask
)
feed_forward_hidden_states = self.mlp(hidden_states)
out = attn_outputs + feed_forward_hidden_states + residual
return out, past_kv_cache
# PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module):
def __init__(self, config, weights):
@ -243,7 +256,10 @@ class PhiModel(nn.Module):
prefix="transformer.embd.wte", weights=weights
)
self.blocks = nn.ModuleList(
[PhiBlock(f"transformer.h.{layer_id}", config, weights) for layer_id in range(config.n_layer)]
[
PhiBlock(f"transformer.h.{layer_id}", config, weights)
for layer_id in range(config.n_layer)
]
)
def forward(
@ -258,14 +274,19 @@ class PhiModel(nn.Module):
seq_len = hidden_states.shape[1]
mask = None if seq_len <= 1 else attention_mask
past_key_values = [None] * len(self.blocks) if past_key_values is None else past_key_values
past_key_values = (
[None] * len(self.blocks) if past_key_values is None else past_key_values
)
for index, block in enumerate(self.blocks):
hidden_states, new_key_values = block(hidden_states, past_key_values[index], mask)
hidden_states, new_key_values = block(
hidden_states, past_key_values[index], mask
)
past_key_values[index] = new_key_values
return hidden_states, past_key_values
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
@ -290,12 +311,15 @@ class PhiForCausalLM(torch.nn.Module):
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(
logits[:, :-1].view(-1, logits.size(-1)),
labels[:, 1:].view(-1)
logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1)
)
if not return_dict:
return ((loss,) + (logits,) + model_output[1:]) if loss is not None else (logits,) + model_output[1:]
return (
((loss,) + (logits,) + model_output[1:])
if loss is not None
else (logits,) + model_output[1:]
)
return CausalLMOutputWithPast(
loss=loss,
@ -304,5 +328,3 @@ class PhiForCausalLM(torch.nn.Module):
hidden_states=None,
attentions=None,
)

View File

@ -74,9 +74,9 @@ class FlashLlama(FlashCausalLM):
import os
from pathlib import Path
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
"WEIGHTS_CACHE_OVERRIDE", None
) is not None
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(

View File

@ -64,9 +64,9 @@ class FlashPhi(FlashCausalLM):
import os
from pathlib import Path
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
"WEIGHTS_CACHE_OVERRIDE", None
) is not None
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(

View File

@ -5,13 +5,17 @@ from transformers import AutoConfig, AutoTokenizer
from typing import Optional, List, Tuple
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.phi_modeling import PhiConfig, PhiForCausalLM
from text_generation_server.models.custom_modeling.phi_modeling import (
PhiConfig,
PhiForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class Phi(CausalLM):
def __init__(
self,
@ -60,4 +64,3 @@ class Phi(CausalLM):
dtype=dtype,
device=device,
)

View File

@ -510,7 +510,9 @@ class TensorParallelEmbedding(nn.Module):
block_size = (num_embeddings + world_size - 1) // world_size
self.min_id = rank * block_size
self.max_id = min(num_embeddings, (rank + 1) * block_size)
self.null_idx = weight.shape[0] # Usually block_size, might be less in non even vocab_size.
self.null_idx = weight.shape[
0
] # Usually block_size, might be less in non even vocab_size.
self.process_group = weights.process_group
self.reduce = reduce