diff --git a/README.md b/README.md index dc074d50..c6db2822 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ to power Hugging Chat, the Inference API and Inference Endpoint. - [MPT](https://huggingface.co/mosaicml/mpt-30b) - [Llama V2](https://huggingface.co/meta-llama) - [Code Llama](https://huggingface.co/codellama) +- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) Other architectures are supported on a best effort basis using: diff --git a/clients/python/README.md b/clients/python/README.md index 4e0e564c..82f3ee0c 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -140,6 +140,8 @@ class Parameters: watermark: bool # Get decoder input token logprobs and ids decoder_input_details: bool + # Return the N most likely tokens at each step + top_n_tokens: Optional[int] # Decoder input tokens class InputToken: @@ -189,6 +191,8 @@ class BestOfSequence: prefill: List[InputToken] # Generated tokens tokens: List[Token] + # Most likely tokens + top_tokens: Optional[List[List[Token]]] # `generate` details @@ -203,6 +207,8 @@ class Details: prefill: List[InputToken] # Generated tokens tokens: List[Token] + # Most likely tokens + top_tokens: Optional[List[List[Token]]] # Additional sequences when using the `best_of` parameter best_of_sequences: Optional[List[BestOfSequence]] @@ -229,6 +235,8 @@ class StreamDetails: class StreamResponse: # Generated token token: Token + # Most likely tokens + top_tokens: Optional[List[Token]] # Complete generated text # Only available when the generation is finished generated_text: Optional[str] diff --git a/clients/python/poetry.lock b/clients/python/poetry.lock index e038ad9b..2d4e45d2 100644 --- a/clients/python/poetry.lock +++ b/clients/python/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -124,6 +124,20 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "annotated-types" +version = "0.5.0" +description = "Reusable constraint types to use with typing.Annotated" +optional = false +python-versions = ">=3.7" +files = [ + {file = "annotated_types-0.5.0-py3-none-any.whl", hash = "sha256:58da39888f92c276ad970249761ebea80ba544b77acddaa1a4d6cf78287d45fd"}, + {file = "annotated_types-0.5.0.tar.gz", hash = "sha256:47cdc3490d9ac1506ce92c7aaa76c579dc3509ff11e098fc867e5130ab7be802"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} + [[package]] name = "async-timeout" version = "4.0.3" @@ -693,55 +707,140 @@ files = [ [[package]] name = "pydantic" -version = "1.10.12" -description = "Data validation and settings management using python type hints" +version = "2.4.2" +description = "Data validation using Python type hints" optional = false python-versions = ">=3.7" files = [ - {file = "pydantic-1.10.12-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a1fcb59f2f355ec350073af41d927bf83a63b50e640f4dbaa01053a28b7a7718"}, - {file = "pydantic-1.10.12-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b7ccf02d7eb340b216ec33e53a3a629856afe1c6e0ef91d84a4e6f2fb2ca70fe"}, - {file = "pydantic-1.10.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fb2aa3ab3728d950bcc885a2e9eff6c8fc40bc0b7bb434e555c215491bcf48b"}, - {file = "pydantic-1.10.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:771735dc43cf8383959dc9b90aa281f0b6092321ca98677c5fb6125a6f56d58d"}, - {file = "pydantic-1.10.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ca48477862372ac3770969b9d75f1bf66131d386dba79506c46d75e6b48c1e09"}, - {file = "pydantic-1.10.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a5e7add47a5b5a40c49b3036d464e3c7802f8ae0d1e66035ea16aa5b7a3923ed"}, - {file = "pydantic-1.10.12-cp310-cp310-win_amd64.whl", hash = "sha256:e4129b528c6baa99a429f97ce733fff478ec955513630e61b49804b6cf9b224a"}, - {file = "pydantic-1.10.12-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b0d191db0f92dfcb1dec210ca244fdae5cbe918c6050b342d619c09d31eea0cc"}, - {file = "pydantic-1.10.12-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:795e34e6cc065f8f498c89b894a3c6da294a936ee71e644e4bd44de048af1405"}, - {file = "pydantic-1.10.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69328e15cfda2c392da4e713443c7dbffa1505bc9d566e71e55abe14c97ddc62"}, - {file = "pydantic-1.10.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2031de0967c279df0d8a1c72b4ffc411ecd06bac607a212892757db7462fc494"}, - {file = "pydantic-1.10.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:ba5b2e6fe6ca2b7e013398bc7d7b170e21cce322d266ffcd57cca313e54fb246"}, - {file = "pydantic-1.10.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2a7bac939fa326db1ab741c9d7f44c565a1d1e80908b3797f7f81a4f86bc8d33"}, - {file = "pydantic-1.10.12-cp311-cp311-win_amd64.whl", hash = "sha256:87afda5539d5140cb8ba9e8b8c8865cb5b1463924d38490d73d3ccfd80896b3f"}, - {file = "pydantic-1.10.12-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:549a8e3d81df0a85226963611950b12d2d334f214436a19537b2efed61b7639a"}, - {file = "pydantic-1.10.12-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:598da88dfa127b666852bef6d0d796573a8cf5009ffd62104094a4fe39599565"}, - {file = "pydantic-1.10.12-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba5c4a8552bff16c61882db58544116d021d0b31ee7c66958d14cf386a5b5350"}, - {file = "pydantic-1.10.12-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c79e6a11a07da7374f46970410b41d5e266f7f38f6a17a9c4823db80dadf4303"}, - {file = "pydantic-1.10.12-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab26038b8375581dc832a63c948f261ae0aa21f1d34c1293469f135fa92972a5"}, - {file = "pydantic-1.10.12-cp37-cp37m-win_amd64.whl", hash = "sha256:e0a16d274b588767602b7646fa05af2782576a6cf1022f4ba74cbb4db66f6ca8"}, - {file = "pydantic-1.10.12-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6a9dfa722316f4acf4460afdf5d41d5246a80e249c7ff475c43a3a1e9d75cf62"}, - {file = "pydantic-1.10.12-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a73f489aebd0c2121ed974054cb2759af8a9f747de120acd2c3394cf84176ccb"}, - {file = "pydantic-1.10.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b30bcb8cbfccfcf02acb8f1a261143fab622831d9c0989707e0e659f77a18e0"}, - {file = "pydantic-1.10.12-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2fcfb5296d7877af406ba1547dfde9943b1256d8928732267e2653c26938cd9c"}, - {file = "pydantic-1.10.12-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2f9a6fab5f82ada41d56b0602606a5506aab165ca54e52bc4545028382ef1c5d"}, - {file = "pydantic-1.10.12-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:dea7adcc33d5d105896401a1f37d56b47d443a2b2605ff8a969a0ed5543f7e33"}, - {file = "pydantic-1.10.12-cp38-cp38-win_amd64.whl", hash = "sha256:1eb2085c13bce1612da8537b2d90f549c8cbb05c67e8f22854e201bde5d98a47"}, - {file = "pydantic-1.10.12-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ef6c96b2baa2100ec91a4b428f80d8f28a3c9e53568219b6c298c1125572ebc6"}, - {file = "pydantic-1.10.12-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6c076be61cd0177a8433c0adcb03475baf4ee91edf5a4e550161ad57fc90f523"}, - {file = "pydantic-1.10.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d5a58feb9a39f481eda4d5ca220aa8b9d4f21a41274760b9bc66bfd72595b86"}, - {file = "pydantic-1.10.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5f805d2d5d0a41633651a73fa4ecdd0b3d7a49de4ec3fadf062fe16501ddbf1"}, - {file = "pydantic-1.10.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:1289c180abd4bd4555bb927c42ee42abc3aee02b0fb2d1223fb7c6e5bef87dbe"}, - {file = "pydantic-1.10.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5d1197e462e0364906cbc19681605cb7c036f2475c899b6f296104ad42b9f5fb"}, - {file = "pydantic-1.10.12-cp39-cp39-win_amd64.whl", hash = "sha256:fdbdd1d630195689f325c9ef1a12900524dceb503b00a987663ff4f58669b93d"}, - {file = "pydantic-1.10.12-py3-none-any.whl", hash = "sha256:b749a43aa51e32839c9d71dc67eb1e4221bb04af1033a32e3923d46f9effa942"}, - {file = "pydantic-1.10.12.tar.gz", hash = "sha256:0fe8a415cea8f340e7a9af9c54fc71a649b43e8ca3cc732986116b3cb135d303"}, + {file = "pydantic-2.4.2-py3-none-any.whl", hash = "sha256:bc3ddf669d234f4220e6e1c4d96b061abe0998185a8d7855c0126782b7abc8c1"}, + {file = "pydantic-2.4.2.tar.gz", hash = "sha256:94f336138093a5d7f426aac732dcfe7ab4eb4da243c88f891d65deb4a2556ee7"}, ] [package.dependencies] -typing-extensions = ">=4.2.0" +annotated-types = ">=0.4.0" +pydantic-core = "2.10.1" +typing-extensions = ">=4.6.1" [package.extras] -dotenv = ["python-dotenv (>=0.10.4)"] -email = ["email-validator (>=1.0.3)"] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.10.1" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic_core-2.10.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:d64728ee14e667ba27c66314b7d880b8eeb050e58ffc5fec3b7a109f8cddbd63"}, + {file = "pydantic_core-2.10.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:48525933fea744a3e7464c19bfede85df4aba79ce90c60b94d8b6e1eddd67096"}, + {file = "pydantic_core-2.10.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef337945bbd76cce390d1b2496ccf9f90b1c1242a3a7bc242ca4a9fc5993427a"}, + {file = "pydantic_core-2.10.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a1392e0638af203cee360495fd2cfdd6054711f2db5175b6e9c3c461b76f5175"}, + {file = "pydantic_core-2.10.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0675ba5d22de54d07bccde38997e780044dcfa9a71aac9fd7d4d7a1d2e3e65f7"}, + {file = "pydantic_core-2.10.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:128552af70a64660f21cb0eb4876cbdadf1a1f9d5de820fed6421fa8de07c893"}, + {file = "pydantic_core-2.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f6e6aed5818c264412ac0598b581a002a9f050cb2637a84979859e70197aa9e"}, + {file = "pydantic_core-2.10.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ecaac27da855b8d73f92123e5f03612b04c5632fd0a476e469dfc47cd37d6b2e"}, + {file = "pydantic_core-2.10.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b3c01c2fb081fced3bbb3da78510693dc7121bb893a1f0f5f4b48013201f362e"}, + {file = "pydantic_core-2.10.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:92f675fefa977625105708492850bcbc1182bfc3e997f8eecb866d1927c98ae6"}, + {file = "pydantic_core-2.10.1-cp310-none-win32.whl", hash = "sha256:420a692b547736a8d8703c39ea935ab5d8f0d2573f8f123b0a294e49a73f214b"}, + {file = "pydantic_core-2.10.1-cp310-none-win_amd64.whl", hash = "sha256:0880e239827b4b5b3e2ce05e6b766a7414e5f5aedc4523be6b68cfbc7f61c5d0"}, + {file = "pydantic_core-2.10.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:073d4a470b195d2b2245d0343569aac7e979d3a0dcce6c7d2af6d8a920ad0bea"}, + {file = "pydantic_core-2.10.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:600d04a7b342363058b9190d4e929a8e2e715c5682a70cc37d5ded1e0dd370b4"}, + {file = "pydantic_core-2.10.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39215d809470f4c8d1881758575b2abfb80174a9e8daf8f33b1d4379357e417c"}, + {file = "pydantic_core-2.10.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eeb3d3d6b399ffe55f9a04e09e635554012f1980696d6b0aca3e6cf42a17a03b"}, + {file = "pydantic_core-2.10.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7a7902bf75779bc12ccfc508bfb7a4c47063f748ea3de87135d433a4cca7a2f"}, + {file = "pydantic_core-2.10.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3625578b6010c65964d177626fde80cf60d7f2e297d56b925cb5cdeda6e9925a"}, + {file = "pydantic_core-2.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:caa48fc31fc7243e50188197b5f0c4228956f97b954f76da157aae7f67269ae8"}, + {file = "pydantic_core-2.10.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:07ec6d7d929ae9c68f716195ce15e745b3e8fa122fc67698ac6498d802ed0fa4"}, + {file = "pydantic_core-2.10.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e6f31a17acede6a8cd1ae2d123ce04d8cca74056c9d456075f4f6f85de055607"}, + {file = "pydantic_core-2.10.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d8f1ebca515a03e5654f88411420fea6380fc841d1bea08effb28184e3d4899f"}, + {file = "pydantic_core-2.10.1-cp311-none-win32.whl", hash = "sha256:6db2eb9654a85ada248afa5a6db5ff1cf0f7b16043a6b070adc4a5be68c716d6"}, + {file = "pydantic_core-2.10.1-cp311-none-win_amd64.whl", hash = "sha256:4a5be350f922430997f240d25f8219f93b0c81e15f7b30b868b2fddfc2d05f27"}, + {file = "pydantic_core-2.10.1-cp311-none-win_arm64.whl", hash = "sha256:5fdb39f67c779b183b0c853cd6b45f7db84b84e0571b3ef1c89cdb1dfc367325"}, + {file = "pydantic_core-2.10.1-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:b1f22a9ab44de5f082216270552aa54259db20189e68fc12484873d926426921"}, + {file = "pydantic_core-2.10.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8572cadbf4cfa95fb4187775b5ade2eaa93511f07947b38f4cd67cf10783b118"}, + {file = "pydantic_core-2.10.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db9a28c063c7c00844ae42a80203eb6d2d6bbb97070cfa00194dff40e6f545ab"}, + {file = "pydantic_core-2.10.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0e2a35baa428181cb2270a15864ec6286822d3576f2ed0f4cd7f0c1708472aff"}, + {file = "pydantic_core-2.10.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05560ab976012bf40f25d5225a58bfa649bb897b87192a36c6fef1ab132540d7"}, + {file = "pydantic_core-2.10.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d6495008733c7521a89422d7a68efa0a0122c99a5861f06020ef5b1f51f9ba7c"}, + {file = "pydantic_core-2.10.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14ac492c686defc8e6133e3a2d9eaf5261b3df26b8ae97450c1647286750b901"}, + {file = "pydantic_core-2.10.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8282bab177a9a3081fd3d0a0175a07a1e2bfb7fcbbd949519ea0980f8a07144d"}, + {file = "pydantic_core-2.10.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:aafdb89fdeb5fe165043896817eccd6434aee124d5ee9b354f92cd574ba5e78f"}, + {file = "pydantic_core-2.10.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f6defd966ca3b187ec6c366604e9296f585021d922e666b99c47e78738b5666c"}, + {file = "pydantic_core-2.10.1-cp312-none-win32.whl", hash = "sha256:7c4d1894fe112b0864c1fa75dffa045720a194b227bed12f4be7f6045b25209f"}, + {file = "pydantic_core-2.10.1-cp312-none-win_amd64.whl", hash = "sha256:5994985da903d0b8a08e4935c46ed8daf5be1cf217489e673910951dc533d430"}, + {file = "pydantic_core-2.10.1-cp312-none-win_arm64.whl", hash = "sha256:0d8a8adef23d86d8eceed3e32e9cca8879c7481c183f84ed1a8edc7df073af94"}, + {file = "pydantic_core-2.10.1-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:9badf8d45171d92387410b04639d73811b785b5161ecadabf056ea14d62d4ede"}, + {file = "pydantic_core-2.10.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:ebedb45b9feb7258fac0a268a3f6bec0a2ea4d9558f3d6f813f02ff3a6dc6698"}, + {file = "pydantic_core-2.10.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfe1090245c078720d250d19cb05d67e21a9cd7c257698ef139bc41cf6c27b4f"}, + {file = "pydantic_core-2.10.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e357571bb0efd65fd55f18db0a2fb0ed89d0bb1d41d906b138f088933ae618bb"}, + {file = "pydantic_core-2.10.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b3dcd587b69bbf54fc04ca157c2323b8911033e827fffaecf0cafa5a892a0904"}, + {file = "pydantic_core-2.10.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c120c9ce3b163b985a3b966bb701114beb1da4b0468b9b236fc754783d85aa3"}, + {file = "pydantic_core-2.10.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15d6bca84ffc966cc9976b09a18cf9543ed4d4ecbd97e7086f9ce9327ea48891"}, + {file = "pydantic_core-2.10.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5cabb9710f09d5d2e9e2748c3e3e20d991a4c5f96ed8f1132518f54ab2967221"}, + {file = "pydantic_core-2.10.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:82f55187a5bebae7d81d35b1e9aaea5e169d44819789837cdd4720d768c55d15"}, + {file = "pydantic_core-2.10.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1d40f55222b233e98e3921df7811c27567f0e1a4411b93d4c5c0f4ce131bc42f"}, + {file = "pydantic_core-2.10.1-cp37-none-win32.whl", hash = "sha256:14e09ff0b8fe6e46b93d36a878f6e4a3a98ba5303c76bb8e716f4878a3bee92c"}, + {file = "pydantic_core-2.10.1-cp37-none-win_amd64.whl", hash = "sha256:1396e81b83516b9d5c9e26a924fa69164156c148c717131f54f586485ac3c15e"}, + {file = "pydantic_core-2.10.1-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:6835451b57c1b467b95ffb03a38bb75b52fb4dc2762bb1d9dbed8de31ea7d0fc"}, + {file = "pydantic_core-2.10.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b00bc4619f60c853556b35f83731bd817f989cba3e97dc792bb8c97941b8053a"}, + {file = "pydantic_core-2.10.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fa467fd300a6f046bdb248d40cd015b21b7576c168a6bb20aa22e595c8ffcdd"}, + {file = "pydantic_core-2.10.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d99277877daf2efe074eae6338453a4ed54a2d93fb4678ddfe1209a0c93a2468"}, + {file = "pydantic_core-2.10.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa7db7558607afeccb33c0e4bf1c9a9a835e26599e76af6fe2fcea45904083a6"}, + {file = "pydantic_core-2.10.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aad7bd686363d1ce4ee930ad39f14e1673248373f4a9d74d2b9554f06199fb58"}, + {file = "pydantic_core-2.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:443fed67d33aa85357464f297e3d26e570267d1af6fef1c21ca50921d2976302"}, + {file = "pydantic_core-2.10.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:042462d8d6ba707fd3ce9649e7bf268633a41018d6a998fb5fbacb7e928a183e"}, + {file = "pydantic_core-2.10.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ecdbde46235f3d560b18be0cb706c8e8ad1b965e5c13bbba7450c86064e96561"}, + {file = "pydantic_core-2.10.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ed550ed05540c03f0e69e6d74ad58d026de61b9eaebebbaaf8873e585cbb18de"}, + {file = "pydantic_core-2.10.1-cp38-none-win32.whl", hash = "sha256:8cdbbd92154db2fec4ec973d45c565e767ddc20aa6dbaf50142676484cbff8ee"}, + {file = "pydantic_core-2.10.1-cp38-none-win_amd64.whl", hash = "sha256:9f6f3e2598604956480f6c8aa24a3384dbf6509fe995d97f6ca6103bb8c2534e"}, + {file = "pydantic_core-2.10.1-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:655f8f4c8d6a5963c9a0687793da37b9b681d9ad06f29438a3b2326d4e6b7970"}, + {file = "pydantic_core-2.10.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e570ffeb2170e116a5b17e83f19911020ac79d19c96f320cbfa1fa96b470185b"}, + {file = "pydantic_core-2.10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64322bfa13e44c6c30c518729ef08fda6026b96d5c0be724b3c4ae4da939f875"}, + {file = "pydantic_core-2.10.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:485a91abe3a07c3a8d1e082ba29254eea3e2bb13cbbd4351ea4e5a21912cc9b0"}, + {file = "pydantic_core-2.10.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7c2b8eb9fc872e68b46eeaf835e86bccc3a58ba57d0eedc109cbb14177be531"}, + {file = "pydantic_core-2.10.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a5cb87bdc2e5f620693148b5f8f842d293cae46c5f15a1b1bf7ceeed324a740c"}, + {file = "pydantic_core-2.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25bd966103890ccfa028841a8f30cebcf5875eeac8c4bde4fe221364c92f0c9a"}, + {file = "pydantic_core-2.10.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f323306d0556351735b54acbf82904fe30a27b6a7147153cbe6e19aaaa2aa429"}, + {file = "pydantic_core-2.10.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0c27f38dc4fbf07b358b2bc90edf35e82d1703e22ff2efa4af4ad5de1b3833e7"}, + {file = "pydantic_core-2.10.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f1365e032a477c1430cfe0cf2856679529a2331426f8081172c4a74186f1d595"}, + {file = "pydantic_core-2.10.1-cp39-none-win32.whl", hash = "sha256:a1c311fd06ab3b10805abb72109f01a134019739bd3286b8ae1bc2fc4e50c07a"}, + {file = "pydantic_core-2.10.1-cp39-none-win_amd64.whl", hash = "sha256:ae8a8843b11dc0b03b57b52793e391f0122e740de3df1474814c700d2622950a"}, + {file = "pydantic_core-2.10.1-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:d43002441932f9a9ea5d6f9efaa2e21458221a3a4b417a14027a1d530201ef1b"}, + {file = "pydantic_core-2.10.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:fcb83175cc4936a5425dde3356f079ae03c0802bbdf8ff82c035f8a54b333521"}, + {file = "pydantic_core-2.10.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:962ed72424bf1f72334e2f1e61b68f16c0e596f024ca7ac5daf229f7c26e4208"}, + {file = "pydantic_core-2.10.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cf5bb4dd67f20f3bbc1209ef572a259027c49e5ff694fa56bed62959b41e1f9"}, + {file = "pydantic_core-2.10.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e544246b859f17373bed915182ab841b80849ed9cf23f1f07b73b7c58baee5fb"}, + {file = "pydantic_core-2.10.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:c0877239307b7e69d025b73774e88e86ce82f6ba6adf98f41069d5b0b78bd1bf"}, + {file = "pydantic_core-2.10.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:53df009d1e1ba40f696f8995683e067e3967101d4bb4ea6f667931b7d4a01357"}, + {file = "pydantic_core-2.10.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a1254357f7e4c82e77c348dabf2d55f1d14d19d91ff025004775e70a6ef40ada"}, + {file = "pydantic_core-2.10.1-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:524ff0ca3baea164d6d93a32c58ac79eca9f6cf713586fdc0adb66a8cdeab96a"}, + {file = "pydantic_core-2.10.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f0ac9fb8608dbc6eaf17956bf623c9119b4db7dbb511650910a82e261e6600f"}, + {file = "pydantic_core-2.10.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:320f14bd4542a04ab23747ff2c8a778bde727158b606e2661349557f0770711e"}, + {file = "pydantic_core-2.10.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:63974d168b6233b4ed6a0046296803cb13c56637a7b8106564ab575926572a55"}, + {file = "pydantic_core-2.10.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:417243bf599ba1f1fef2bb8c543ceb918676954734e2dcb82bf162ae9d7bd514"}, + {file = "pydantic_core-2.10.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:dda81e5ec82485155a19d9624cfcca9be88a405e2857354e5b089c2a982144b2"}, + {file = "pydantic_core-2.10.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:14cfbb00959259e15d684505263d5a21732b31248a5dd4941f73a3be233865b9"}, + {file = "pydantic_core-2.10.1-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:631cb7415225954fdcc2a024119101946793e5923f6c4d73a5914d27eb3d3a05"}, + {file = "pydantic_core-2.10.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:bec7dd208a4182e99c5b6c501ce0b1f49de2802448d4056091f8e630b28e9a52"}, + {file = "pydantic_core-2.10.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:149b8a07712f45b332faee1a2258d8ef1fb4a36f88c0c17cb687f205c5dc6e7d"}, + {file = "pydantic_core-2.10.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d966c47f9dd73c2d32a809d2be529112d509321c5310ebf54076812e6ecd884"}, + {file = "pydantic_core-2.10.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7eb037106f5c6b3b0b864ad226b0b7ab58157124161d48e4b30c4a43fef8bc4b"}, + {file = "pydantic_core-2.10.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:154ea7c52e32dce13065dbb20a4a6f0cc012b4f667ac90d648d36b12007fa9f7"}, + {file = "pydantic_core-2.10.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e562617a45b5a9da5be4abe72b971d4f00bf8555eb29bb91ec2ef2be348cd132"}, + {file = "pydantic_core-2.10.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:f23b55eb5464468f9e0e9a9935ce3ed2a870608d5f534025cd5536bca25b1402"}, + {file = "pydantic_core-2.10.1-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:e9121b4009339b0f751955baf4543a0bfd6bc3f8188f8056b1a25a2d45099934"}, + {file = "pydantic_core-2.10.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:0523aeb76e03f753b58be33b26540880bac5aa54422e4462404c432230543f33"}, + {file = "pydantic_core-2.10.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e0e2959ef5d5b8dc9ef21e1a305a21a36e254e6a34432d00c72a92fdc5ecda5"}, + {file = "pydantic_core-2.10.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da01bec0a26befab4898ed83b362993c844b9a607a86add78604186297eb047e"}, + {file = "pydantic_core-2.10.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f2e9072d71c1f6cfc79a36d4484c82823c560e6f5599c43c1ca6b5cdbd54f881"}, + {file = "pydantic_core-2.10.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:f36a3489d9e28fe4b67be9992a23029c3cec0babc3bd9afb39f49844a8c721c5"}, + {file = "pydantic_core-2.10.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f64f82cc3443149292b32387086d02a6c7fb39b8781563e0ca7b8d7d9cf72bd7"}, + {file = "pydantic_core-2.10.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b4a6db486ac8e99ae696e09efc8b2b9fea67b63c8f88ba7a1a16c24a057a0776"}, + {file = "pydantic_core-2.10.1.tar.gz", hash = "sha256:0f8682dbdd2f67f8e1edddcbffcc29f60a6182b4901c367fc8c1c40d30bb0a82"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pytest" @@ -816,6 +915,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -823,8 +923,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -841,6 +948,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -848,6 +956,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -929,13 +1038,13 @@ files = [ [[package]] name = "urllib3" -version = "2.0.4" +version = "2.0.5" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.7" files = [ - {file = "urllib3-2.0.4-py3-none-any.whl", hash = "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4"}, - {file = "urllib3-2.0.4.tar.gz", hash = "sha256:8d22f86aae8ef5e410d4f539fde9ce6b2113a001bb4d189e0aed70642d602b11"}, + {file = "urllib3-2.0.5-py3-none-any.whl", hash = "sha256:ef16afa8ba34a1f989db38e1dbbe0c302e4289a47856990d0682e374563ce35e"}, + {file = "urllib3-2.0.5.tar.gz", hash = "sha256:13abf37382ea2ce6fb744d4dad67838eec857c9f4f57009891805e0b5e123594"}, ] [package.extras] @@ -1050,4 +1159,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "0db2f97d52c557dd7f90c55b4ad5bbe308c957c5f7f99fec53c57e0a13822cb4" +content-hash = "b7fab8703967f2616ea59a98a437cd30f97f0c8d2a06e399d688814a2a2c64f8" diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 915ac7aa..4fe6e8b0 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "text-generation" -version = "0.6.0" +version = "0.6.1" description = "Hugging Face Text Generation Python Client" license = "Apache-2.0" authors = ["Olivier Dehaene "] diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index ff7f66a3..0bf80f8c 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -482,7 +482,6 @@ class AsyncClient: headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: async with session.post(self.base_url, json=request.dict()) as resp: - if resp.status != 200: raise parse_error(resp.status, await resp.json()) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 6d6a0536..aa02d8d8 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -40,7 +40,7 @@ class Parameters(BaseModel): # Get decoder input token logprobs and ids decoder_input_details: bool = False # Return the N most likely tokens at each step - top_n_tokens: Optional[int] + top_n_tokens: Optional[int] = None @validator("best_of") def valid_best_of(cls, field_value, values): @@ -188,7 +188,7 @@ class BestOfSequence(BaseModel): # Generated tokens tokens: List[Token] # Most likely tokens - top_tokens: Optional[List[List[Token]]] + top_tokens: Optional[List[List[Token]]] = None # `generate` details @@ -204,7 +204,7 @@ class Details(BaseModel): # Generated tokens tokens: List[Token] # Most likely tokens - top_tokens: Optional[List[List[Token]]] + top_tokens: Optional[List[List[Token]]] = None # Additional sequences when using the `best_of` parameter best_of_sequences: Optional[List[BestOfSequence]] = None @@ -232,7 +232,7 @@ class StreamResponse(BaseModel): # Generated token token: Token # Most likely tokens - top_tokens: Optional[List[Token]] + top_tokens: Optional[List[Token]] = None # Complete generated text # Only available when the generation is finished generated_text: Optional[str] = None diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index c689b550..eb34c1f6 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -34,10 +34,17 @@ Options: [env: NUM_SHARD=] --quantize - Whether you want the model to be quantized. This will use `bitsandbytes` for quantization on the fly, or `gptq`. 4bit quantization is available through `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options + Whether you want the model to be quantized [env: QUANTIZE=] - [possible values: bitsandbytes, bitsandbytes-nf4, bitsandbytes-fp4, gptq, awq] + + Possible values: + - awq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models whereever possible because of the better latency + - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git + - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels whereever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels + - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 + - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 + - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model --dtype The dtype to be forced upon the model. This option cannot be used with `--quantize` diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index c5768d9a..5d645759 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -18,6 +18,8 @@ The following models are optimized and can be served with TGI, which uses custom - [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b) - [MPT](https://huggingface.co/mosaicml/mpt-30b) - [Llama V2](https://huggingface.co/meta-llama) +- [Code Llama](https://huggingface.co/codellama) +- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models: diff --git a/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral.json b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral.json new file mode 100644 index 00000000..4e7de9a6 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -12.9140625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.7578125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 28747, + "logprob": -0.54785156, + "special": false, + "text": ":" + }, + { + "id": 3169, + "logprob": -1.4091797, + "special": false, + "text": " Let" + }, + { + "id": 307, + "logprob": -3.0273438, + "special": false, + "text": " n" + }, + { + "id": 327, + "logprob": -0.94433594, + "special": false, + "text": " =" + }, + { + "id": 28705, + "logprob": -0.81347656, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.2958984, + "special": false, + "text": "1" + }, + { + "id": 28734, + "logprob": -2.0644531, + "special": false, + "text": "0" + }, + { + "id": 387, + "logprob": -1.9580078, + "special": false, + "text": " -" + }, + { + "id": 28705, + "logprob": -0.5073242, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.1816406, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": ": Let n = 10 - 1" +} diff --git a/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_all_params.json b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_all_params.json new file mode 100644 index 00000000..c0dc6471 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -12.9140625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.7578125, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 28747, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 3169, + "logprob": -0.1307373, + "special": false, + "text": " Let" + }, + { + "id": 332, + "logprob": -2.3359375, + "special": false, + "text": " u" + }, + { + "id": 347, + "logprob": 0.0, + "special": false, + "text": " be" + }, + { + "id": 325, + "logprob": -1.0234375, + "special": false, + "text": " (" + }, + { + "id": 28734, + "logprob": -2.0292969, + "special": false, + "text": "0" + }, + { + "id": 648, + "logprob": -1.0439453, + "special": false, + "text": " +" + }, + { + "id": 28705, + "logprob": -0.24499512, + "special": false, + "text": " " + }, + { + "id": 28770, + "logprob": -0.5073242, + "special": false, + "text": "3" + }, + { + "id": 387, + "logprob": -1.5507812, + "special": false, + "text": " -" + } + ], + "top_tokens": null + }, + "generated_text": "Test request: Let u be (0 + 3 -" +} diff --git a/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_load.json b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_load.json new file mode 100644 index 00000000..9d133077 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -12.9140625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.7578125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 28747, + "logprob": -0.55078125, + "special": false, + "text": ":" + }, + { + "id": 3169, + "logprob": -1.4140625, + "special": false, + "text": " Let" + }, + { + "id": 307, + "logprob": -3.0273438, + "special": false, + "text": " n" + }, + { + "id": 327, + "logprob": -0.94140625, + "special": false, + "text": " =" + }, + { + "id": 28705, + "logprob": -0.8173828, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.2978516, + "special": false, + "text": "1" + }, + { + "id": 28734, + "logprob": -2.0664062, + "special": false, + "text": "0" + }, + { + "id": 387, + "logprob": -1.9560547, + "special": false, + "text": " -" + }, + { + "id": 28705, + "logprob": -0.5078125, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.1787109, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": ": Let n = 10 - 1" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -12.9140625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.7578125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 28747, + "logprob": -0.54785156, + "special": false, + "text": ":" + }, + { + "id": 3169, + "logprob": -1.4111328, + "special": false, + "text": " Let" + }, + { + "id": 307, + "logprob": -3.0292969, + "special": false, + "text": " n" + }, + { + "id": 327, + "logprob": -0.94433594, + "special": false, + "text": " =" + }, + { + "id": 28705, + "logprob": -0.8178711, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.2939453, + "special": false, + "text": "1" + }, + { + "id": 28734, + "logprob": -2.0644531, + "special": false, + "text": "0" + }, + { + "id": 387, + "logprob": -1.9550781, + "special": false, + "text": " -" + }, + { + "id": 28705, + "logprob": -0.5078125, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.1796875, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": ": Let n = 10 - 1" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -12.9140625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.7578125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 28747, + "logprob": -0.55078125, + "special": false, + "text": ":" + }, + { + "id": 3169, + "logprob": -1.4140625, + "special": false, + "text": " Let" + }, + { + "id": 307, + "logprob": -3.0273438, + "special": false, + "text": " n" + }, + { + "id": 327, + "logprob": -0.94140625, + "special": false, + "text": " =" + }, + { + "id": 28705, + "logprob": -0.8173828, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.2978516, + "special": false, + "text": "1" + }, + { + "id": 28734, + "logprob": -2.0664062, + "special": false, + "text": "0" + }, + { + "id": 387, + "logprob": -1.9560547, + "special": false, + "text": " -" + }, + { + "id": 28705, + "logprob": -0.5078125, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.1787109, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": ": Let n = 10 - 1" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -12.9140625, + "text": "Test" + }, + { + "id": 2159, + "logprob": -10.7578125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 28747, + "logprob": -0.55078125, + "special": false, + "text": ":" + }, + { + "id": 3169, + "logprob": -1.4140625, + "special": false, + "text": " Let" + }, + { + "id": 307, + "logprob": -3.0273438, + "special": false, + "text": " n" + }, + { + "id": 327, + "logprob": -0.94140625, + "special": false, + "text": " =" + }, + { + "id": 28705, + "logprob": -0.8173828, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.2978516, + "special": false, + "text": "1" + }, + { + "id": 28734, + "logprob": -2.0664062, + "special": false, + "text": "0" + }, + { + "id": 387, + "logprob": -1.9560547, + "special": false, + "text": " -" + }, + { + "id": 28705, + "logprob": -0.5078125, + "special": false, + "text": " " + }, + { + "id": 28740, + "logprob": -1.1787109, + "special": false, + "text": "1" + } + ], + "top_tokens": null + }, + "generated_text": ": Let n = 10 - 1" + } +] diff --git a/integration-tests/models/test_flash_mistral.py b/integration-tests/models/test_flash_mistral.py new file mode 100644 index 00000000..63cb09b5 --- /dev/null +++ b/integration-tests/models/test_flash_mistral.py @@ -0,0 +1,60 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_mistral_handle(launcher): + with launcher("mistralai/Mistral-7B-Instruct-v0.1") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_mistral(flash_mistral_handle): + await flash_mistral_handle.health(300) + return flash_mistral_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_mistral(flash_mistral, response_snapshot): + response = await flash_mistral.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_mistral_all_params(flash_mistral, response_snapshot): + response = await flash_mistral.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot): + responses = await generate_load( + flash_mistral, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/proto/generate.proto b/proto/generate.proto index 3f607dc5..c873e661 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -31,6 +31,7 @@ message InfoResponse { bool requires_padding = 1; string dtype = 2; string device_type = 3; + optional uint32 window_size = 4; } /// Empty request diff --git a/router/src/infer.rs b/router/src/infer.rs index 67b5bde2..787ccfcf 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -50,10 +50,11 @@ impl Infer { max_waiting_tokens: usize, max_concurrent_requests: usize, requires_padding: bool, + window_size: Option, generation_health: Arc, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding, 16); + let queue = Queue::new(requires_padding, 16, window_size); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/router/src/queue.rs b/router/src/queue.rs index e97a168e..1ab9eb11 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -2,6 +2,7 @@ use crate::infer::InferError; use crate::infer::InferStreamResponse; use crate::validation::ValidGenerateRequest; use nohash_hasher::{BuildNoHashHasher, IntMap}; +use std::cmp::min; use std::collections::VecDeque; use text_generation_client::{Batch, Request}; use tokio::sync::oneshot; @@ -33,12 +34,17 @@ pub(crate) struct Queue { } impl Queue { - pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self { + pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option) -> Self { // Create channel let (queue_sender, queue_receiver) = flume::unbounded(); // Launch background queue task - tokio::spawn(queue_task(requires_padding, block_size, queue_receiver)); + tokio::spawn(queue_task( + requires_padding, + block_size, + window_size, + queue_receiver, + )); Self { queue_sender } } @@ -84,9 +90,10 @@ impl Queue { async fn queue_task( requires_padding: bool, block_size: u32, + window_size: Option, receiver: flume::Receiver, ) { - let mut state = State::new(requires_padding, block_size); + let mut state = State::new(requires_padding, block_size, window_size); while let Ok(cmd) = receiver.recv_async().await { match cmd { @@ -126,16 +133,20 @@ struct State { /// Paged Attention block size block_size: u32, + + /// Sliding window + window_size: Option, } impl State { - fn new(requires_padding: bool, block_size: u32) -> Self { + fn new(requires_padding: bool, block_size: u32, window_size: Option) -> Self { Self { entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, requires_padding, block_size, + window_size, } } @@ -204,11 +215,17 @@ impl State { if self.requires_padding { decode_tokens += entry.request.stopping_parameters.max_new_tokens; } else { + let max_new_tokens = match self.window_size { + None => entry.request.stopping_parameters.max_new_tokens, + Some(window_size) => min( + window_size.saturating_sub(entry.request.input_length), + entry.request.stopping_parameters.max_new_tokens, + ), + }; + // pad to block size decode_tokens += - ((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1) - / self.block_size) - * self.block_size; + ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size; } if prefill_tokens > prefill_token_budget @@ -342,7 +359,7 @@ mod tests { #[test] fn test_append() { - let mut state = State::new(false, 1); + let mut state = State::new(false, 1, None); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -358,7 +375,7 @@ mod tests { #[test] fn test_next_batch_empty() { - let mut state = State::new(false, 1); + let mut state = State::new(false, 1, None); assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none()); @@ -366,7 +383,7 @@ mod tests { #[test] fn test_next_batch_min_size() { - let mut state = State::new(false, 1); + let mut state = State::new(false, 1, None); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -398,7 +415,7 @@ mod tests { #[test] fn test_next_batch_token_budget() { - let mut state = State::new(false, 1); + let mut state = State::new(false, 1, None); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -431,14 +448,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1); + let queue = Queue::new(false, 1, None); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1); + let queue = Queue::new(false, 1, None); assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); @@ -446,7 +463,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1); + let queue = Queue::new(false, 1, None); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -479,7 +496,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1); + let queue = Queue::new(false, 1, None); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -504,7 +521,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1); + let queue = Queue::new(false, 1, None); let (entry, _) = default_entry(); queue.append(entry); diff --git a/router/src/server.rs b/router/src/server.rs index fbc444fc..f254afd8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -595,6 +595,7 @@ pub async fn run( max_waiting_tokens, max_concurrent_requests, shard_info.requires_padding, + shard_info.window_size, generation_health, ); diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index a7d63356..cdea8431 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc +flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c flash-attention-v2: # Clone flash attention diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 96bfc108..2e965da0 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,4 +1,4 @@ -vllm_commit := e86af624d059969b0fb07b075b1d338bf10c3365 +vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78 vllm: # Clone vllm diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 96dd1ed1..dca3612f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -67,6 +67,16 @@ if FLASH_ATTENTION: __all__.append(FlashLlama) __all__.append(IDEFICSSharded) +MISTRAL = True +try: + from text_generation_server.models.flash_mistral import FlashMistral +except ImportError as e: + logger.warning(f"Could not import Mistral model: {e}") + MISTRAL = False + +if MISTRAL: + __all__.append(FlashMistral) + def get_model( model_id: str, @@ -237,7 +247,18 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "opt": + if model_type == "mistral": + if MISTRAL: + return FlashMistral( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + raise NotImplementedError("Mistral model requires flash attention v2") + + if model_type == "opt": return OPTSharded( model_id, revision, @@ -246,7 +267,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "t5": + if model_type == "t5": return T5Sharded( model_id, revision, @@ -254,7 +275,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == "idefics": + if model_type == "idefics": if FLASH_ATTENTION: return IDEFICSSharded( model_id, diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py new file mode 100644 index 00000000..2e6ae086 --- /dev/null +++ b/server/text_generation_server/models/cache_manager.py @@ -0,0 +1,135 @@ +import math +import torch + +from typing import Optional, List, Tuple + +BLOCK_SIZE: int = 16 +# Will be set in warmup +CACHE_MANAGER: Optional["CacheManager"] = None + + +class CacheManager: + def __init__( + self, + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + repeat_slots: bool, + dtype: torch.dtype, + device: torch.device, + ): + self.block_size = BLOCK_SIZE + self.num_blocks = num_blocks + self.repeat_slots = repeat_slots + + element_size = torch.tensor([], dtype=dtype).element_size() + x = self.block_size // element_size + + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, self.block_size, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, self.block_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") + self.slots = torch.arange( + 0, num_blocks * self.block_size, dtype=torch.int32 + ).view(num_blocks, self.block_size) + + def allocate( + self, + needed_blocks_slots: List[Tuple[int, int]], + blocks: int, + max_blocks: int, + device: torch.device, + ): + # Get free blocks indices by finding values in mask that are not set to 0 + free_block_indices = self.free_block_mask.nonzero() + assert ( + len(free_block_indices) >= blocks + ), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" + + # Slice by the number of required blocks + block_indices = free_block_indices[:blocks] + block_indices = block_indices.flatten() + + # Padded block tables + block_tables_tensor = torch.zeros( + (len(needed_blocks_slots), max_blocks), dtype=torch.int32 + ) + + # Allocate paged attention blocks + cumulative_blocks = 0 + slots = [] + block_tables = [] + for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots): + # Get allocated blocks for this sequence + allocated_blocks = block_indices[ + cumulative_blocks : cumulative_blocks + needed_blocks + ] + # Get slots for the allocated blocks + all_slots = self.slots[allocated_blocks].flatten() + + # Repeat slots in the case of context sliding window + if needed_slots > len(all_slots) and self.repeat_slots: + repeats = math.ceil(needed_slots / len(all_slots)) + all_slots = all_slots.repeat(repeats) + + allocated_slots = all_slots[:needed_slots] + + slots.append(allocated_slots) + block_tables.append(allocated_blocks.tolist()) + block_tables_tensor[i, :needed_blocks] = allocated_blocks + cumulative_blocks += needed_blocks + + block_tables = block_tables + block_tables_tensor = block_tables_tensor.to(device) + slots = torch.concat(slots).to(device) + + # Allocate the required number of blocks by setting the mask to 0 + self.free_block_mask[block_indices] = 0 + + return block_tables, block_tables_tensor, slots + + def free(self, block_indices: Optional[List[int]]): + if block_indices is not None and block_indices: + # Reset mask + self.free_block_mask[block_indices] = 1 + + +def set_cache_manager( + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + repeat_slots: bool, + dtype: torch.dtype, + device: torch.device, +) -> CacheManager: + global CACHE_MANAGER + if CACHE_MANAGER is not None: + del CACHE_MANAGER + torch.cuda.empty_cache() + + CACHE_MANAGER = CacheManager( + num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device + ) + return CACHE_MANAGER + + +def get_cache_manager() -> CacheManager: + global CACHE_MANAGER + if CACHE_MANAGER is None: + raise RuntimeError("cache manager was not initialized") + + return CACHE_MANAGER diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py new file mode 100644 index 00000000..77b7f230 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -0,0 +1,532 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +# Flash attention imports +import dropout_layer_norm + +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + +from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2 +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + TensorParallelHead, + get_linear, +) + +if not HAS_FLASH_ATTN_V2: + raise ImportError("Mistral model requires flash attn v2") + + +class MistralConfig(PretrainedConfig): + model_type = "mistral" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class MistralRMSNorm(nn.Module): + def __init__(self, prefix, weights, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states, residual + else: + # faster post attention rms norm + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear( + get_linear(weight, bias=None, quantize=config.quantize) + ) + + +class MistralAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.max_past = ( + config.sliding_window if config.sliding_window is not None else 0 + ) + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + vllm_cache_ops.reshape_and_cache( + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + window_size_left=self.max_past, + ) + # Decode + else: + # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class MistralMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none", + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class MistralLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = MistralAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = MistralRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MistralRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class MistralModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + MistralLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = MistralRMSNorm( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashMistralForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = MistralModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ) + self.max_past = config.sliding_window + if self.max_past is None: + raise ValueError("max_past cannot be None") + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + else: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + max_s = min(self.max_past, max_s) + input_lengths = torch.clamp(input_lengths, max=self.max_past) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 34c7f633..1fe40c0c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -19,99 +19,17 @@ from text_generation_server.models.types import ( GeneratedText, TopTokens, ) +from text_generation_server.models.cache_manager import ( + get_cache_manager, + set_cache_manager, + BLOCK_SIZE, +) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION tracer = trace.get_tracer(__name__) -BLOCK_SIZE = 16 -# Will be set in warmup -CACHE_MANAGER: Optional["CacheManager"] = None - - -class CacheManager: - def __init__( - self, - num_blocks: int, - num_layers: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - ): - self.block_size = BLOCK_SIZE - self.num_blocks = num_blocks - - element_size = torch.tensor([], dtype=dtype).element_size() - x = self.block_size // element_size - - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, head_size // x, self.block_size, x), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, head_size, self.block_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") - self.slots = torch.arange( - 0, num_blocks * self.block_size, dtype=torch.int32 - ).view(num_blocks, self.block_size) - - def allocate(self, batch: "FlashCausalLMBatch"): - # Get free blocks indices by finding values in mask that are not set to 0 - free_block_indices = self.free_block_mask.nonzero() - assert ( - len(free_block_indices) >= batch.blocks - ), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks" - - # Slice by the number of required blocks - block_indices = free_block_indices[: batch.blocks] - block_indices = block_indices.flatten() - - # Padded block tables - block_tables_tensor = torch.zeros( - (len(batch), batch.max_blocks), dtype=torch.int32 - ) - - # Allocate paged attention blocks - cumulative_blocks = 0 - slots = [] - block_tables = [] - for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots): - # Get allocated blocks for this sequence - allocated_blocks = block_indices[ - cumulative_blocks : cumulative_blocks + needed_blocks - ] - # Get slots for the allocated blocks - allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots] - - slots.append(allocated_slots) - block_tables.append(allocated_blocks.tolist()) - block_tables_tensor[i, :needed_blocks] = allocated_blocks - cumulative_blocks += needed_blocks - - batch.needed_blocks_slots = None - batch.block_tables = block_tables - batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device) - batch.slots = torch.concat(slots).to(batch.input_ids.device) - - # Allocate the required number of blocks by setting the mask to 0 - self.free_block_mask[block_indices] = 0 - - def free(self, block_indices: Optional[List[int]]): - if block_indices is not None and block_indices: - # Reset mask - self.free_block_mask[block_indices] = 1 - @dataclass class FlashCausalLMBatch(Batch): @@ -481,7 +399,6 @@ class FlashCausalLMBatch(Batch): max_blocks = max(max_blocks, len(request_block_table)) - global CACHE_MANAGER block_indices_to_free = [] # Iterate on all requests for i, r in enumerate(self.requests): @@ -489,7 +406,7 @@ class FlashCausalLMBatch(Batch): if r.id not in requests_idx_mapping.keys(): block_indices_to_free.extend(self.block_tables[i]) # Free blocks - CACHE_MANAGER.free(block_indices_to_free) + get_cache_manager().free(block_indices_to_free) # Needed to avoid dropping blocks when the batches will go out of scope self.block_tables = None @@ -508,7 +425,7 @@ class FlashCausalLMBatch(Batch): # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) - return FlashCausalLMBatch( + return type(self)( batch_id=self.batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, @@ -665,7 +582,7 @@ class FlashCausalLMBatch(Batch): b.block_tables = None del b - return FlashCausalLMBatch( + return cls( batch_id=batches[0].batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, @@ -698,9 +615,10 @@ class FlashCausalLMBatch(Batch): def __del__(self): if self.block_tables is not None and self.block_tables: - global CACHE_MANAGER # Free blocks - CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables))) + get_cache_manager().free( + list(itertools.chain.from_iterable(self.block_tables)) + ) def __len__(self): return len(self.requests) @@ -718,6 +636,7 @@ class FlashCausalLM(Model): device: torch.device, rank: int = 0, world_size: int = 1, + sliding_window: Optional[int] = None, ): self.num_layers = num_layers self.num_kv_heads = num_kv_heads @@ -731,6 +650,7 @@ class FlashCausalLM(Model): device=device, rank=rank, world_size=world_size, + sliding_window=sliding_window, ) @property @@ -738,15 +658,14 @@ class FlashCausalLM(Model): return FlashCausalLMBatch def warmup(self, batch: FlashCausalLMBatch): - global CACHE_MANAGER - torch.cuda.empty_cache() try: - CACHE_MANAGER = CacheManager( + cache_manager = set_cache_manager( batch.blocks, self.num_layers, self.num_kv_heads, self.head_size, + self.sliding_window is not None, self.dtype, self.device, ) @@ -775,48 +694,36 @@ class FlashCausalLM(Model): num_blocks = ( int(free_memory // total_cache_size) # Add batch.blocks as we allocated it above, so it is included in the peak memory. - + CACHE_MANAGER.num_blocks + + cache_manager.num_blocks ) - del CACHE_MANAGER del batch - torch.cuda.empty_cache() + del cache_manager - CACHE_MANAGER = CacheManager( + set_cache_manager( num_blocks, self.num_layers, self.num_kv_heads, self.head_size, + self.sliding_window is not None, self.dtype, self.device, ) return int(num_blocks * BLOCK_SIZE) - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - block_tables: torch.Tensor, - slots: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, - lm_head_indices: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - global CACHE_MANAGER - + def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward return self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=CACHE_MANAGER.kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - lm_head_indices=lm_head_indices, + input_ids=batch.input_ids, + position_ids=batch.position_ids, + cu_seqlen_prefill=batch.cu_seqlen_prefill, + kv_cache=get_cache_manager().kv_cache, + block_tables=batch.block_tables_tensor, + slots=batch.slots[batch.slot_indices], + input_lengths=batch.input_lengths_tensor, + max_s=batch.max_seqlen, + lm_head_indices=batch.prefill_head_indices, ) @tracer.start_as_current_span("generate_token") @@ -828,19 +735,19 @@ class FlashCausalLM(Model): if batch.needed_blocks_slots: # Allocate blocks to this batch - CACHE_MANAGER.allocate(batch) + block_tables, block_tables_tensor, slots = get_cache_manager().allocate( + batch.needed_blocks_slots, + batch.blocks, + batch.max_blocks, + batch.input_ids.device, + ) + batch.needed_blocks_slots = None + batch.block_tables = block_tables + batch.block_tables_tensor = block_tables_tensor + batch.slots = slots try: - out = self.forward( - batch.input_ids, - batch.position_ids, - batch.cu_seqlen_prefill, - batch.block_tables_tensor, - batch.slots[batch.slot_indices], - batch.input_lengths_tensor, - batch.max_seqlen, - batch.prefill_head_indices, - ) + out = self.forward(batch) except Exception as e: del batch raise e diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py new file mode 100644 index 00000000..919e4625 --- /dev/null +++ b/server/text_generation_server/models/flash_mistral.py @@ -0,0 +1,357 @@ +import math +import torch +import torch.distributed + +import numpy as np + +from dataclasses import dataclass +from opentelemetry import trace +from transformers import PreTrainedTokenizerBase +from transformers.models.llama import LlamaTokenizerFast +from typing import Optional, Tuple, Type + +from text_generation_server.pb import generate_pb2 +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE +from text_generation_server.models.cache_manager import ( + get_cache_manager, + set_cache_manager, +) +from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + MistralConfig, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, + HeterogeneousNextTokenChooser, + StoppingCriteria, +) + +tracer = trace.get_tracer(__name__) + +# Will be set in init +SLIDING_WINDOW: Optional[int] = None +SLIDING_WINDOW_BLOCKS: Optional[int] = None + + +# Adds windowing logic to FlashCausalLMBatch +@dataclass +class FlashMistralBatch(FlashCausalLMBatch): + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] = None + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashCausalLMBatch": + global SLIDING_WINDOW + global SLIDING_WINDOW_BLOCKS + + batch_inputs = [] + max_truncation = 0 + for r in pb.requests: + batch_inputs.append(r.inputs) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, truncation=True, max_length=max_truncation + )["input_ids"] + + position_ids = [] + cu_seqlen_prefill = [0] + needed_blocks_slots = [] + start_slots = [] + slot_indices = [] + prefill_cache_indices = [] + + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + requests_idx_mapping = {} + + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_head_indices = [] + prefill_next_token_indices = [] + prefill_cu_outlens = [0] + + next_token_chooser_parameters = [] + stopping_criterias = [] + top_n_tokens = [] + + # Cumulative length + cumulative_length = 0 + cumulative_max_length = 0 + prefill_out_cumulative_length = 0 + + blocks = 0 + max_seqlen = 0 + max_length = 0 + max_blocks = 0 + + # Parse batch + for i, (r, tokenized_input) in enumerate( + zip(pb.requests, batch_tokenized_inputs) + ): + # request id -> idx in list mapping + requests_idx_mapping[r.id] = i + + tokenized_input = tokenized_input[-r.truncate :] + + input_length = len(tokenized_input) + input_lengths.append(input_length) + + prefix_offsets.append(input_length - 5) + read_offsets.append(input_length) + + all_input_ids.append(tokenized_input) + + # Position ids + request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + position_ids.append(request_position_ids) + + # Add cumulative lengths of all previous inputs + cu_seqlen_prefill.append(cumulative_length + input_length) + + next_token_chooser_parameters.append(r.parameters) + + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + max_new_tokens = stopping_criteria.max_new_tokens + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) + + # Paged attention + # Remove one as the first token des not have a past + total_tokens = input_length + max_new_tokens - 1 + + # Needed blocks can not go over SLIDING_WINDOW_BLOCKS + needed_blocks = min( + math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS + ) + blocks += needed_blocks + + needed_blocks_slots.append((needed_blocks, total_tokens)) + start_slots.append(cumulative_max_length) + + request_slot_indices = torch.arange( + cumulative_max_length, + cumulative_max_length + input_length, + dtype=torch.int64, + ) + slot_indices.append(request_slot_indices) + + # Create tensor to slice into the kv tensor in prefill + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, input_length - SLIDING_WINDOW), + cumulative_length + input_length, + dtype=torch.int64, + ) + prefill_cache_indices.append(request_prefill_cache_indices) + + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs + + if r.prefill_logprobs: + prefill_head_indices.append(request_position_ids + cumulative_length) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], dtype=torch.int32 + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + # Update + cumulative_length += input_length + cumulative_max_length += total_tokens + max_seqlen = max(max_seqlen, input_length) + max_blocks = max(max_blocks, needed_blocks) + max_length = max(max_length, input_length + max_new_tokens) + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + next_token_chooser_parameters, dtype, device + ) + start_slots = torch.tensor(start_slots, dtype=torch.int64) + + # Padded all_input_ids_tensor + all_input_ids_tensor = np.zeros( + (len(all_input_ids), max_length), dtype=np.int64 + ) + for i, input_ids in enumerate(all_input_ids): + all_input_ids_tensor[i, : len(input_ids)] = input_ids + + # Create tensors on device + all_input_ids_tensor = torch.tensor( + all_input_ids_tensor, dtype=torch.int64, device=device + ) + + if len(pb.requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + position_ids = torch.cat(position_ids) + slot_indices = torch.cat(slot_indices) + prefill_cache_indices = torch.cat(prefill_cache_indices) + else: + input_ids = all_input_ids[0] + position_ids = position_ids[0] + slot_indices = slot_indices[0] + prefill_cache_indices = prefill_cache_indices[0] + + cu_seqlen_prefill = torch.tensor( + cu_seqlen_prefill, device=device, dtype=torch.int32 + ) + + position_ids = position_ids.to(device) + slot_indices = slot_indices.to(device) + prefill_cache_indices = prefill_cache_indices.to(device) + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + input_lengths_tensor = torch.tensor( + input_lengths, dtype=torch.int32, device=device + ) + + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.tensor( + torch.cat(prefill_head_indices), dtype=torch.int64, device=device + ) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + start_slots=start_slots, + slot_indices=slot_indices, + needed_blocks_slots=needed_blocks_slots, + block_tables=None, + block_tables_tensor=None, + slots=None, + max_seqlen=max_seqlen, + prefill_head_indices=prefill_head_indices, + prefill_next_token_indices=prefill_next_token_indices, + prefill_cu_outlens=prefill_cu_outlens, + input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + all_input_ids=all_input_ids, + all_input_ids_tensor=all_input_ids_tensor, + next_token_chooser=next_token_chooser, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + blocks=blocks, + max_blocks=max_blocks, + prefill_cache_indices=prefill_cache_indices, + ) + + +class FlashMistral(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + global SLIDING_WINDOW + global SLIDING_WINDOW_BLOCKS + + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashLlama is only available on GPU") + + tokenizer = LlamaTokenizerFast.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = MistralConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + + # Set context windows + SLIDING_WINDOW = config.sliding_window + SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE) + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id) + + model = FlashMistralForCausalLM(config, weights) + + torch.distributed.barrier(group=self.process_group) + super(FlashMistral, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + sliding_window=config.sliding_window, + ) + + @property + def batch_type(self) -> Type[FlashMistralBatch]: + return FlashMistralBatch + + def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: + # Model Forward + logits = self.model.forward( + input_ids=batch.input_ids, + position_ids=batch.position_ids, + cu_seqlen_prefill=batch.cu_seqlen_prefill, + kv_cache=get_cache_manager().kv_cache, + block_tables=batch.block_tables_tensor, + slots=batch.slots[batch.slot_indices], + input_lengths=batch.input_lengths_tensor, + max_s=batch.max_seqlen, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=batch.prefill_head_indices, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index f6e66d30..17d2ea9b 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -21,6 +21,7 @@ class Model(ABC): device: torch.device, rank: int = 0, world_size: int = 1, + sliding_window: Optional[int] = None, ): self.model = model.eval() self.tokenizer = tokenizer @@ -30,6 +31,7 @@ class Model(ABC): self.device = device self.rank = rank self.world_size = world_size + self.sliding_window = sliding_window self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) @@ -40,10 +42,14 @@ class Model(ABC): @property def info(self) -> InfoResponse: + if self.requires_padding and self.sliding_window is not None: + raise NotImplementedError("sliding_window is not implemented with padding") + return InfoResponse( requires_padding=self.requires_padding, dtype=str(self.dtype), device_type=self.device.type, + window_size=self.sliding_window, ) @property diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index c472d1fc..caf072b7 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -57,6 +57,7 @@ def attention( cu_seqlens, max_s, softmax_scale, + window_size_left=-1, ): if HAS_FLASH_ATTN_V2: return flash_attn_2_cuda.varlen_fwd( @@ -72,11 +73,18 @@ def attention( softmax_scale, False, True, + window_size_left, + 0, False, None, ) if HAS_FLASH_ATTN: + if window_size_left != 0: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" + ) + # Flash attention v1 requires q, k and v to have the same number of heads if k.shape[1] != q.shape[1]: # MQA expand diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 8be2463f..cf61e47b 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -53,6 +53,7 @@ try: except ImportError: pass + # Monkey patching @classmethod def load_layer_norm(cls, prefix, weights, eps): diff --git a/update_doc.py b/update_doc.py index a4c95743..7e8fb769 100644 --- a/update_doc.py +++ b/update_doc.py @@ -8,7 +8,9 @@ def main(): args = parser.parse_args() - output = subprocess.check_output(["text-generation-launcher", "--help"]).decode("utf-8") + output = subprocess.check_output(["text-generation-launcher", "--help"]).decode( + "utf-8" + ) final_doc = f"# Text-generation-launcher arguments\n```\n{output}\n```" filename = "docs/source/basic_tutorials/launcher.md" @@ -16,16 +18,20 @@ def main(): with open(filename, "r") as f: doc = f.read() if doc != final_doc: - tmp = "launcher.md" with open(tmp, "w") as g: g.write(final_doc) - diff = subprocess.run(["diff",tmp, filename], capture_output=True).stdout.decode("utf-8") + diff = subprocess.run( + ["diff", tmp, filename], capture_output=True + ).stdout.decode("utf-8") print(diff) - raise Exception("Doc is not up-to-date, run `python update_doc.py` in order to update it") + raise Exception( + "Doc is not up-to-date, run `python update_doc.py` in order to update it" + ) else: with open(filename, "w") as f: f.write(final_doc) + if __name__ == "__main__": main()