From 299217c95ca314e8dbbeca26ce8cdceb440ec53b Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 11 Apr 2023 16:38:22 +0200 Subject: [PATCH] feat(server): add flash attention llama (#144) --- README.md | 10 +- launcher/tests/mt0_base.json | 2 +- server/poetry.lock | 104 ++- server/pyproject.toml | 2 + server/tests/models/test_seq2seq_lm.py | 2 +- .../text_generation_server/models/__init__.py | 27 +- .../models/causal_lm.py | 30 +- .../custom_modeling/flash_llama_modeling.py | 619 ++++++++++++++++++ .../models/flash_causal_lm.py | 30 +- .../models/flash_llama.py | 303 +++++++++ .../models/galactica.py | 8 +- server/text_generation_server/models/model.py | 48 +- .../models/seq2seq_lm.py | 30 +- 13 files changed, 1175 insertions(+), 40 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/flash_llama_modeling.py create mode 100644 server/text_generation_server/models/flash_llama.py diff --git a/README.md b/README.md index ad938bd..bc77fd4 100644 --- a/README.md +++ b/README.md @@ -51,16 +51,14 @@ to power LLMs api-inference widgets. - Log probabilities - Production ready (distributed tracing with Open Telemetry, Prometheus metrics) -## Officially supported architectures +## Optimized architectures - [BLOOM](https://huggingface.co/bigscience/bloom) -- [BLOOMZ](https://huggingface.co/bigscience/bloomz) -- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - [Galactica](https://huggingface.co/facebook/galactica-120b) - [SantaCoder](https://huggingface.co/bigcode/santacoder) -- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b) -- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) -- [FLAN-UL2](https://huggingface.co/google/flan-ul2) +- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) +- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl) +- [Llama](https://github.com/facebookresearch/llama) Other architectures are supported on a best effort basis using: diff --git a/launcher/tests/mt0_base.json b/launcher/tests/mt0_base.json index 22c9499..f5be63f 100644 --- a/launcher/tests/mt0_base.json +++ b/launcher/tests/mt0_base.json @@ -14,7 +14,7 @@ "tokens": [ { "id": 259, - "text": " ", + "text": "", "logprob": -1.3656927, "special": false }, diff --git a/server/poetry.lock b/server/poetry.lock index 89ad31e..49b9d39 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -517,6 +517,14 @@ tensorflow = ["tensorflow"] testing = ["h5py", "huggingface-hub", "numpy", "pytest", "pytest-benchmark", "setuptools-rust"] torch = ["torch"] +[[package]] +name = "sentencepiece" +version = "0.1.97" +description = "SentencePiece python wrapper" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "setuptools" version = "67.4.0" @@ -530,6 +538,19 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-g testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +[[package]] +name = "tokenizers" +version = "0.13.3" +description = "Fast and Customizable Tokenizers" +category = "main" +optional = false +python-versions = "*" + +[package.extras] +dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] + [[package]] name = "tomli" version = "2.0.1" @@ -630,7 +651,7 @@ bnb = ["bitsandbytes"] [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "521dc9f3c283dc56f7d2e2f96759919ff27ab49ffd3ae7cd26317b209e7fa98d" +content-hash = "1c57379c7b9349d2a860b50b3ab125737a0f6f94f4303d7cb55248cb86ff8b8e" [metadata.files] accelerate = [ @@ -1116,10 +1137,91 @@ safetensors = [ {file = "safetensors-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:ba3dc236a2344b7feadc9868307f42ba5e4804c9d68a80a35aac831349b31f6f"}, {file = "safetensors-0.2.8.tar.gz", hash = "sha256:2720b20a6a38c799dca79bd76caeeac2f7df585a9d4f7d59fa7e28eff9ccb27f"}, ] +sentencepiece = [ + {file = "sentencepiece-0.1.97-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6f249c8f1852893be86eae66b19d522c5fb30bbad4fe2d1b07f06fdc86e1907e"}, + {file = "sentencepiece-0.1.97-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:09e1bc53178de70c557a9ba4fece07364b4416ce3d36570726b3372b68aea135"}, + {file = "sentencepiece-0.1.97-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:667193c57fb48b238be7e3d7636cfc8da56cb5bac5559d8f0b647334e1175be8"}, + {file = "sentencepiece-0.1.97-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2780531985af79c6163f63d4f200fec8a28b70b6768d2c19f70d01568a4524e8"}, + {file = "sentencepiece-0.1.97-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:205050670c53ef9015e2a98cce3934bfbcf0aafaa14caa0c618dd5667bc217ee"}, + {file = "sentencepiece-0.1.97-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28b183dadef8e8b6b4645c1c20692d7be0a13ecc3ec1a07b3885c8905516675f"}, + {file = "sentencepiece-0.1.97-cp310-cp310-win32.whl", hash = "sha256:ee3c9dbd558d8d85bb1617087b86df6ea2b856a528669630ce6cedeb4353b823"}, + {file = "sentencepiece-0.1.97-cp310-cp310-win_amd64.whl", hash = "sha256:f7dc55379e2f7dee86537180283db2e5f8418c6825fdd2fe436c724eb5604c05"}, + {file = "sentencepiece-0.1.97-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:ba1b4154f9144c5a7528b00aff5cffaa1a896a1c6ca53ca78b6e74cd2dae5244"}, + {file = "sentencepiece-0.1.97-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac3d90aee5581e55d029d124ac11b6ae2fbae0817863b664b2f2302e966ababb"}, + {file = "sentencepiece-0.1.97-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c27400f1ac46518a01c87cb7703650e4e48728649feb115d2e3f1102a946a42"}, + {file = "sentencepiece-0.1.97-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6e12a166eba75994ca749aadc4a5056b91b31405f805d6de6e8914cc9741c60"}, + {file = "sentencepiece-0.1.97-cp36-cp36m-win32.whl", hash = "sha256:ed85dff5c0a9b3dd1a414c7e1119f2a19e863fc3f81da525bf7f885ebc883de0"}, + {file = "sentencepiece-0.1.97-cp36-cp36m-win_amd64.whl", hash = "sha256:91a19ab6f40ffbae6d6127119953d2c6a85e93d734953dbc8629fde0d21ace66"}, + {file = "sentencepiece-0.1.97-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:bae580e4a35a9314ff49561ac7c06574fe6afc71b821ed6bb00534e571458156"}, + {file = "sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ad7262e7530c683b186672b5dd0082f82719a50a500a8cfbc4bbd7cde5bff8c"}, + {file = "sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:620cee35279720016735a7c7103cddbd9b84fe5e2f098bd5e673834d69fee2b8"}, + {file = "sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93b921b59914c0ec6697e8c6d5e6b44d99d1298fb1a0af56980a79ade0540c19"}, + {file = "sentencepiece-0.1.97-cp37-cp37m-win32.whl", hash = "sha256:9b9a4c44a31d5f47616e9568dcf31e029b0bfa776e0a252c0b59247881598b09"}, + {file = "sentencepiece-0.1.97-cp37-cp37m-win_amd64.whl", hash = "sha256:f31533cdacced56219e239d3459a003ece35116920dd64b2309d4ad047b77644"}, + {file = "sentencepiece-0.1.97-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:7d643c01d1cad13b9206a276bbe5bc1a468e3d7cf6a26bde7783f945277f859d"}, + {file = "sentencepiece-0.1.97-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:542f1985b1ee279a92bef7740ec0781452372028ce01e15aa88df3228b197ba3"}, + {file = "sentencepiece-0.1.97-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:93701da21fea906dd244bf88cdbe640385a89c45d3c1812b76dbadf8782cdbcd"}, + {file = "sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51514047b964047b7fadb480d88a5e0f72c02f6ca1ba96258fbbc6e79274a94"}, + {file = "sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e3ae2e9b7a5b6f2aa64ec9240b0c185dabe597d0e787dc4344acfbaef1ffe0b2"}, + {file = "sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:923ee4af16dbae1f2ab358ed09f8a0eb89e40a8198a8b343bf54181482342721"}, + {file = "sentencepiece-0.1.97-cp38-cp38-win32.whl", hash = "sha256:fa6f2b88850b5fae3a05053658824cf9f147c8e3c3b40eb64539a976c83d8a24"}, + {file = "sentencepiece-0.1.97-cp38-cp38-win_amd64.whl", hash = "sha256:5137ff0d0b1cc574751d178650ef800ff8d90bf21eb9f71e9567d4a0548940a5"}, + {file = "sentencepiece-0.1.97-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f92876271a10494671431ad955bff2d6f8ea59baaf957f5ae5946aff56dfcb90"}, + {file = "sentencepiece-0.1.97-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:35c227b6d55e473033db7e0ecc51b1e99e6ed7607cc08602fb5768132543c81d"}, + {file = "sentencepiece-0.1.97-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1706a8a8188f7b3d4b7922db9bb00c64c4e16ee68ab4caaae79f55b3e18748c7"}, + {file = "sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce61efc1862ccb18856c4aabbd930e13d5bfbb4b09b4f111081ac53a9dc62275"}, + {file = "sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a78c03800ef9f02d320e0159f5768b15357f3e9ebea545c9c4ba7928ba8ba254"}, + {file = "sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753b8088fd685ee787d9f54c84275ab347de558c7c4ebc6accb4c35bf7776f20"}, + {file = "sentencepiece-0.1.97-cp39-cp39-win32.whl", hash = "sha256:24306fd86031c17a1a6ae92671e76a350390a3140a65620bc2843dad7db24e2a"}, + {file = "sentencepiece-0.1.97-cp39-cp39-win_amd64.whl", hash = "sha256:c6641d0b7acec61fde5881ea6ebe098c169557ac9aa3bdabdf124eab5a5592bb"}, + {file = "sentencepiece-0.1.97.tar.gz", hash = "sha256:c901305e0a710bbcd296f66d79e96f744e6e175b29812bd5178318437d4e1f6c"}, +] setuptools = [ {file = "setuptools-67.4.0-py3-none-any.whl", hash = "sha256:f106dee1b506dee5102cc3f3e9e68137bbad6d47b616be7991714b0c62204251"}, {file = "setuptools-67.4.0.tar.gz", hash = "sha256:e5fd0a713141a4a105412233c63dc4e17ba0090c8e8334594ac790ec97792330"}, ] +tokenizers = [ + {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"}, + {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee0b1b311d65beab83d7a41c56a1e46ab732a9eed4460648e8eb0bd69fc2d059"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ef4215284df1277dadbcc5e17d4882bda19f770d02348e73523f7e7d8b8d396"}, + {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4d53976079cff8a033f778fb9adca2d9d69d009c02fa2d71a878b5f3963ed30"}, + {file = "tokenizers-0.13.3-cp310-cp310-win32.whl", hash = "sha256:1f0e3b4c2ea2cd13238ce43548959c118069db7579e5d40ec270ad77da5833ce"}, + {file = "tokenizers-0.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:89649c00d0d7211e8186f7a75dfa1db6996f65edce4b84821817eadcc2d3c79e"}, + {file = "tokenizers-0.13.3-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:56b726e0d2bbc9243872b0144515ba684af5b8d8cd112fb83ee1365e26ec74c8"}, + {file = "tokenizers-0.13.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc5c022ce692e1f499d745af293ab9ee6f5d92538ed2faf73f9708c89ee59ce6"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f55c981ac44ba87c93e847c333e58c12abcbb377a0c2f2ef96e1a266e4184ff2"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f247eae99800ef821a91f47c5280e9e9afaeed9980fc444208d5aa6ba69ff148"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b3e3215d048e94f40f1c95802e45dcc37c5b05eb46280fc2ccc8cd351bff839"}, + {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba2b0bf01777c9b9bc94b53764d6684554ce98551fec496f71bc5be3a03e98b"}, + {file = "tokenizers-0.13.3-cp311-cp311-win32.whl", hash = "sha256:cc78d77f597d1c458bf0ea7c2a64b6aa06941c7a99cb135b5969b0278824d808"}, + {file = "tokenizers-0.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:ecf182bf59bd541a8876deccf0360f5ae60496fd50b58510048020751cf1724c"}, + {file = "tokenizers-0.13.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:0527dc5436a1f6bf2c0327da3145687d3bcfbeab91fed8458920093de3901b44"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cbb2c307627dc99b44b22ef05ff4473aa7c7cc1fec8f0a8b37d8a64b1a16d2"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4560dbdeaae5b7ee0d4e493027e3de6d53c991b5002d7ff95083c99e11dd5ac0"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64064bd0322405c9374305ab9b4c07152a1474370327499911937fd4a76d004b"}, + {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8c6e2ab0f2e3d939ca66aa1d596602105fe33b505cd2854a4c1717f704c51de"}, + {file = "tokenizers-0.13.3-cp37-cp37m-win32.whl", hash = "sha256:6cc29d410768f960db8677221e497226e545eaaea01aa3613fa0fdf2cc96cff4"}, + {file = "tokenizers-0.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fc2a7fdf864554a0dacf09d32e17c0caa9afe72baf9dd7ddedc61973bae352d8"}, + {file = "tokenizers-0.13.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8791dedba834c1fc55e5f1521be325ea3dafb381964be20684b92fdac95d79b7"}, + {file = "tokenizers-0.13.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d607a6a13718aeb20507bdf2b96162ead5145bbbfa26788d6b833f98b31b26e1"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3791338f809cd1bf8e4fee6b540b36822434d0c6c6bc47162448deee3f77d425"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2f35f30e39e6aab8716f07790f646bdc6e4a853816cc49a95ef2a9016bf9ce6"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310204dfed5aa797128b65d63538a9837cbdd15da2a29a77d67eefa489edda26"}, + {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0f9b92ea052305166559f38498b3b0cae159caea712646648aaa272f7160963"}, + {file = "tokenizers-0.13.3-cp38-cp38-win32.whl", hash = "sha256:9a3fa134896c3c1f0da6e762d15141fbff30d094067c8f1157b9fdca593b5806"}, + {file = "tokenizers-0.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:8e7b0cdeace87fa9e760e6a605e0ae8fc14b7d72e9fc19c578116f7287bb873d"}, + {file = "tokenizers-0.13.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:00cee1e0859d55507e693a48fa4aef07060c4bb6bd93d80120e18fea9371c66d"}, + {file = "tokenizers-0.13.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a23ff602d0797cea1d0506ce69b27523b07e70f6dda982ab8cf82402de839088"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ce07445050b537d2696022dafb115307abdffd2a5c106f029490f84501ef97"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:280ffe95f50eaaf655b3a1dc7ff1d9cf4777029dbbc3e63a74e65a056594abc3"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97acfcec592f7e9de8cadcdcda50a7134423ac8455c0166b28c9ff04d227b371"}, + {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7730c98a3010cd4f523465867ff95cd9d6430db46676ce79358f65ae39797b"}, + {file = "tokenizers-0.13.3-cp39-cp39-win32.whl", hash = "sha256:48625a108029cb1ddf42e17a81b5a3230ba6888a70c9dc14e81bc319e812652d"}, + {file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"}, + {file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"}, +] tomli = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, diff --git a/server/pyproject.toml b/server/pyproject.toml index e9dc624..38f7fb0 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -23,6 +23,8 @@ opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" hf-transfer = "^0.1.2" +sentencepiece = "^0.1.97" +tokenizers = "0.13.3" [tool.poetry.extras] bnb = ["bitsandbytes"] diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index baf4457..7943578 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -148,7 +148,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) assert all([generation.generated_text is None for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([generation.token_id.item() == 259 for generation in generations]) - assert all([generation.token_text == " " for generation in generations]) + assert all([generation.token_text == "" for generation in generations]) assert generations[0].request_id == 0 diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index bcaf6ec..1e06b6d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -19,13 +19,11 @@ from text_generation_server.models.t5 import T5Sharded try: from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_santacoder import FlashSantacoder + from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded - FLASH_ATTENTION = ( - torch.cuda.is_available() and int(os.environ.get("FLASH_ATTENTION", 0)) == 1 - ) + FLASH_ATTENTION = torch.cuda.is_available() except ImportError: - if int(os.environ.get("FLASH_ATTENTION", 0)) == 1: - logger.exception("Could not import Flash Attention models") + logger.exception("Could not import Flash Attention enabled models") FLASH_ATTENTION = False __all__ = [ @@ -47,6 +45,12 @@ if FLASH_ATTENTION: __all__.append(FlashNeoX) __all__.append(FlashNeoXSharded) __all__.append(FlashSantacoder) + __all__.append(FlashLlama) + __all__.append(FlashLlamaSharded) + +FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention CUDA kernels to be installed.\n" \ + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " \ + "or install flash attention with `cd server && make install install-flash-attention`" # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -60,7 +64,7 @@ torch.set_grad_enabled(False) def get_model( - model_id: str, revision: Optional[str], sharded: bool, quantize: bool + model_id: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: if "facebook/galactica" in model_id: if sharded: @@ -92,6 +96,17 @@ def get_model( neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM return neox_cls(model_id, revision, quantize=quantize) + if model_type == "llama": + if sharded: + if FLASH_ATTENTION: + return FlashLlamaSharded(model_id, revision, quantize=quantize) + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama") + ) + else: + llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM + return llama_cls(model_id, revision, quantize=quantize) + if model_type == "t5": if sharded: return T5Sharded(model_id, revision, quantize=quantize) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cb7bbfd..8c092d6 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -34,6 +34,8 @@ class CausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] + offsets: List[Optional[int]] + token_offsets: List[Optional[int]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -64,12 +66,16 @@ class CausalLMBatch(Batch): inputs = [] next_token_choosers = [] stopping_criterias = [] + offsets = [] + token_offsets = [] # Parse batch max_truncation = 0 padding_right_offset = 0 for r in pb.requests: inputs.append(r.inputs) + offsets.append(None) + token_offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -113,6 +119,8 @@ class CausalLMBatch(Batch): past_key_values=None, all_input_ids=all_input_ids, input_lengths=input_lengths.tolist(), + offsets=offsets, + token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=pb.size, @@ -135,6 +143,8 @@ class CausalLMBatch(Batch): # Batch attributes requests = [] input_lengths = [] + offsets = [] + token_offsets = [] all_input_ids = [] next_token_choosers = [] stopping_criterias = [] @@ -151,6 +161,8 @@ class CausalLMBatch(Batch): for i, batch in enumerate(batches): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) + offsets.extend(batch.offsets) + token_offsets.extend(batch.token_offsets) all_input_ids.extend(batch.all_input_ids) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -264,6 +276,8 @@ class CausalLMBatch(Batch): past_key_values=past_key_values, all_input_ids=all_input_ids, input_lengths=input_lengths, + offsets=offsets, + token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, @@ -289,7 +303,7 @@ class CausalLM(Model): dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left" + model_id, revision=revision, padding_side="left", truncation_side="left" ) self.model = AutoModelForCausalLM.from_pretrained( model_id, @@ -350,6 +364,8 @@ class CausalLM(Model): # New values for next forward next_batch_input_lengths = [] + next_batch_offsets = [] + next_batch_token_offsets = [] next_batch_input_ids = [] next_batch_all_input_ids = [] @@ -364,6 +380,8 @@ class CausalLM(Model): iterator = zip( batch.requests, batch.input_lengths, + batch.offsets, + batch.token_offsets, logits, batch.next_token_choosers, batch.stopping_criterias, @@ -374,6 +392,8 @@ class CausalLM(Model): for i, ( request, input_length, + offset, + token_offset, logits, next_token_chooser, stopping_criteria, @@ -391,8 +411,8 @@ class CausalLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.decode_token( - next_token_id_squeezed, + next_token_text, offset, token_offset = self.decode_token( + all_input_ids[:, 0], offset, token_offset ) # Evaluate stopping criteria @@ -423,6 +443,8 @@ class CausalLM(Model): next_batch_all_input_ids.append(all_input_ids) next_batch_size += 1 next_batch_input_lengths.append(new_input_length) + next_batch_offsets.append(offset) + next_batch_token_offsets.append(token_offset) next_batch_max_input_length = max( next_batch_max_input_length, new_input_length ) @@ -506,6 +528,8 @@ class CausalLM(Model): past_key_values=next_batch_past_key_values, all_input_ids=next_batch_all_input_ids, input_lengths=next_batch_input_lengths, + offsets=next_batch_offsets, + token_offsets=next_batch_token_offsets, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py new file mode 100644 index 0000000..228529c --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -0,0 +1,619 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch.nn import functional as F + +from torch import nn +from transformers.activations import ACT2FN + +# Flash attention imports +import rotary_emb +import flash_attn_cuda +import dropout_layer_norm + +from flash_attn.layers.rotary import RotaryEmbedding + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states, residual + else: + # faster post attention rms norm + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res + + +class FastLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) + + def transpose_weight(self): + self.weight = nn.Parameter(self.weight.T) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.bias is not None: + return torch.addmm(self.bias, input, self.weight) + return torch.matmul(input, self.weight) + + +class TensorParallelColumnLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + assert out_features % self.tp_world_size == 0 + out_features = out_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + +class TensorParallelRowLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + reduce=True, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + self.reduce = reduce + assert in_features % self.tp_world_size == 0 + in_features = in_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super(TensorParallelRowLinear, self).forward(input) + if self.reduce: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + +class TensorParallelEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.original_num_embeddings = num_embeddings + + assert num_embeddings % self.tp_world_size == 0 + block_size = num_embeddings // self.tp_world_size + # inputs in `[min_id, max_id[` are handled by `self` to get embeddings + self.min_id = self.tp_rank * block_size + self.max_id = (self.tp_rank + 1) * block_size + + # Additional entry that will map to zero + # Used for masking + self.null_idx = block_size + + super().__init__( + block_size, + embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=_weight, + device=device, + dtype=dtype, + ) + + def add_null_idx(self): + """Additional 0 entry used for masking""" + self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = super().forward(input) + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class PositionRotaryEmbedding(RotaryEmbedding): + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): + """ + Return cos and sin for the asked position ids + """ + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + q1 = qkv[:, 0, :, :rotary_dim] + q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] + k1 = qkv[:, 1, :, :rotary_dim] + k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return qkv + + +class FlashLlamaAttention(torch.nn.Module): + def __init__( + self, + num_heads, + hidden_size, + process_group=None, + ): + super().__init__() + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + + self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.softmax_scale = self.head_size ** (-0.5) + + if process_group is None: + self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False) + self.o_proj = FastLinear(hidden_size, hidden_size, bias=False) + else: + self.num_heads = self.num_heads // process_group.size() + self.query_key_value = TensorParallelColumnLinear( + hidden_size, + 3 * hidden_size, + bias=False, + process_group=process_group, + ) + self.o_proj = TensorParallelRowLinear( + hidden_size, + hidden_size, + bias=False, + process_group=process_group, + ) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + qkv_rot = self.rotary_emb(qkv, cos, sin) + + # Prefill + if layer_past_present_indices is None: + # Copy to layer past + layer_past[...] = qkv_rot[:, 1:] + + # output + attn_output = torch.empty_like(qkv_rot[:, 0]) + # flash attention + flash_attn_cuda.fwd( + qkv_rot[:, 0], + qkv_rot[:, 1], + qkv_rot[:, 2], + attn_output, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + False, + 0, + None, + ) + # Decode + else: + query = qkv_rot[:, 0] + # Add present to the layer_past tensor at the correct indices + layer_past[layer_past_present_indices] = qkv_rot[:, 1:] + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + layer_past[:, 0], + layer_past[:, 1], + attn_output, + cu_seqlens_q, + cu_seqlens, + 1, + max_s, + 0.0, + self.softmax_scale, + False, + False, + False, + 0, + None, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class LlamaMLP(nn.Module): + def __init__(self, act, hidden_size, intermediate_size, process_group=None): + super().__init__() + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else None, + ) + ) + + if process_group is None: + # Fuse gate and up proj + self.gate_up_proj = FastLinear( + hidden_size, 2 * intermediate_size, bias=False + ) + self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False) + self.intermediate_size = intermediate_size + else: + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear( + hidden_size, + 2 * intermediate_size, + bias=False, + process_group=process_group, + ) + self.down_proj = TensorParallelRowLinear( + intermediate_size, + hidden_size, + bias=False, + process_group=process_group, + reduce=True, + ) + self.intermediate_size = self.down_proj.in_features + + self.process_group = process_group + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class FlashLlamaLayer(nn.Module): + def __init__( + self, + num_heads, + act, + hidden_size, + intermediate_size, + rms_norm_eps, + process_group=None, + ): + super().__init__() + + self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group) + self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group) + + self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class FlashLlamaModel(torch.nn.Module): + def __init__(self, config, process_group=None): + super(FlashLlamaModel, self).__init__() + self.config = config + + self.tp_embeddings = False + if process_group is not None: + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + if config.vocab_size % self.tp_world_size == 0: + self.tp_embeddings = True + + if self.tp_embeddings: + self.embed_tokens = TensorParallelEmbedding( + config.vocab_size, config.hidden_size, process_group=process_group + ) + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + self.layers = nn.ModuleList( + [ + FlashLlamaLayer( + config.num_attention_heads, + config.hidden_act, + config.hidden_size, + config.intermediate_size, + config.rms_norm_eps, + process_group, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + + def post_load_weights(self): + if isinstance(self.embed_tokens, TensorParallelEmbedding): + self.embed_tokens.add_null_idx() + for layer in self.layers: + layer: FlashLlamaLayer + layer.self_attn.query_key_value.transpose_weight() + layer.self_attn.o_proj.transpose_weight() + layer.mlp.gate_up_proj.transpose_weight() + layer.mlp.down_proj.transpose_weight() + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, + ): + hidden_states = self.embed_tokens(input_ids) + + # Prefill + if past_key_values is None: + # Create past tensor + past_key_values = hidden_states.new_empty( + ( + len(self.layers), + len(hidden_states), + 2, + self.num_heads, + self.head_size, + ) + ) + layer_past_present_indices = None + cu_seqlens_q = None + # Decode + else: + # Create indices from cumulative sequence lengths + layer_past_present_indices = cu_seqlens[1:] - 1 + cu_seqlens_q = torch.arange( + cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device + ) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + past_key_values[i], + layer_past_present_indices, + cu_seqlens_q, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states, past_key_values + + +class FlashLlamaForCausalLM(torch.nn.Module): + def __init__(self, config, process_group=None): + super().__init__() + + self.process_group = process_group + if self.process_group is not None: + self.world_size = self.process_group.size() + self.rank = self.process_group.rank() + else: + self.world_size = 1 + self.rank = 0 + + self.model = FlashLlamaModel(config, process_group) + + if self.model.tp_embeddings: + self.lm_head = FastLinear( + config.hidden_size, + config.vocab_size // process_group.size(), + bias=False, + ) + else: + self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) + + def post_load_weights(self): + self.model.post_load_weights() + self.lm_head.transpose_weight() + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, + ): + hidden_states, present = self.model( + input_ids, position_ids, cu_seqlens, max_s, past_key_values + ) + logits = self.lm_head(hidden_states) + + if self.model.tp_embeddings: + # Logits are sharded, so we need to gather them + world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] + torch.distributed.all_gather(world_logits, logits, group=self.process_group) + world_logits = torch.cat(world_logits, dim=1) + + return world_logits, present + return logits, present diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bc1ac06..3801ed2 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -44,6 +44,8 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] + offsets: List[Optional[int]] + token_offsets: List[Optional[int]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -67,6 +69,8 @@ class FlashCausalLMBatch(Batch): max_seqlen = 0 input_lengths = [] + offsets = [] + token_offsets = [] all_input_ids = [] all_input_ids_tensor = [] @@ -84,6 +88,8 @@ class FlashCausalLMBatch(Batch): input_length = len(tokenized_input) max_seqlen = max(max_seqlen, input_length) input_lengths.append(input_length) + offsets.append(None) + token_offsets.append(None) all_input_ids.append(tokenized_input) tokenized_input = torch.tensor(tokenized_input, device=device) @@ -120,6 +126,8 @@ class FlashCausalLMBatch(Batch): max_seqlen=max_seqlen, past_key_values=None, input_lengths=input_lengths, + offsets=offsets, + token_offsets=token_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, @@ -132,6 +140,8 @@ class FlashCausalLMBatch(Batch): # Batch attributes requests = [] input_lengths = [] + offsets = [] + token_offsets = [] all_input_ids = [] all_input_ids_tensor = [] next_token_choosers = [] @@ -150,6 +160,8 @@ class FlashCausalLMBatch(Batch): for i, batch in enumerate(batches): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) + offsets.extend(batch.offsets) + token_offsets.extend(batch.token_offsets) all_input_ids.extend(batch.all_input_ids) all_input_ids_tensor.extend(batch.all_input_ids_tensor) next_token_choosers.extend(batch.next_token_choosers) @@ -182,6 +194,8 @@ class FlashCausalLMBatch(Batch): max_seqlen=max_seqlen, past_key_values=past_key_values, input_lengths=input_lengths, + offsets=offsets, + token_offsets=token_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, @@ -279,6 +293,8 @@ class FlashCausalLM(Model): next_batch_max_seqlen = 0 next_batch_past_key_values = [] next_batch_input_lengths = [] + next_batch_offsets = [] + next_batch_token_offsets = [] next_batch_all_input_ids = [] next_batch_all_input_ids_tensor = [] @@ -292,6 +308,8 @@ class FlashCausalLM(Model): iterator = zip( batch.requests, batch.input_lengths, + batch.offsets, + batch.token_offsets, batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, @@ -302,6 +320,8 @@ class FlashCausalLM(Model): for i, ( request, input_length, + offset, + token_offset, next_token_chooser, stopping_criteria, all_input_ids, @@ -334,8 +354,10 @@ class FlashCausalLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id_item] - next_token_text = self.decode_token( - next_token_id_item, + next_token_text, offset, token_offset = self.decode_token( + all_input_ids, + offset, + token_offset, ) # Evaluate stopping criteria @@ -376,6 +398,8 @@ class FlashCausalLM(Model): next_batch_cu_seqlens[-1] + new_input_length ) next_batch_input_lengths.append(new_input_length) + next_batch_offsets.append(offset) + next_batch_token_offsets.append(token_offset) next_batch_all_input_ids.append(all_input_ids) next_batch_all_input_ids_tensor.append(all_input_ids_tensor) next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) @@ -452,6 +476,8 @@ class FlashCausalLM(Model): max_seqlen=next_batch_max_seqlen, past_key_values=next_batch_past_key_values, input_lengths=next_batch_input_lengths, + offsets=next_batch_offsets, + token_offsets=next_batch_token_offsets, all_input_ids=next_batch_all_input_ids, all_input_ids_tensor=next_batch_all_input_ids_tensor, next_token_choosers=next_batch_next_token_choosers, diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py new file mode 100644 index 0000000..063910f --- /dev/null +++ b/server/text_generation_server/models/flash_llama.py @@ -0,0 +1,303 @@ +import torch +import torch.distributed + +from accelerate import init_empty_weights +from opentelemetry import trace +from pathlib import Path +from safetensors import safe_open +from transformers import AutoConfig +from transformers.models.llama import LlamaTokenizer +from typing import Optional, List + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelColumnLinear, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + download_weights, + weight_hub_files, + LocalEntryNotFoundError, +) + +tracer = trace.get_tracer(__name__) + + +class FlashLlama(FlashCausalLM): + def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + raise NotImplementedError("FlashLlama is only available on GPU") + + if quantize: + raise NotImplementedError("FlashLlama does not support quantization") + + tokenizer = LlamaTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + ) + + config = AutoConfig.from_pretrained( + model_id, + revision=revision, + ) + + # We do not use from_pretrained as we modified the model internal module layout + try: + filenames = weight_files(model_id, revision, ".bin") + # Local files not found + except LocalEntryNotFoundError: + hub_files = weight_hub_files(model_id, revision, ".bin") + filenames = download_weights(hub_files, model_id, revision) + + with init_empty_weights(): + model = FlashLlamaForCausalLM(config) + + self.load_weights(model, filenames, device, dtype) + self.model = model.eval() + + super(FlashCausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model, + filenames: List[Path], + device: torch.device, + dtype: torch.dtype, + ): + for filename in filenames: + state_dict = torch.load(filename, map_location="cpu") + for key, value in state_dict.items(): + value = value.to(device).to(dtype) + + layer_name = ".".join(key.split(".")[:4]) + + # Fused qkv + if "q_proj" in key or "k_proj" in key or "v_proj" in key: + final_key = layer_name + ".query_key_value.weight" + + # Fused gate and up projs + elif "gate_proj" in key or "up_proj" in key: + final_key = layer_name + ".gate_up_proj.weight" + else: + final_key = key + + module_name, param_name = final_key.rsplit(".", 1) + module = model.get_submodule(module_name) + + try: + current_parameter_tensor = module._parameters[param_name] + except KeyError: + current_parameter_tensor = None + + if current_parameter_tensor is not None: + if current_parameter_tensor.device == torch.device("meta"): + # Init qkv + if "query_key_value" in final_key: + module._parameters[param_name] = value.new_empty( + (value.shape[0] * 3, value.shape[1]) + ) + # Init gate and up proj + elif "gate_up_proj" in final_key: + module._parameters[param_name] = value.new_empty( + (value.shape[0] * 2, value.shape[1]) + ) + + # Copy to correct slice + if "q_proj" in key: + module._parameters[param_name][: value.shape[0]] = value + elif "k_proj" in key: + module._parameters[param_name][ + value.shape[0] : value.shape[0] * 2 + ] = value + elif "v_proj" in key: + module._parameters[param_name][value.shape[0] * 2 :] = value + elif "gate_proj" in key: + module._parameters[param_name][: value.shape[0]] = value + elif "up_proj" in key: + module._parameters[param_name][value.shape[0] :] = value + else: + if current_parameter_tensor.shape != value.shape: + raise ValueError( + f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" + ) + module._parameters[param_name] = value + else: + module._buffers[param_name] = value + + del value + + torch.cuda.empty_cache() + model.post_load_weights() + + +class FlashLlamaSharded(FlashLlama): + def __init__( + self, model_id: str, revision: Optional[str] = None, quantize: bool = False + ): + self.process_group, self.rank, self.world_size = initialize_torch_distributed() + self.master = self.rank == 0 + if torch.cuda.is_available(): + device = torch.device(f"cuda:{self.rank}") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + raise NotImplementedError("FlashLlama is only available on GPU") + + if quantize: + raise NotImplementedError("FlashLlama does not support quantization") + + tokenizer = LlamaTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + ) + + config = AutoConfig.from_pretrained( + model_id, + revision=revision, + ) + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + + with init_empty_weights(): + model = FlashLlamaForCausalLM(config, process_group=self.process_group) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + quantize=quantize, + device=device, + dtype=dtype, + rank=self.rank, + world_size=self.world_size, + ) + self.model = model.eval() + torch.distributed.barrier(group=self.process_group) + super(FlashCausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model, + filenames: List[str], + quantize: bool, + device: torch.device, + dtype: torch.dtype, + rank: int, + world_size: int, + ): + for file in filenames: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: + for name in f.keys(): + slice_ = f.get_slice(name) + + layer_name = ".".join(name.split(".")[:4]) + + # Fused qkv + if "q_proj" in name or "k_proj" in name or "v_proj" in name: + final_name = layer_name + ".query_key_value.weight" + + # Fused gate and up projs + elif "gate_proj" in name or "up_proj" in name: + final_name = layer_name + ".gate_up_proj.weight" + else: + final_name = name + + module_name, param_name = final_name.rsplit(".", 1) + module = model.get_submodule(module_name) + + if isinstance(module, TensorParallelColumnLinear): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif name == "lm_head.weight" and model.model.tp_embeddings: + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + try: + tensor = slice_[:] + except: + tensor = f.get_tensor(name) + + tensor = tensor.contiguous().to(dtype) + + try: + current_parameter_tensor = module._parameters[param_name] + except KeyError: + current_parameter_tensor = None + + if current_parameter_tensor is not None: + if current_parameter_tensor.device == torch.device("meta"): + # Init qkv + if "query_key_value" in final_name: + module._parameters[param_name] = tensor.new_empty( + (tensor.shape[0] * 3, tensor.shape[1]) + ) + # Init gate and up proj + elif "gate_up_proj" in final_name: + module._parameters[param_name] = tensor.new_empty( + (tensor.shape[0] * 2, tensor.shape[1]) + ) + + # Init gate and up proj + if "q_proj" in name: + module._parameters[param_name][: tensor.shape[0]] = tensor + elif "k_proj" in name: + module._parameters[param_name][ + tensor.shape[0] : tensor.shape[0] * 2 + ] = tensor + elif "v_proj" in name: + module._parameters[param_name][ + tensor.shape[0] * 2 : + ] = tensor + elif "gate_proj" in name: + module._parameters[param_name][: tensor.shape[0]] = tensor + elif "up_proj" in name: + module._parameters[param_name][tensor.shape[0] :] = tensor + else: + if current_parameter_tensor.shape != tensor.shape: + raise ValueError( + f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + ) + + module._parameters[param_name] = tensor + + else: + module._buffers[param_name] = tensor + torch.cuda.empty_cache() + model.post_load_weights() diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 3c89b4a..f7fbb2a 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -93,7 +93,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): inputs = [] next_token_choosers = [] stopping_criterias = [] - input_lengths = [] + offsets = [] + token_offsets = [] # Parse batch max_truncation = 0 @@ -101,7 +102,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): for r in pb.requests: # Add escape_custom_split_sequence to the CausalLMBatch logic inputs.append(escape_custom_split_sequence(r.inputs)) - input_lengths.append(r.input_length) + offsets.append(None) + token_offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -146,6 +148,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): past_key_values=None, all_input_ids=all_input_ids, input_lengths=input_lengths, + offsets=offsets, + token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=pb.size, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index e0ce668..5b82872 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -15,15 +15,6 @@ class Model(ABC): self.all_special_ids = set(tokenizer.all_special_ids) self.device = device - # see `decode_token` method - self.tokenizer.add_special_tokens( - {"additional_special_tokens": [""]} - ) - self.special_decode_token_id = self.tokenizer.convert_tokens_to_ids( - "" - ) - self.special_decode_token_length = len("") - @property @abstractmethod def batch_type(self) -> Type[B]: @@ -33,11 +24,38 @@ class Model(ABC): def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: raise NotImplementedError - def decode_token(self, token_id: int) -> str: + def decode_token( + self, + all_input_ids: List[int], + offset: Optional[int] = None, + token_offset: Optional[int] = None, + ) -> Tuple[str, Optional[int], Optional[int]]: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" - # append token to special decode token and decode both - result = self.tokenizer.decode( - [self.special_decode_token_id, token_id], skip_special_tokens=False + if all_input_ids[-1] in self.all_special_ids: + return ( + self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False), + None, + None, + ) + + if token_offset is None: + token_offset = len(all_input_ids) - 3 + + # Decode token_offset token minus last one and token_offset tokens + results = self.tokenizer.batch_decode( + [all_input_ids[token_offset:-1], all_input_ids[token_offset:]], + skip_special_tokens=False, ) - # slice to remove special decode token - return result[self.special_decode_token_length :] + + # default offset is only the last token + if offset is None: + offset = len(results[0]) + + # get text + text = results[1][offset:] + + # if text is utf-8 + if text and text[-1] != "�": + return text, None, None + else: + return "", offset, token_offset diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 7cf9712..13eafd6 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -38,6 +38,8 @@ class Seq2SeqLMBatch(Batch): # Lengths of all generations present in the batch input_lengths: List[int] decoder_input_lengths: List[int] + offsets: List[Optional[int]] + token_offsets: List[Optional[int]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -71,6 +73,8 @@ class Seq2SeqLMBatch(Batch): decoder_input_ids = [] decoder_input_lengths = [] + offsets = [] + token_offsets = [] # Parse batch max_truncation = 0 @@ -80,6 +84,8 @@ class Seq2SeqLMBatch(Batch): # Decoder sequence only contains the bos_token decoder_input_ids.append(tokenizer.bos_token_id) decoder_input_lengths.append(1) + offsets.append(None) + token_offsets.append(None) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer @@ -117,6 +123,8 @@ class Seq2SeqLMBatch(Batch): past_key_values=None, input_lengths=input_lengths.tolist(), decoder_input_lengths=decoder_input_lengths, + offsets=offsets, + token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=len(pb.requests), @@ -147,6 +155,8 @@ class Seq2SeqLMBatch(Batch): requests = [] input_lengths = [] decoder_input_lengths = [] + offsets = [] + token_offsets = [] next_token_choosers = [] stopping_criterias = [] @@ -166,6 +176,8 @@ class Seq2SeqLMBatch(Batch): requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths) + offsets.extend(batch.offsets) + token_offsets.extend(batch.token_offsets) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -303,6 +315,8 @@ class Seq2SeqLMBatch(Batch): past_key_values=past_key_values, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, + offsets=offsets, + token_offsets=token_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, @@ -335,7 +349,7 @@ class Seq2SeqLM(Model): load_in_8bit=quantize, ).eval() tokenizer = AutoTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left" + model_id, revision=revision, padding_side="left", truncation_side="left" ) tokenizer.bos_token_id = self.model.config.decoder_start_token_id @@ -422,6 +436,8 @@ class Seq2SeqLM(Model): # New values for next forward next_batch_input_lengths = [] + next_batch_offsets = [] + next_batch_token_offsets = [] next_batch_decoder_input_ids = [] next_batch_decoder_input_lengths = [] @@ -437,6 +453,8 @@ class Seq2SeqLM(Model): iterator = zip( batch.requests, batch.input_lengths, + batch.offsets, + batch.token_offsets, batch.decoder_input_lengths, logits, batch.next_token_choosers, @@ -448,6 +466,8 @@ class Seq2SeqLM(Model): for i, ( request, input_length, + offset, + token_offset, decoder_input_length, logits, next_token_chooser, @@ -466,8 +486,8 @@ class Seq2SeqLM(Model): # Generated token next_token_logprob = logprobs[-1, next_token_id] next_token_id_squeezed = next_token_id.squeeze() - next_token_text = self.decode_token( - next_token_id_squeezed, + next_token_text, offset, token_offset = self.decode_token( + decoder_input_ids, offset, token_offset ) # Evaluate stopping criteria @@ -495,6 +515,8 @@ class Seq2SeqLM(Model): next_batch_size += 1 next_batch_input_lengths.append(input_length) next_batch_decoder_input_lengths.append(new_decoder_input_length) + next_batch_offsets.append(offset) + next_batch_token_offsets.append(token_offset) next_batch_max_input_length = max( next_batch_max_input_length, input_length ) @@ -580,6 +602,8 @@ class Seq2SeqLM(Model): past_key_values=next_batch_past_key_values, input_lengths=next_batch_input_lengths, decoder_input_lengths=next_batch_decoder_input_lengths, + offsets=next_batch_offsets, + token_offsets=next_batch_token_offsets, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size,