feat(server): Add Non flash MPT. (#514)

# What does this PR do?


This adds a non flash version of MPT.
Flash is harder because we need to create a bias ready cuda kernel of
flash attention.

Fixes
https://github.com/huggingface/text-generation-inference/issues/361
Fixes
https://github.com/huggingface/text-generation-inference/issues/491
Fixes
https://github.com/huggingface/text-generation-inference/issues/290
This commit is contained in:
Nicolas Patry 2023-07-03 13:01:46 +02:00 committed by GitHub
parent e28a809004
commit 1da07e85aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 2011 additions and 1 deletions

View File

@ -0,0 +1,140 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 17,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -1.5117188,
"text": " is"
},
{
"id": 18147,
"logprob": -8.96875,
"text": " Deep"
},
{
"id": 20727,
"logprob": -1.953125,
"text": " Learning"
},
{
"id": 32,
"logprob": -0.94189453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 428,
"logprob": -1.5830078,
"special": false,
"text": " -"
},
{
"id": 18147,
"logprob": -3.3105469,
"special": false,
"text": " Deep"
},
{
"id": 20727,
"logprob": -0.3215332,
"special": false,
"text": " Learning"
},
{
"id": 187,
"logprob": -2.5566406,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.6074219,
"special": false,
"text": "Deep"
},
{
"id": 20727,
"logprob": -0.69628906,
"special": false,
"text": " Learning"
},
{
"id": 310,
"logprob": -0.6923828,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.5263672,
"special": false,
"text": " a"
},
{
"id": 749,
"logprob": -1.8544922,
"special": false,
"text": " sub"
},
{
"id": 3423,
"logprob": -0.6118164,
"special": false,
"text": "field"
},
{
"id": 273,
"logprob": -0.055877686,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.0537109,
"special": false,
"text": " machine"
},
{
"id": 4715,
"logprob": -0.0115737915,
"special": false,
"text": " learning"
},
{
"id": 326,
"logprob": -0.9111328,
"special": false,
"text": " that"
},
{
"id": 4648,
"logprob": -1.4589844,
"special": false,
"text": " uses"
},
{
"id": 13345,
"logprob": -1.4853516,
"special": false,
"text": " artificial"
},
{
"id": 11454,
"logprob": -0.021636963,
"special": false,
"text": " neural"
}
]
},
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
}

View File

@ -0,0 +1,562 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 17,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -1.5117188,
"text": " is"
},
{
"id": 18147,
"logprob": -8.96875,
"text": " Deep"
},
{
"id": 20727,
"logprob": -1.953125,
"text": " Learning"
},
{
"id": 32,
"logprob": -0.94189453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 428,
"logprob": -1.5830078,
"special": false,
"text": " -"
},
{
"id": 18147,
"logprob": -3.3183594,
"special": false,
"text": " Deep"
},
{
"id": 20727,
"logprob": -0.32617188,
"special": false,
"text": " Learning"
},
{
"id": 187,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.6015625,
"special": false,
"text": "Deep"
},
{
"id": 20727,
"logprob": -0.69628906,
"special": false,
"text": " Learning"
},
{
"id": 310,
"logprob": -0.67822266,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.5395508,
"special": false,
"text": " a"
},
{
"id": 749,
"logprob": -1.8623047,
"special": false,
"text": " sub"
},
{
"id": 3423,
"logprob": -0.6020508,
"special": false,
"text": "field"
},
{
"id": 273,
"logprob": -0.0552063,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.0742188,
"special": false,
"text": " machine"
},
{
"id": 4715,
"logprob": -0.011405945,
"special": false,
"text": " learning"
},
{
"id": 326,
"logprob": -0.9165039,
"special": false,
"text": " that"
},
{
"id": 4648,
"logprob": -1.4501953,
"special": false,
"text": " uses"
},
{
"id": 13345,
"logprob": -1.4960938,
"special": false,
"text": " artificial"
},
{
"id": 11454,
"logprob": -0.02116394,
"special": false,
"text": " neural"
}
]
},
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 17,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -1.5,
"text": " is"
},
{
"id": 18147,
"logprob": -8.984375,
"text": " Deep"
},
{
"id": 20727,
"logprob": -1.96875,
"text": " Learning"
},
{
"id": 32,
"logprob": -0.93359375,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 428,
"logprob": -1.5800781,
"special": false,
"text": " -"
},
{
"id": 18147,
"logprob": -3.3242188,
"special": false,
"text": " Deep"
},
{
"id": 20727,
"logprob": -0.31835938,
"special": false,
"text": " Learning"
},
{
"id": 187,
"logprob": -2.5644531,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.5957031,
"special": false,
"text": "Deep"
},
{
"id": 20727,
"logprob": -0.69628906,
"special": false,
"text": " Learning"
},
{
"id": 310,
"logprob": -0.68603516,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.5258789,
"special": false,
"text": " a"
},
{
"id": 749,
"logprob": -1.859375,
"special": false,
"text": " sub"
},
{
"id": 3423,
"logprob": -0.6166992,
"special": false,
"text": "field"
},
{
"id": 273,
"logprob": -0.056762695,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.0703125,
"special": false,
"text": " machine"
},
{
"id": 4715,
"logprob": -0.011428833,
"special": false,
"text": " learning"
},
{
"id": 326,
"logprob": -0.9213867,
"special": false,
"text": " that"
},
{
"id": 4648,
"logprob": -1.4726562,
"special": false,
"text": " uses"
},
{
"id": 13345,
"logprob": -1.5039062,
"special": false,
"text": " artificial"
},
{
"id": 11454,
"logprob": -0.021652222,
"special": false,
"text": " neural"
}
]
},
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 17,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -1.5,
"text": " is"
},
{
"id": 18147,
"logprob": -8.984375,
"text": " Deep"
},
{
"id": 20727,
"logprob": -1.96875,
"text": " Learning"
},
{
"id": 32,
"logprob": -0.93359375,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 428,
"logprob": -1.5800781,
"special": false,
"text": " -"
},
{
"id": 18147,
"logprob": -3.3242188,
"special": false,
"text": " Deep"
},
{
"id": 20727,
"logprob": -0.31835938,
"special": false,
"text": " Learning"
},
{
"id": 187,
"logprob": -2.5644531,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.5957031,
"special": false,
"text": "Deep"
},
{
"id": 20727,
"logprob": -0.69628906,
"special": false,
"text": " Learning"
},
{
"id": 310,
"logprob": -0.68603516,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.5258789,
"special": false,
"text": " a"
},
{
"id": 749,
"logprob": -1.859375,
"special": false,
"text": " sub"
},
{
"id": 3423,
"logprob": -0.6166992,
"special": false,
"text": "field"
},
{
"id": 273,
"logprob": -0.056762695,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.0703125,
"special": false,
"text": " machine"
},
{
"id": 4715,
"logprob": -0.011428833,
"special": false,
"text": " learning"
},
{
"id": 326,
"logprob": -0.9213867,
"special": false,
"text": " that"
},
{
"id": 4648,
"logprob": -1.4726562,
"special": false,
"text": " uses"
},
{
"id": 13345,
"logprob": -1.5039062,
"special": false,
"text": " artificial"
},
{
"id": 11454,
"logprob": -0.021652222,
"special": false,
"text": " neural"
}
]
},
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 17,
"prefill": [
{
"id": 1276,
"logprob": null,
"text": "What"
},
{
"id": 310,
"logprob": -1.5,
"text": " is"
},
{
"id": 18147,
"logprob": -8.984375,
"text": " Deep"
},
{
"id": 20727,
"logprob": -1.96875,
"text": " Learning"
},
{
"id": 32,
"logprob": -0.93359375,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 428,
"logprob": -1.5800781,
"special": false,
"text": " -"
},
{
"id": 18147,
"logprob": -3.3242188,
"special": false,
"text": " Deep"
},
{
"id": 20727,
"logprob": -0.31835938,
"special": false,
"text": " Learning"
},
{
"id": 187,
"logprob": -2.5644531,
"special": false,
"text": "\n"
},
{
"id": 30763,
"logprob": -1.5957031,
"special": false,
"text": "Deep"
},
{
"id": 20727,
"logprob": -0.69628906,
"special": false,
"text": " Learning"
},
{
"id": 310,
"logprob": -0.68603516,
"special": false,
"text": " is"
},
{
"id": 247,
"logprob": -0.5258789,
"special": false,
"text": " a"
},
{
"id": 749,
"logprob": -1.859375,
"special": false,
"text": " sub"
},
{
"id": 3423,
"logprob": -0.6166992,
"special": false,
"text": "field"
},
{
"id": 273,
"logprob": -0.056762695,
"special": false,
"text": " of"
},
{
"id": 5145,
"logprob": -1.0703125,
"special": false,
"text": " machine"
},
{
"id": 4715,
"logprob": -0.011428833,
"special": false,
"text": " learning"
},
{
"id": 326,
"logprob": -0.9213867,
"special": false,
"text": " that"
},
{
"id": 4648,
"logprob": -1.4726562,
"special": false,
"text": " uses"
},
{
"id": 13345,
"logprob": -1.5039062,
"special": false,
"text": " artificial"
},
{
"id": 11454,
"logprob": -0.021652222,
"special": false,
"text": " neural"
}
]
},
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
}
]

View File

@ -0,0 +1,48 @@
import pytest
@pytest.fixture(scope="module")
def mpt_sharded_handle(launcher):
with launcher("mosaicml/mpt-7b", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def mpt_sharded(mpt_sharded_handle):
await mpt_sharded_handle.health(300)
return mpt_sharded_handle.client
@pytest.mark.asyncio
async def test_mpt(mpt_sharded, response_snapshot):
response = await mpt_sharded.generate(
"What is Deep Learning?",
max_new_tokens=17,
decoder_input_details=True,
)
assert response.details.generated_tokens == 17
assert (
response.generated_text
== " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_mpt_load(mpt_sharded, generate_load, response_snapshot):
responses = await generate_load(
mpt_sharded,
"What is Deep Learning?",
max_new_tokens=17,
n=4,
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert (
responses[0].generated_text
== " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
)
assert responses == response_snapshot

13
server/poetry.lock generated
View File

@ -187,6 +187,17 @@ wrapt = ">=1.10,<2"
[package.extras] [package.extras]
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
[[package]]
name = "einops"
version = "0.6.1"
description = "A new flavour of deep learning operations"
optional = false
python-versions = ">=3.7"
files = [
{file = "einops-0.6.1-py3-none-any.whl", hash = "sha256:99149e46cc808956b174932fe563d920db4d6e5dadb8c6ecdaa7483b7ef7cfc3"},
{file = "einops-0.6.1.tar.gz", hash = "sha256:f95f8d00f4ded90dbc4b19b6f98b177332614b0357dde66997f3ae5d474dc8c8"},
]
[[package]] [[package]]
name = "exceptiongroup" name = "exceptiongroup"
version = "1.1.1" version = "1.1.1"
@ -1586,4 +1597,4 @@ bnb = ["bitsandbytes"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "54ecacb32d699cb1298c237c4661c1b707f119cf2c27bd54bad7a1ea2ffb8b10" content-hash = "3174a211d30bed5990ed5f8418416c951bb6c585153fb51b62809baa89ef07d0"

View File

@ -27,6 +27,7 @@ sentencepiece = "^0.1.97"
tokenizers = "0.13.3" tokenizers = "0.13.3"
huggingface-hub = "^0.14.1" huggingface-hub = "^0.14.1"
transformers = "^4.29.2" transformers = "^4.29.2"
einops = "^0.6.1"
[tool.poetry.extras] [tool.poetry.extras]
accelerate = ["accelerate"] accelerate = ["accelerate"]

View File

@ -4,6 +4,7 @@ charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0" click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0" filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0" fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0" googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"

View File

@ -10,6 +10,7 @@ from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW from text_generation_server.models.rw import RW
from text_generation_server.models.opt import OPTSharded from text_generation_server.models.opt import OPTSharded
@ -178,6 +179,10 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == "mpt":
return MPTSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
elif model_type == "gpt_neox": elif model_type == "gpt_neox":
if FLASH_ATTENTION: if FLASH_ATTENTION:

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,90 @@
import torch
import torch.distributed
from typing import Optional, Type
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
from huggingface_hub import hf_hub_download
import json
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class MPTCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
class MPTSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16
else:
raise NotImplementedError("MPTSharded is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
with open(filename, "r") as f:
config = json.load(f)
config = PretrainedConfig(**config)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
config.quantize = quantize
model = MPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return MPTCausalLMBatch

View File

@ -31,7 +31,19 @@ def load_layer_norm(cls, prefix, weights, eps):
return ln return ln
@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights():
ln = cls(weight.shape, eps=eps)
ln.weight = nn.Parameter(weight)
ln.bias = None
return ln
torch.nn.LayerNorm.load = load_layer_norm torch.nn.LayerNorm.load = load_layer_norm
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
class FastLinear(nn.Module): class FastLinear(nn.Module):