feat(server): Add Non flash MPT. (#514)
# What does this PR do? This adds a non flash version of MPT. Flash is harder because we need to create a bias ready cuda kernel of flash attention. Fixes https://github.com/huggingface/text-generation-inference/issues/361 Fixes https://github.com/huggingface/text-generation-inference/issues/491 Fixes https://github.com/huggingface/text-generation-inference/issues/290
This commit is contained in:
parent
e28a809004
commit
1da07e85aa
|
@ -0,0 +1,140 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 17,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -1.5117188,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -8.96875,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -1.953125,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.94189453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 428,
|
||||
"logprob": -1.5830078,
|
||||
"special": false,
|
||||
"text": " -"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -3.3105469,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -0.3215332,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.5566406,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 30763,
|
||||
"logprob": -1.6074219,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -0.69628906,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -0.6923828,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -0.5263672,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 749,
|
||||
"logprob": -1.8544922,
|
||||
"special": false,
|
||||
"text": " sub"
|
||||
},
|
||||
{
|
||||
"id": 3423,
|
||||
"logprob": -0.6118164,
|
||||
"special": false,
|
||||
"text": "field"
|
||||
},
|
||||
{
|
||||
"id": 273,
|
||||
"logprob": -0.055877686,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5145,
|
||||
"logprob": -1.0537109,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 4715,
|
||||
"logprob": -0.0115737915,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.9111328,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 4648,
|
||||
"logprob": -1.4589844,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
},
|
||||
{
|
||||
"id": 13345,
|
||||
"logprob": -1.4853516,
|
||||
"special": false,
|
||||
"text": " artificial"
|
||||
},
|
||||
{
|
||||
"id": 11454,
|
||||
"logprob": -0.021636963,
|
||||
"special": false,
|
||||
"text": " neural"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||
}
|
|
@ -0,0 +1,562 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 17,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -1.5117188,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -8.96875,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -1.953125,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.94189453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 428,
|
||||
"logprob": -1.5830078,
|
||||
"special": false,
|
||||
"text": " -"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -3.3183594,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -0.32617188,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 30763,
|
||||
"logprob": -1.6015625,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -0.69628906,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -0.67822266,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -0.5395508,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 749,
|
||||
"logprob": -1.8623047,
|
||||
"special": false,
|
||||
"text": " sub"
|
||||
},
|
||||
{
|
||||
"id": 3423,
|
||||
"logprob": -0.6020508,
|
||||
"special": false,
|
||||
"text": "field"
|
||||
},
|
||||
{
|
||||
"id": 273,
|
||||
"logprob": -0.0552063,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5145,
|
||||
"logprob": -1.0742188,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 4715,
|
||||
"logprob": -0.011405945,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.9165039,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 4648,
|
||||
"logprob": -1.4501953,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
},
|
||||
{
|
||||
"id": 13345,
|
||||
"logprob": -1.4960938,
|
||||
"special": false,
|
||||
"text": " artificial"
|
||||
},
|
||||
{
|
||||
"id": 11454,
|
||||
"logprob": -0.02116394,
|
||||
"special": false,
|
||||
"text": " neural"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 17,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -1.5,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -8.984375,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -1.96875,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.93359375,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 428,
|
||||
"logprob": -1.5800781,
|
||||
"special": false,
|
||||
"text": " -"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -3.3242188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -0.31835938,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.5644531,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 30763,
|
||||
"logprob": -1.5957031,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -0.69628906,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -0.68603516,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -0.5258789,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 749,
|
||||
"logprob": -1.859375,
|
||||
"special": false,
|
||||
"text": " sub"
|
||||
},
|
||||
{
|
||||
"id": 3423,
|
||||
"logprob": -0.6166992,
|
||||
"special": false,
|
||||
"text": "field"
|
||||
},
|
||||
{
|
||||
"id": 273,
|
||||
"logprob": -0.056762695,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5145,
|
||||
"logprob": -1.0703125,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 4715,
|
||||
"logprob": -0.011428833,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.9213867,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 4648,
|
||||
"logprob": -1.4726562,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
},
|
||||
{
|
||||
"id": 13345,
|
||||
"logprob": -1.5039062,
|
||||
"special": false,
|
||||
"text": " artificial"
|
||||
},
|
||||
{
|
||||
"id": 11454,
|
||||
"logprob": -0.021652222,
|
||||
"special": false,
|
||||
"text": " neural"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 17,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -1.5,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -8.984375,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -1.96875,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.93359375,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 428,
|
||||
"logprob": -1.5800781,
|
||||
"special": false,
|
||||
"text": " -"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -3.3242188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -0.31835938,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.5644531,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 30763,
|
||||
"logprob": -1.5957031,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -0.69628906,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -0.68603516,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -0.5258789,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 749,
|
||||
"logprob": -1.859375,
|
||||
"special": false,
|
||||
"text": " sub"
|
||||
},
|
||||
{
|
||||
"id": 3423,
|
||||
"logprob": -0.6166992,
|
||||
"special": false,
|
||||
"text": "field"
|
||||
},
|
||||
{
|
||||
"id": 273,
|
||||
"logprob": -0.056762695,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5145,
|
||||
"logprob": -1.0703125,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 4715,
|
||||
"logprob": -0.011428833,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.9213867,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 4648,
|
||||
"logprob": -1.4726562,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
},
|
||||
{
|
||||
"id": 13345,
|
||||
"logprob": -1.5039062,
|
||||
"special": false,
|
||||
"text": " artificial"
|
||||
},
|
||||
{
|
||||
"id": 11454,
|
||||
"logprob": -0.021652222,
|
||||
"special": false,
|
||||
"text": " neural"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 17,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -1.5,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -8.984375,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -1.96875,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.93359375,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 428,
|
||||
"logprob": -1.5800781,
|
||||
"special": false,
|
||||
"text": " -"
|
||||
},
|
||||
{
|
||||
"id": 18147,
|
||||
"logprob": -3.3242188,
|
||||
"special": false,
|
||||
"text": " Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -0.31835938,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 187,
|
||||
"logprob": -2.5644531,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 30763,
|
||||
"logprob": -1.5957031,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 20727,
|
||||
"logprob": -0.69628906,
|
||||
"special": false,
|
||||
"text": " Learning"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -0.68603516,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -0.5258789,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 749,
|
||||
"logprob": -1.859375,
|
||||
"special": false,
|
||||
"text": " sub"
|
||||
},
|
||||
{
|
||||
"id": 3423,
|
||||
"logprob": -0.6166992,
|
||||
"special": false,
|
||||
"text": "field"
|
||||
},
|
||||
{
|
||||
"id": 273,
|
||||
"logprob": -0.056762695,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5145,
|
||||
"logprob": -1.0703125,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 4715,
|
||||
"logprob": -0.011428833,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 326,
|
||||
"logprob": -0.9213867,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 4648,
|
||||
"logprob": -1.4726562,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
},
|
||||
{
|
||||
"id": 13345,
|
||||
"logprob": -1.5039062,
|
||||
"special": false,
|
||||
"text": " artificial"
|
||||
},
|
||||
{
|
||||
"id": 11454,
|
||||
"logprob": -0.021652222,
|
||||
"special": false,
|
||||
"text": " neural"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,48 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mpt_sharded_handle(launcher):
|
||||
with launcher("mosaicml/mpt-7b", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def mpt_sharded(mpt_sharded_handle):
|
||||
await mpt_sharded_handle.health(300)
|
||||
return mpt_sharded_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mpt(mpt_sharded, response_snapshot):
|
||||
response = await mpt_sharded.generate(
|
||||
"What is Deep Learning?",
|
||||
max_new_tokens=17,
|
||||
decoder_input_details=True,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 17
|
||||
assert (
|
||||
response.generated_text
|
||||
== " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mpt_load(mpt_sharded, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
mpt_sharded,
|
||||
"What is Deep Learning?",
|
||||
max_new_tokens=17,
|
||||
n=4,
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
assert (
|
||||
responses[0].generated_text
|
||||
== " - Deep Learning\nDeep Learning is a subfield of machine learning that uses artificial neural"
|
||||
)
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -187,6 +187,17 @@ wrapt = ">=1.10,<2"
|
|||
[package.extras]
|
||||
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "einops"
|
||||
version = "0.6.1"
|
||||
description = "A new flavour of deep learning operations"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "einops-0.6.1-py3-none-any.whl", hash = "sha256:99149e46cc808956b174932fe563d920db4d6e5dadb8c6ecdaa7483b7ef7cfc3"},
|
||||
{file = "einops-0.6.1.tar.gz", hash = "sha256:f95f8d00f4ded90dbc4b19b6f98b177332614b0357dde66997f3ae5d474dc8c8"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.1.1"
|
||||
|
@ -1586,4 +1597,4 @@ bnb = ["bitsandbytes"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "54ecacb32d699cb1298c237c4661c1b707f119cf2c27bd54bad7a1ea2ffb8b10"
|
||||
content-hash = "3174a211d30bed5990ed5f8418416c951bb6c585153fb51b62809baa89ef07d0"
|
||||
|
|
|
@ -27,6 +27,7 @@ sentencepiece = "^0.1.97"
|
|||
tokenizers = "0.13.3"
|
||||
huggingface-hub = "^0.14.1"
|
||||
transformers = "^4.29.2"
|
||||
einops = "^0.6.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
accelerate = ["accelerate"]
|
||||
|
|
|
@ -4,6 +4,7 @@ charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0"
|
|||
click==8.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
googleapis-common-protos==1.59.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
|
|
|
@ -10,6 +10,7 @@ from text_generation_server.models.model import Model
|
|||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.bloom import BLOOMSharded
|
||||
from text_generation_server.models.mpt import MPTSharded
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.rw import RW
|
||||
from text_generation_server.models.opt import OPTSharded
|
||||
|
@ -178,6 +179,10 @@ def get_model(
|
|||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == "mpt":
|
||||
return MPTSharded(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
elif model_type == "gpt_neox":
|
||||
if FLASH_ATTENTION:
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,90 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import Optional, Type
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||
MPTForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class MPTCausalLMBatch(CausalLMBatch):
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
||||
|
||||
class MPTSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
raise NotImplementedError("MPTSharded is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
|
||||
with open(filename, "r") as f:
|
||||
config = json.load(f)
|
||||
config = PretrainedConfig(**config)
|
||||
config.quantize = quantize
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
|
||||
config.quantize = quantize
|
||||
model = MPTForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return MPTCausalLMBatch
|
|
@ -31,7 +31,19 @@ def load_layer_norm(cls, prefix, weights, eps):
|
|||
return ln
|
||||
|
||||
|
||||
@classmethod
|
||||
def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
with init_empty_weights():
|
||||
ln = cls(weight.shape, eps=eps)
|
||||
|
||||
ln.weight = nn.Parameter(weight)
|
||||
ln.bias = None
|
||||
return ln
|
||||
|
||||
|
||||
torch.nn.LayerNorm.load = load_layer_norm
|
||||
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||
|
||||
|
||||
class FastLinear(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue