feat: add mistral model (#1071)
This commit is contained in:
parent
259a230028
commit
3b56d7669b
|
@ -68,6 +68,7 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
|
||||||
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
|
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
|
||||||
- [Llama V2](https://huggingface.co/meta-llama)
|
- [Llama V2](https://huggingface.co/meta-llama)
|
||||||
- [Code Llama](https://huggingface.co/codellama)
|
- [Code Llama](https://huggingface.co/codellama)
|
||||||
|
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||||
|
|
||||||
Other architectures are supported on a best effort basis using:
|
Other architectures are supported on a best effort basis using:
|
||||||
|
|
||||||
|
|
|
@ -140,6 +140,8 @@ class Parameters:
|
||||||
watermark: bool
|
watermark: bool
|
||||||
# Get decoder input token logprobs and ids
|
# Get decoder input token logprobs and ids
|
||||||
decoder_input_details: bool
|
decoder_input_details: bool
|
||||||
|
# Return the N most likely tokens at each step
|
||||||
|
top_n_tokens: Optional[int]
|
||||||
|
|
||||||
# Decoder input tokens
|
# Decoder input tokens
|
||||||
class InputToken:
|
class InputToken:
|
||||||
|
@ -189,6 +191,8 @@ class BestOfSequence:
|
||||||
prefill: List[InputToken]
|
prefill: List[InputToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
|
# Most likely tokens
|
||||||
|
top_tokens: Optional[List[List[Token]]]
|
||||||
|
|
||||||
|
|
||||||
# `generate` details
|
# `generate` details
|
||||||
|
@ -203,6 +207,8 @@ class Details:
|
||||||
prefill: List[InputToken]
|
prefill: List[InputToken]
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
|
# Most likely tokens
|
||||||
|
top_tokens: Optional[List[List[Token]]]
|
||||||
# Additional sequences when using the `best_of` parameter
|
# Additional sequences when using the `best_of` parameter
|
||||||
best_of_sequences: Optional[List[BestOfSequence]]
|
best_of_sequences: Optional[List[BestOfSequence]]
|
||||||
|
|
||||||
|
@ -229,6 +235,8 @@ class StreamDetails:
|
||||||
class StreamResponse:
|
class StreamResponse:
|
||||||
# Generated token
|
# Generated token
|
||||||
token: Token
|
token: Token
|
||||||
|
# Most likely tokens
|
||||||
|
top_tokens: Optional[List[Token]]
|
||||||
# Complete generated text
|
# Complete generated text
|
||||||
# Only available when the generation is finished
|
# Only available when the generation is finished
|
||||||
generated_text: Optional[str]
|
generated_text: Optional[str]
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiohttp"
|
name = "aiohttp"
|
||||||
|
@ -124,6 +124,20 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
frozenlist = ">=1.1.0"
|
frozenlist = ">=1.1.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "annotated-types"
|
||||||
|
version = "0.5.0"
|
||||||
|
description = "Reusable constraint types to use with typing.Annotated"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "annotated_types-0.5.0-py3-none-any.whl", hash = "sha256:58da39888f92c276ad970249761ebea80ba544b77acddaa1a4d6cf78287d45fd"},
|
||||||
|
{file = "annotated_types-0.5.0.tar.gz", hash = "sha256:47cdc3490d9ac1506ce92c7aaa76c579dc3509ff11e098fc867e5130ab7be802"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-timeout"
|
name = "async-timeout"
|
||||||
version = "4.0.3"
|
version = "4.0.3"
|
||||||
|
@ -693,55 +707,140 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pydantic"
|
name = "pydantic"
|
||||||
version = "1.10.12"
|
version = "2.4.2"
|
||||||
description = "Data validation and settings management using python type hints"
|
description = "Data validation using Python type hints"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "pydantic-1.10.12-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a1fcb59f2f355ec350073af41d927bf83a63b50e640f4dbaa01053a28b7a7718"},
|
{file = "pydantic-2.4.2-py3-none-any.whl", hash = "sha256:bc3ddf669d234f4220e6e1c4d96b061abe0998185a8d7855c0126782b7abc8c1"},
|
||||||
{file = "pydantic-1.10.12-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b7ccf02d7eb340b216ec33e53a3a629856afe1c6e0ef91d84a4e6f2fb2ca70fe"},
|
{file = "pydantic-2.4.2.tar.gz", hash = "sha256:94f336138093a5d7f426aac732dcfe7ab4eb4da243c88f891d65deb4a2556ee7"},
|
||||||
{file = "pydantic-1.10.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fb2aa3ab3728d950bcc885a2e9eff6c8fc40bc0b7bb434e555c215491bcf48b"},
|
|
||||||
{file = "pydantic-1.10.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:771735dc43cf8383959dc9b90aa281f0b6092321ca98677c5fb6125a6f56d58d"},
|
|
||||||
{file = "pydantic-1.10.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ca48477862372ac3770969b9d75f1bf66131d386dba79506c46d75e6b48c1e09"},
|
|
||||||
{file = "pydantic-1.10.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a5e7add47a5b5a40c49b3036d464e3c7802f8ae0d1e66035ea16aa5b7a3923ed"},
|
|
||||||
{file = "pydantic-1.10.12-cp310-cp310-win_amd64.whl", hash = "sha256:e4129b528c6baa99a429f97ce733fff478ec955513630e61b49804b6cf9b224a"},
|
|
||||||
{file = "pydantic-1.10.12-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b0d191db0f92dfcb1dec210ca244fdae5cbe918c6050b342d619c09d31eea0cc"},
|
|
||||||
{file = "pydantic-1.10.12-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:795e34e6cc065f8f498c89b894a3c6da294a936ee71e644e4bd44de048af1405"},
|
|
||||||
{file = "pydantic-1.10.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69328e15cfda2c392da4e713443c7dbffa1505bc9d566e71e55abe14c97ddc62"},
|
|
||||||
{file = "pydantic-1.10.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2031de0967c279df0d8a1c72b4ffc411ecd06bac607a212892757db7462fc494"},
|
|
||||||
{file = "pydantic-1.10.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:ba5b2e6fe6ca2b7e013398bc7d7b170e21cce322d266ffcd57cca313e54fb246"},
|
|
||||||
{file = "pydantic-1.10.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2a7bac939fa326db1ab741c9d7f44c565a1d1e80908b3797f7f81a4f86bc8d33"},
|
|
||||||
{file = "pydantic-1.10.12-cp311-cp311-win_amd64.whl", hash = "sha256:87afda5539d5140cb8ba9e8b8c8865cb5b1463924d38490d73d3ccfd80896b3f"},
|
|
||||||
{file = "pydantic-1.10.12-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:549a8e3d81df0a85226963611950b12d2d334f214436a19537b2efed61b7639a"},
|
|
||||||
{file = "pydantic-1.10.12-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:598da88dfa127b666852bef6d0d796573a8cf5009ffd62104094a4fe39599565"},
|
|
||||||
{file = "pydantic-1.10.12-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba5c4a8552bff16c61882db58544116d021d0b31ee7c66958d14cf386a5b5350"},
|
|
||||||
{file = "pydantic-1.10.12-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c79e6a11a07da7374f46970410b41d5e266f7f38f6a17a9c4823db80dadf4303"},
|
|
||||||
{file = "pydantic-1.10.12-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab26038b8375581dc832a63c948f261ae0aa21f1d34c1293469f135fa92972a5"},
|
|
||||||
{file = "pydantic-1.10.12-cp37-cp37m-win_amd64.whl", hash = "sha256:e0a16d274b588767602b7646fa05af2782576a6cf1022f4ba74cbb4db66f6ca8"},
|
|
||||||
{file = "pydantic-1.10.12-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6a9dfa722316f4acf4460afdf5d41d5246a80e249c7ff475c43a3a1e9d75cf62"},
|
|
||||||
{file = "pydantic-1.10.12-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a73f489aebd0c2121ed974054cb2759af8a9f747de120acd2c3394cf84176ccb"},
|
|
||||||
{file = "pydantic-1.10.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b30bcb8cbfccfcf02acb8f1a261143fab622831d9c0989707e0e659f77a18e0"},
|
|
||||||
{file = "pydantic-1.10.12-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2fcfb5296d7877af406ba1547dfde9943b1256d8928732267e2653c26938cd9c"},
|
|
||||||
{file = "pydantic-1.10.12-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2f9a6fab5f82ada41d56b0602606a5506aab165ca54e52bc4545028382ef1c5d"},
|
|
||||||
{file = "pydantic-1.10.12-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:dea7adcc33d5d105896401a1f37d56b47d443a2b2605ff8a969a0ed5543f7e33"},
|
|
||||||
{file = "pydantic-1.10.12-cp38-cp38-win_amd64.whl", hash = "sha256:1eb2085c13bce1612da8537b2d90f549c8cbb05c67e8f22854e201bde5d98a47"},
|
|
||||||
{file = "pydantic-1.10.12-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ef6c96b2baa2100ec91a4b428f80d8f28a3c9e53568219b6c298c1125572ebc6"},
|
|
||||||
{file = "pydantic-1.10.12-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6c076be61cd0177a8433c0adcb03475baf4ee91edf5a4e550161ad57fc90f523"},
|
|
||||||
{file = "pydantic-1.10.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d5a58feb9a39f481eda4d5ca220aa8b9d4f21a41274760b9bc66bfd72595b86"},
|
|
||||||
{file = "pydantic-1.10.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5f805d2d5d0a41633651a73fa4ecdd0b3d7a49de4ec3fadf062fe16501ddbf1"},
|
|
||||||
{file = "pydantic-1.10.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:1289c180abd4bd4555bb927c42ee42abc3aee02b0fb2d1223fb7c6e5bef87dbe"},
|
|
||||||
{file = "pydantic-1.10.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5d1197e462e0364906cbc19681605cb7c036f2475c899b6f296104ad42b9f5fb"},
|
|
||||||
{file = "pydantic-1.10.12-cp39-cp39-win_amd64.whl", hash = "sha256:fdbdd1d630195689f325c9ef1a12900524dceb503b00a987663ff4f58669b93d"},
|
|
||||||
{file = "pydantic-1.10.12-py3-none-any.whl", hash = "sha256:b749a43aa51e32839c9d71dc67eb1e4221bb04af1033a32e3923d46f9effa942"},
|
|
||||||
{file = "pydantic-1.10.12.tar.gz", hash = "sha256:0fe8a415cea8f340e7a9af9c54fc71a649b43e8ca3cc732986116b3cb135d303"},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
typing-extensions = ">=4.2.0"
|
annotated-types = ">=0.4.0"
|
||||||
|
pydantic-core = "2.10.1"
|
||||||
|
typing-extensions = ">=4.6.1"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
dotenv = ["python-dotenv (>=0.10.4)"]
|
email = ["email-validator (>=2.0.0)"]
|
||||||
email = ["email-validator (>=1.0.3)"]
|
|
||||||
|
[[package]]
|
||||||
|
name = "pydantic-core"
|
||||||
|
version = "2.10.1"
|
||||||
|
description = ""
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "pydantic_core-2.10.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:d64728ee14e667ba27c66314b7d880b8eeb050e58ffc5fec3b7a109f8cddbd63"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:48525933fea744a3e7464c19bfede85df4aba79ce90c60b94d8b6e1eddd67096"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef337945bbd76cce390d1b2496ccf9f90b1c1242a3a7bc242ca4a9fc5993427a"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a1392e0638af203cee360495fd2cfdd6054711f2db5175b6e9c3c461b76f5175"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0675ba5d22de54d07bccde38997e780044dcfa9a71aac9fd7d4d7a1d2e3e65f7"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:128552af70a64660f21cb0eb4876cbdadf1a1f9d5de820fed6421fa8de07c893"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f6e6aed5818c264412ac0598b581a002a9f050cb2637a84979859e70197aa9e"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ecaac27da855b8d73f92123e5f03612b04c5632fd0a476e469dfc47cd37d6b2e"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b3c01c2fb081fced3bbb3da78510693dc7121bb893a1f0f5f4b48013201f362e"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:92f675fefa977625105708492850bcbc1182bfc3e997f8eecb866d1927c98ae6"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp310-none-win32.whl", hash = "sha256:420a692b547736a8d8703c39ea935ab5d8f0d2573f8f123b0a294e49a73f214b"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp310-none-win_amd64.whl", hash = "sha256:0880e239827b4b5b3e2ce05e6b766a7414e5f5aedc4523be6b68cfbc7f61c5d0"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:073d4a470b195d2b2245d0343569aac7e979d3a0dcce6c7d2af6d8a920ad0bea"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:600d04a7b342363058b9190d4e929a8e2e715c5682a70cc37d5ded1e0dd370b4"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39215d809470f4c8d1881758575b2abfb80174a9e8daf8f33b1d4379357e417c"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eeb3d3d6b399ffe55f9a04e09e635554012f1980696d6b0aca3e6cf42a17a03b"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a7a7902bf75779bc12ccfc508bfb7a4c47063f748ea3de87135d433a4cca7a2f"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3625578b6010c65964d177626fde80cf60d7f2e297d56b925cb5cdeda6e9925a"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:caa48fc31fc7243e50188197b5f0c4228956f97b954f76da157aae7f67269ae8"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:07ec6d7d929ae9c68f716195ce15e745b3e8fa122fc67698ac6498d802ed0fa4"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e6f31a17acede6a8cd1ae2d123ce04d8cca74056c9d456075f4f6f85de055607"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d8f1ebca515a03e5654f88411420fea6380fc841d1bea08effb28184e3d4899f"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-none-win32.whl", hash = "sha256:6db2eb9654a85ada248afa5a6db5ff1cf0f7b16043a6b070adc4a5be68c716d6"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-none-win_amd64.whl", hash = "sha256:4a5be350f922430997f240d25f8219f93b0c81e15f7b30b868b2fddfc2d05f27"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp311-none-win_arm64.whl", hash = "sha256:5fdb39f67c779b183b0c853cd6b45f7db84b84e0571b3ef1c89cdb1dfc367325"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:b1f22a9ab44de5f082216270552aa54259db20189e68fc12484873d926426921"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8572cadbf4cfa95fb4187775b5ade2eaa93511f07947b38f4cd67cf10783b118"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db9a28c063c7c00844ae42a80203eb6d2d6bbb97070cfa00194dff40e6f545ab"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0e2a35baa428181cb2270a15864ec6286822d3576f2ed0f4cd7f0c1708472aff"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05560ab976012bf40f25d5225a58bfa649bb897b87192a36c6fef1ab132540d7"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d6495008733c7521a89422d7a68efa0a0122c99a5861f06020ef5b1f51f9ba7c"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14ac492c686defc8e6133e3a2d9eaf5261b3df26b8ae97450c1647286750b901"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8282bab177a9a3081fd3d0a0175a07a1e2bfb7fcbbd949519ea0980f8a07144d"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:aafdb89fdeb5fe165043896817eccd6434aee124d5ee9b354f92cd574ba5e78f"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f6defd966ca3b187ec6c366604e9296f585021d922e666b99c47e78738b5666c"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-none-win32.whl", hash = "sha256:7c4d1894fe112b0864c1fa75dffa045720a194b227bed12f4be7f6045b25209f"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-none-win_amd64.whl", hash = "sha256:5994985da903d0b8a08e4935c46ed8daf5be1cf217489e673910951dc533d430"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp312-none-win_arm64.whl", hash = "sha256:0d8a8adef23d86d8eceed3e32e9cca8879c7481c183f84ed1a8edc7df073af94"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:9badf8d45171d92387410b04639d73811b785b5161ecadabf056ea14d62d4ede"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:ebedb45b9feb7258fac0a268a3f6bec0a2ea4d9558f3d6f813f02ff3a6dc6698"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfe1090245c078720d250d19cb05d67e21a9cd7c257698ef139bc41cf6c27b4f"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e357571bb0efd65fd55f18db0a2fb0ed89d0bb1d41d906b138f088933ae618bb"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b3dcd587b69bbf54fc04ca157c2323b8911033e827fffaecf0cafa5a892a0904"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c120c9ce3b163b985a3b966bb701114beb1da4b0468b9b236fc754783d85aa3"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15d6bca84ffc966cc9976b09a18cf9543ed4d4ecbd97e7086f9ce9327ea48891"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5cabb9710f09d5d2e9e2748c3e3e20d991a4c5f96ed8f1132518f54ab2967221"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:82f55187a5bebae7d81d35b1e9aaea5e169d44819789837cdd4720d768c55d15"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1d40f55222b233e98e3921df7811c27567f0e1a4411b93d4c5c0f4ce131bc42f"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp37-none-win32.whl", hash = "sha256:14e09ff0b8fe6e46b93d36a878f6e4a3a98ba5303c76bb8e716f4878a3bee92c"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp37-none-win_amd64.whl", hash = "sha256:1396e81b83516b9d5c9e26a924fa69164156c148c717131f54f586485ac3c15e"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:6835451b57c1b467b95ffb03a38bb75b52fb4dc2762bb1d9dbed8de31ea7d0fc"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b00bc4619f60c853556b35f83731bd817f989cba3e97dc792bb8c97941b8053a"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fa467fd300a6f046bdb248d40cd015b21b7576c168a6bb20aa22e595c8ffcdd"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d99277877daf2efe074eae6338453a4ed54a2d93fb4678ddfe1209a0c93a2468"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa7db7558607afeccb33c0e4bf1c9a9a835e26599e76af6fe2fcea45904083a6"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aad7bd686363d1ce4ee930ad39f14e1673248373f4a9d74d2b9554f06199fb58"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:443fed67d33aa85357464f297e3d26e570267d1af6fef1c21ca50921d2976302"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:042462d8d6ba707fd3ce9649e7bf268633a41018d6a998fb5fbacb7e928a183e"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ecdbde46235f3d560b18be0cb706c8e8ad1b965e5c13bbba7450c86064e96561"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ed550ed05540c03f0e69e6d74ad58d026de61b9eaebebbaaf8873e585cbb18de"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp38-none-win32.whl", hash = "sha256:8cdbbd92154db2fec4ec973d45c565e767ddc20aa6dbaf50142676484cbff8ee"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp38-none-win_amd64.whl", hash = "sha256:9f6f3e2598604956480f6c8aa24a3384dbf6509fe995d97f6ca6103bb8c2534e"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:655f8f4c8d6a5963c9a0687793da37b9b681d9ad06f29438a3b2326d4e6b7970"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e570ffeb2170e116a5b17e83f19911020ac79d19c96f320cbfa1fa96b470185b"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64322bfa13e44c6c30c518729ef08fda6026b96d5c0be724b3c4ae4da939f875"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:485a91abe3a07c3a8d1e082ba29254eea3e2bb13cbbd4351ea4e5a21912cc9b0"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7c2b8eb9fc872e68b46eeaf835e86bccc3a58ba57d0eedc109cbb14177be531"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a5cb87bdc2e5f620693148b5f8f842d293cae46c5f15a1b1bf7ceeed324a740c"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25bd966103890ccfa028841a8f30cebcf5875eeac8c4bde4fe221364c92f0c9a"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f323306d0556351735b54acbf82904fe30a27b6a7147153cbe6e19aaaa2aa429"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0c27f38dc4fbf07b358b2bc90edf35e82d1703e22ff2efa4af4ad5de1b3833e7"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f1365e032a477c1430cfe0cf2856679529a2331426f8081172c4a74186f1d595"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp39-none-win32.whl", hash = "sha256:a1c311fd06ab3b10805abb72109f01a134019739bd3286b8ae1bc2fc4e50c07a"},
|
||||||
|
{file = "pydantic_core-2.10.1-cp39-none-win_amd64.whl", hash = "sha256:ae8a8843b11dc0b03b57b52793e391f0122e740de3df1474814c700d2622950a"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:d43002441932f9a9ea5d6f9efaa2e21458221a3a4b417a14027a1d530201ef1b"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:fcb83175cc4936a5425dde3356f079ae03c0802bbdf8ff82c035f8a54b333521"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:962ed72424bf1f72334e2f1e61b68f16c0e596f024ca7ac5daf229f7c26e4208"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cf5bb4dd67f20f3bbc1209ef572a259027c49e5ff694fa56bed62959b41e1f9"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e544246b859f17373bed915182ab841b80849ed9cf23f1f07b73b7c58baee5fb"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:c0877239307b7e69d025b73774e88e86ce82f6ba6adf98f41069d5b0b78bd1bf"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:53df009d1e1ba40f696f8995683e067e3967101d4bb4ea6f667931b7d4a01357"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a1254357f7e4c82e77c348dabf2d55f1d14d19d91ff025004775e70a6ef40ada"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:524ff0ca3baea164d6d93a32c58ac79eca9f6cf713586fdc0adb66a8cdeab96a"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f0ac9fb8608dbc6eaf17956bf623c9119b4db7dbb511650910a82e261e6600f"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:320f14bd4542a04ab23747ff2c8a778bde727158b606e2661349557f0770711e"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:63974d168b6233b4ed6a0046296803cb13c56637a7b8106564ab575926572a55"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:417243bf599ba1f1fef2bb8c543ceb918676954734e2dcb82bf162ae9d7bd514"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:dda81e5ec82485155a19d9624cfcca9be88a405e2857354e5b089c2a982144b2"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:14cfbb00959259e15d684505263d5a21732b31248a5dd4941f73a3be233865b9"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:631cb7415225954fdcc2a024119101946793e5923f6c4d73a5914d27eb3d3a05"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:bec7dd208a4182e99c5b6c501ce0b1f49de2802448d4056091f8e630b28e9a52"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:149b8a07712f45b332faee1a2258d8ef1fb4a36f88c0c17cb687f205c5dc6e7d"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d966c47f9dd73c2d32a809d2be529112d509321c5310ebf54076812e6ecd884"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7eb037106f5c6b3b0b864ad226b0b7ab58157124161d48e4b30c4a43fef8bc4b"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:154ea7c52e32dce13065dbb20a4a6f0cc012b4f667ac90d648d36b12007fa9f7"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e562617a45b5a9da5be4abe72b971d4f00bf8555eb29bb91ec2ef2be348cd132"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:f23b55eb5464468f9e0e9a9935ce3ed2a870608d5f534025cd5536bca25b1402"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:e9121b4009339b0f751955baf4543a0bfd6bc3f8188f8056b1a25a2d45099934"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:0523aeb76e03f753b58be33b26540880bac5aa54422e4462404c432230543f33"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e0e2959ef5d5b8dc9ef21e1a305a21a36e254e6a34432d00c72a92fdc5ecda5"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da01bec0a26befab4898ed83b362993c844b9a607a86add78604186297eb047e"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f2e9072d71c1f6cfc79a36d4484c82823c560e6f5599c43c1ca6b5cdbd54f881"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:f36a3489d9e28fe4b67be9992a23029c3cec0babc3bd9afb39f49844a8c721c5"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f64f82cc3443149292b32387086d02a6c7fb39b8781563e0ca7b8d7d9cf72bd7"},
|
||||||
|
{file = "pydantic_core-2.10.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b4a6db486ac8e99ae696e09efc8b2b9fea67b63c8f88ba7a1a16c24a057a0776"},
|
||||||
|
{file = "pydantic_core-2.10.1.tar.gz", hash = "sha256:0f8682dbdd2f67f8e1edddcbffcc29f60a6182b4901c367fc8c1c40d30bb0a82"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest"
|
name = "pytest"
|
||||||
|
@ -816,6 +915,7 @@ files = [
|
||||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
|
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
|
||||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
|
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
|
||||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
|
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
|
||||||
|
{file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
|
||||||
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
|
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
|
||||||
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
|
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
|
||||||
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
|
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
|
||||||
|
@ -823,8 +923,15 @@ files = [
|
||||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
|
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
|
||||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
|
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
|
||||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
|
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
|
||||||
|
{file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
|
||||||
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
|
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
|
||||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||||
|
{file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
|
||||||
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
|
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
|
||||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
|
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
|
||||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
|
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
|
||||||
|
@ -841,6 +948,7 @@ files = [
|
||||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
|
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
|
||||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
|
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
|
||||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
|
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
|
||||||
|
{file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
|
||||||
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
|
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
|
||||||
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
|
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
|
||||||
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
|
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
|
||||||
|
@ -848,6 +956,7 @@ files = [
|
||||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
|
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
|
||||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
|
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
|
||||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
|
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
|
||||||
|
{file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
|
||||||
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
|
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
|
||||||
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
|
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
|
||||||
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
|
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
|
||||||
|
@ -929,13 +1038,13 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "urllib3"
|
name = "urllib3"
|
||||||
version = "2.0.4"
|
version = "2.0.5"
|
||||||
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "urllib3-2.0.4-py3-none-any.whl", hash = "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4"},
|
{file = "urllib3-2.0.5-py3-none-any.whl", hash = "sha256:ef16afa8ba34a1f989db38e1dbbe0c302e4289a47856990d0682e374563ce35e"},
|
||||||
{file = "urllib3-2.0.4.tar.gz", hash = "sha256:8d22f86aae8ef5e410d4f539fde9ce6b2113a001bb4d189e0aed70642d602b11"},
|
{file = "urllib3-2.0.5.tar.gz", hash = "sha256:13abf37382ea2ce6fb744d4dad67838eec857c9f4f57009891805e0b5e123594"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
|
@ -1050,4 +1159,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.7"
|
python-versions = "^3.7"
|
||||||
content-hash = "0db2f97d52c557dd7f90c55b4ad5bbe308c957c5f7f99fec53c57e0a13822cb4"
|
content-hash = "b7fab8703967f2616ea59a98a437cd30f97f0c8d2a06e399d688814a2a2c64f8"
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation"
|
name = "text-generation"
|
||||||
version = "0.6.0"
|
version = "0.6.1"
|
||||||
description = "Hugging Face Text Generation Python Client"
|
description = "Hugging Face Text Generation Python Client"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
|
@ -482,7 +482,6 @@ class AsyncClient:
|
||||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||||
) as session:
|
) as session:
|
||||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||||
|
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
raise parse_error(resp.status, await resp.json())
|
raise parse_error(resp.status, await resp.json())
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class Parameters(BaseModel):
|
||||||
# Get decoder input token logprobs and ids
|
# Get decoder input token logprobs and ids
|
||||||
decoder_input_details: bool = False
|
decoder_input_details: bool = False
|
||||||
# Return the N most likely tokens at each step
|
# Return the N most likely tokens at each step
|
||||||
top_n_tokens: Optional[int]
|
top_n_tokens: Optional[int] = None
|
||||||
|
|
||||||
@validator("best_of")
|
@validator("best_of")
|
||||||
def valid_best_of(cls, field_value, values):
|
def valid_best_of(cls, field_value, values):
|
||||||
|
@ -188,7 +188,7 @@ class BestOfSequence(BaseModel):
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
# Most likely tokens
|
# Most likely tokens
|
||||||
top_tokens: Optional[List[List[Token]]]
|
top_tokens: Optional[List[List[Token]]] = None
|
||||||
|
|
||||||
|
|
||||||
# `generate` details
|
# `generate` details
|
||||||
|
@ -204,7 +204,7 @@ class Details(BaseModel):
|
||||||
# Generated tokens
|
# Generated tokens
|
||||||
tokens: List[Token]
|
tokens: List[Token]
|
||||||
# Most likely tokens
|
# Most likely tokens
|
||||||
top_tokens: Optional[List[List[Token]]]
|
top_tokens: Optional[List[List[Token]]] = None
|
||||||
# Additional sequences when using the `best_of` parameter
|
# Additional sequences when using the `best_of` parameter
|
||||||
best_of_sequences: Optional[List[BestOfSequence]] = None
|
best_of_sequences: Optional[List[BestOfSequence]] = None
|
||||||
|
|
||||||
|
@ -232,7 +232,7 @@ class StreamResponse(BaseModel):
|
||||||
# Generated token
|
# Generated token
|
||||||
token: Token
|
token: Token
|
||||||
# Most likely tokens
|
# Most likely tokens
|
||||||
top_tokens: Optional[List[Token]]
|
top_tokens: Optional[List[Token]] = None
|
||||||
# Complete generated text
|
# Complete generated text
|
||||||
# Only available when the generation is finished
|
# Only available when the generation is finished
|
||||||
generated_text: Optional[str] = None
|
generated_text: Optional[str] = None
|
||||||
|
|
|
@ -34,10 +34,17 @@ Options:
|
||||||
[env: NUM_SHARD=]
|
[env: NUM_SHARD=]
|
||||||
|
|
||||||
--quantize <QUANTIZE>
|
--quantize <QUANTIZE>
|
||||||
Whether you want the model to be quantized. This will use `bitsandbytes` for quantization on the fly, or `gptq`. 4bit quantization is available through `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options
|
Whether you want the model to be quantized
|
||||||
|
|
||||||
[env: QUANTIZE=]
|
[env: QUANTIZE=]
|
||||||
[possible values: bitsandbytes, bitsandbytes-nf4, bitsandbytes-fp4, gptq, awq]
|
|
||||||
|
Possible values:
|
||||||
|
- awq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models whereever possible because of the better latency
|
||||||
|
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git
|
||||||
|
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels whereever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||||
|
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
|
||||||
|
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
||||||
|
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
||||||
|
|
||||||
--dtype <DTYPE>
|
--dtype <DTYPE>
|
||||||
The dtype to be forced upon the model. This option cannot be used with `--quantize`
|
The dtype to be forced upon the model. This option cannot be used with `--quantize`
|
||||||
|
|
|
@ -18,6 +18,8 @@ The following models are optimized and can be served with TGI, which uses custom
|
||||||
- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
|
- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
|
||||||
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
|
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
|
||||||
- [Llama V2](https://huggingface.co/meta-llama)
|
- [Llama V2](https://huggingface.co/meta-llama)
|
||||||
|
- [Code Llama](https://huggingface.co/codellama)
|
||||||
|
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||||
|
|
||||||
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -12.9140625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -10.7578125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 28747,
|
||||||
|
"logprob": -0.54785156,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3169,
|
||||||
|
"logprob": -1.4091797,
|
||||||
|
"special": false,
|
||||||
|
"text": " Let"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 307,
|
||||||
|
"logprob": -3.0273438,
|
||||||
|
"special": false,
|
||||||
|
"text": " n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 327,
|
||||||
|
"logprob": -0.94433594,
|
||||||
|
"special": false,
|
||||||
|
"text": " ="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28705,
|
||||||
|
"logprob": -0.81347656,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28740,
|
||||||
|
"logprob": -1.2958984,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28734,
|
||||||
|
"logprob": -2.0644531,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 387,
|
||||||
|
"logprob": -1.9580078,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28705,
|
||||||
|
"logprob": -0.5073242,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28740,
|
||||||
|
"logprob": -1.1816406,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": Let n = 10 - 1"
|
||||||
|
}
|
|
@ -0,0 +1,89 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -12.9140625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -10.7578125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 28747,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3169,
|
||||||
|
"logprob": -0.1307373,
|
||||||
|
"special": false,
|
||||||
|
"text": " Let"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 332,
|
||||||
|
"logprob": -2.3359375,
|
||||||
|
"special": false,
|
||||||
|
"text": " u"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 347,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " be"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 325,
|
||||||
|
"logprob": -1.0234375,
|
||||||
|
"special": false,
|
||||||
|
"text": " ("
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28734,
|
||||||
|
"logprob": -2.0292969,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 648,
|
||||||
|
"logprob": -1.0439453,
|
||||||
|
"special": false,
|
||||||
|
"text": " +"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28705,
|
||||||
|
"logprob": -0.24499512,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28770,
|
||||||
|
"logprob": -0.5073242,
|
||||||
|
"special": false,
|
||||||
|
"text": "3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 387,
|
||||||
|
"logprob": -1.5507812,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request: Let u be (0 + 3 -"
|
||||||
|
}
|
|
@ -0,0 +1,358 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -12.9140625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -10.7578125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 28747,
|
||||||
|
"logprob": -0.55078125,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3169,
|
||||||
|
"logprob": -1.4140625,
|
||||||
|
"special": false,
|
||||||
|
"text": " Let"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 307,
|
||||||
|
"logprob": -3.0273438,
|
||||||
|
"special": false,
|
||||||
|
"text": " n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 327,
|
||||||
|
"logprob": -0.94140625,
|
||||||
|
"special": false,
|
||||||
|
"text": " ="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28705,
|
||||||
|
"logprob": -0.8173828,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28740,
|
||||||
|
"logprob": -1.2978516,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28734,
|
||||||
|
"logprob": -2.0664062,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 387,
|
||||||
|
"logprob": -1.9560547,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28705,
|
||||||
|
"logprob": -0.5078125,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28740,
|
||||||
|
"logprob": -1.1787109,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": Let n = 10 - 1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -12.9140625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -10.7578125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 28747,
|
||||||
|
"logprob": -0.54785156,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3169,
|
||||||
|
"logprob": -1.4111328,
|
||||||
|
"special": false,
|
||||||
|
"text": " Let"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 307,
|
||||||
|
"logprob": -3.0292969,
|
||||||
|
"special": false,
|
||||||
|
"text": " n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 327,
|
||||||
|
"logprob": -0.94433594,
|
||||||
|
"special": false,
|
||||||
|
"text": " ="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28705,
|
||||||
|
"logprob": -0.8178711,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28740,
|
||||||
|
"logprob": -1.2939453,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28734,
|
||||||
|
"logprob": -2.0644531,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 387,
|
||||||
|
"logprob": -1.9550781,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28705,
|
||||||
|
"logprob": -0.5078125,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28740,
|
||||||
|
"logprob": -1.1796875,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": Let n = 10 - 1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -12.9140625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -10.7578125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 28747,
|
||||||
|
"logprob": -0.55078125,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3169,
|
||||||
|
"logprob": -1.4140625,
|
||||||
|
"special": false,
|
||||||
|
"text": " Let"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 307,
|
||||||
|
"logprob": -3.0273438,
|
||||||
|
"special": false,
|
||||||
|
"text": " n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 327,
|
||||||
|
"logprob": -0.94140625,
|
||||||
|
"special": false,
|
||||||
|
"text": " ="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28705,
|
||||||
|
"logprob": -0.8173828,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28740,
|
||||||
|
"logprob": -1.2978516,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28734,
|
||||||
|
"logprob": -2.0664062,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 387,
|
||||||
|
"logprob": -1.9560547,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28705,
|
||||||
|
"logprob": -0.5078125,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28740,
|
||||||
|
"logprob": -1.1787109,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": Let n = 10 - 1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -12.9140625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -10.7578125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 28747,
|
||||||
|
"logprob": -0.55078125,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3169,
|
||||||
|
"logprob": -1.4140625,
|
||||||
|
"special": false,
|
||||||
|
"text": " Let"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 307,
|
||||||
|
"logprob": -3.0273438,
|
||||||
|
"special": false,
|
||||||
|
"text": " n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 327,
|
||||||
|
"logprob": -0.94140625,
|
||||||
|
"special": false,
|
||||||
|
"text": " ="
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28705,
|
||||||
|
"logprob": -0.8173828,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28740,
|
||||||
|
"logprob": -1.2978516,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28734,
|
||||||
|
"logprob": -2.0664062,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 387,
|
||||||
|
"logprob": -1.9560547,
|
||||||
|
"special": false,
|
||||||
|
"text": " -"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28705,
|
||||||
|
"logprob": -0.5078125,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28740,
|
||||||
|
"logprob": -1.1787109,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": ": Let n = 10 - 1"
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,60 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_mistral_handle(launcher):
|
||||||
|
with launcher("mistralai/Mistral-7B-Instruct-v0.1") as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_mistral(flash_mistral_handle):
|
||||||
|
await flash_mistral_handle.health(300)
|
||||||
|
return flash_mistral_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_mistral(flash_mistral, response_snapshot):
|
||||||
|
response = await flash_mistral.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
|
||||||
|
response = await flash_mistral.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_mistral, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
|
@ -31,6 +31,7 @@ message InfoResponse {
|
||||||
bool requires_padding = 1;
|
bool requires_padding = 1;
|
||||||
string dtype = 2;
|
string dtype = 2;
|
||||||
string device_type = 3;
|
string device_type = 3;
|
||||||
|
optional uint32 window_size = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
|
|
|
@ -50,10 +50,11 @@ impl Infer {
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding, 16);
|
let queue = Queue::new(requires_padding, 16, window_size);
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
|
|
|
@ -2,6 +2,7 @@ use crate::infer::InferError;
|
||||||
use crate::infer::InferStreamResponse;
|
use crate::infer::InferStreamResponse;
|
||||||
use crate::validation::ValidGenerateRequest;
|
use crate::validation::ValidGenerateRequest;
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
|
use std::cmp::min;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_client::{Batch, Request};
|
use text_generation_client::{Batch, Request};
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
|
@ -33,12 +34,17 @@ pub(crate) struct Queue {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Queue {
|
impl Queue {
|
||||||
pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self {
|
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||||
|
|
||||||
// Launch background queue task
|
// Launch background queue task
|
||||||
tokio::spawn(queue_task(requires_padding, block_size, queue_receiver));
|
tokio::spawn(queue_task(
|
||||||
|
requires_padding,
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
|
queue_receiver,
|
||||||
|
));
|
||||||
|
|
||||||
Self { queue_sender }
|
Self { queue_sender }
|
||||||
}
|
}
|
||||||
|
@ -84,9 +90,10 @@ impl Queue {
|
||||||
async fn queue_task(
|
async fn queue_task(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
receiver: flume::Receiver<QueueCommand>,
|
receiver: flume::Receiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(requires_padding, block_size);
|
let mut state = State::new(requires_padding, block_size, window_size);
|
||||||
|
|
||||||
while let Ok(cmd) = receiver.recv_async().await {
|
while let Ok(cmd) = receiver.recv_async().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
|
@ -126,16 +133,20 @@ struct State {
|
||||||
|
|
||||||
/// Paged Attention block size
|
/// Paged Attention block size
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
|
||||||
|
/// Sliding window
|
||||||
|
window_size: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
fn new(requires_padding: bool, block_size: u32) -> Self {
|
fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
requires_padding,
|
requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
|
window_size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -204,11 +215,17 @@ impl State {
|
||||||
if self.requires_padding {
|
if self.requires_padding {
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
} else {
|
} else {
|
||||||
|
let max_new_tokens = match self.window_size {
|
||||||
|
None => entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
Some(window_size) => min(
|
||||||
|
window_size.saturating_sub(entry.request.input_length),
|
||||||
|
entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
// pad to block size
|
// pad to block size
|
||||||
decode_tokens +=
|
decode_tokens +=
|
||||||
((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1)
|
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
|
||||||
/ self.block_size)
|
|
||||||
* self.block_size;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
if prefill_tokens > prefill_token_budget
|
||||||
|
@ -342,7 +359,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_append() {
|
fn test_append() {
|
||||||
let mut state = State::new(false, 1);
|
let mut state = State::new(false, 1, None);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
assert_eq!(state.next_id, 0);
|
||||||
|
@ -358,7 +375,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_empty() {
|
fn test_next_batch_empty() {
|
||||||
let mut state = State::new(false, 1);
|
let mut state = State::new(false, 1, None);
|
||||||
|
|
||||||
assert!(state.next_batch(None, 1, 1).is_none());
|
assert!(state.next_batch(None, 1, 1).is_none());
|
||||||
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
assert!(state.next_batch(Some(1), 1, 1).is_none());
|
||||||
|
@ -366,7 +383,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_min_size() {
|
fn test_next_batch_min_size() {
|
||||||
let mut state = State::new(false, 1);
|
let mut state = State::new(false, 1, None);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -398,7 +415,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_next_batch_token_budget() {
|
fn test_next_batch_token_budget() {
|
||||||
let mut state = State::new(false, 1);
|
let mut state = State::new(false, 1, None);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -431,14 +448,14 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false, 1);
|
let queue = Queue::new(false, 1, None);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false, 1);
|
let queue = Queue::new(false, 1, None);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
|
||||||
|
@ -446,7 +463,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false, 1);
|
let queue = Queue::new(false, 1, None);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -479,7 +496,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false, 1);
|
let queue = Queue::new(false, 1, None);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -504,7 +521,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false, 1);
|
let queue = Queue::new(false, 1, None);
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
|
|
@ -595,6 +595,7 @@ pub async fn run(
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
|
shard_info.window_size,
|
||||||
generation_health,
|
generation_health,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
|
flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c
|
||||||
|
|
||||||
flash-attention-v2:
|
flash-attention-v2:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
vllm_commit := e86af624d059969b0fb07b075b1d338bf10c3365
|
vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78
|
||||||
|
|
||||||
vllm:
|
vllm:
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
|
|
|
@ -67,6 +67,16 @@ if FLASH_ATTENTION:
|
||||||
__all__.append(FlashLlama)
|
__all__.append(FlashLlama)
|
||||||
__all__.append(IDEFICSSharded)
|
__all__.append(IDEFICSSharded)
|
||||||
|
|
||||||
|
MISTRAL = True
|
||||||
|
try:
|
||||||
|
from text_generation_server.models.flash_mistral import FlashMistral
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Could not import Mistral model: {e}")
|
||||||
|
MISTRAL = False
|
||||||
|
|
||||||
|
if MISTRAL:
|
||||||
|
__all__.append(FlashMistral)
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -237,7 +247,18 @@ def get_model(
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "opt":
|
if model_type == "mistral":
|
||||||
|
if MISTRAL:
|
||||||
|
return FlashMistral(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
raise NotImplementedError("Mistral model requires flash attention v2")
|
||||||
|
|
||||||
|
if model_type == "opt":
|
||||||
return OPTSharded(
|
return OPTSharded(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -246,7 +267,7 @@ def get_model(
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "t5":
|
if model_type == "t5":
|
||||||
return T5Sharded(
|
return T5Sharded(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -254,7 +275,7 @@ def get_model(
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == "idefics":
|
if model_type == "idefics":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return IDEFICSSharded(
|
return IDEFICSSharded(
|
||||||
model_id,
|
model_id,
|
||||||
|
|
|
@ -0,0 +1,135 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
BLOCK_SIZE: int = 16
|
||||||
|
# Will be set in warmup
|
||||||
|
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CacheManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_blocks: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
repeat_slots: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
self.block_size = BLOCK_SIZE
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.repeat_slots = repeat_slots
|
||||||
|
|
||||||
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
x = self.block_size // element_size
|
||||||
|
|
||||||
|
self.kv_cache = [
|
||||||
|
(
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, head_size, self.block_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||||
|
self.slots = torch.arange(
|
||||||
|
0, num_blocks * self.block_size, dtype=torch.int32
|
||||||
|
).view(num_blocks, self.block_size)
|
||||||
|
|
||||||
|
def allocate(
|
||||||
|
self,
|
||||||
|
needed_blocks_slots: List[Tuple[int, int]],
|
||||||
|
blocks: int,
|
||||||
|
max_blocks: int,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
# Get free blocks indices by finding values in mask that are not set to 0
|
||||||
|
free_block_indices = self.free_block_mask.nonzero()
|
||||||
|
assert (
|
||||||
|
len(free_block_indices) >= blocks
|
||||||
|
), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
|
||||||
|
|
||||||
|
# Slice by the number of required blocks
|
||||||
|
block_indices = free_block_indices[:blocks]
|
||||||
|
block_indices = block_indices.flatten()
|
||||||
|
|
||||||
|
# Padded block tables
|
||||||
|
block_tables_tensor = torch.zeros(
|
||||||
|
(len(needed_blocks_slots), max_blocks), dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allocate paged attention blocks
|
||||||
|
cumulative_blocks = 0
|
||||||
|
slots = []
|
||||||
|
block_tables = []
|
||||||
|
for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
|
||||||
|
# Get allocated blocks for this sequence
|
||||||
|
allocated_blocks = block_indices[
|
||||||
|
cumulative_blocks : cumulative_blocks + needed_blocks
|
||||||
|
]
|
||||||
|
# Get slots for the allocated blocks
|
||||||
|
all_slots = self.slots[allocated_blocks].flatten()
|
||||||
|
|
||||||
|
# Repeat slots in the case of context sliding window
|
||||||
|
if needed_slots > len(all_slots) and self.repeat_slots:
|
||||||
|
repeats = math.ceil(needed_slots / len(all_slots))
|
||||||
|
all_slots = all_slots.repeat(repeats)
|
||||||
|
|
||||||
|
allocated_slots = all_slots[:needed_slots]
|
||||||
|
|
||||||
|
slots.append(allocated_slots)
|
||||||
|
block_tables.append(allocated_blocks.tolist())
|
||||||
|
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||||
|
cumulative_blocks += needed_blocks
|
||||||
|
|
||||||
|
block_tables = block_tables
|
||||||
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
|
slots = torch.concat(slots).to(device)
|
||||||
|
|
||||||
|
# Allocate the required number of blocks by setting the mask to 0
|
||||||
|
self.free_block_mask[block_indices] = 0
|
||||||
|
|
||||||
|
return block_tables, block_tables_tensor, slots
|
||||||
|
|
||||||
|
def free(self, block_indices: Optional[List[int]]):
|
||||||
|
if block_indices is not None and block_indices:
|
||||||
|
# Reset mask
|
||||||
|
self.free_block_mask[block_indices] = 1
|
||||||
|
|
||||||
|
|
||||||
|
def set_cache_manager(
|
||||||
|
num_blocks: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
repeat_slots: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> CacheManager:
|
||||||
|
global CACHE_MANAGER
|
||||||
|
if CACHE_MANAGER is not None:
|
||||||
|
del CACHE_MANAGER
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
CACHE_MANAGER = CacheManager(
|
||||||
|
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
|
||||||
|
)
|
||||||
|
return CACHE_MANAGER
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_manager() -> CacheManager:
|
||||||
|
global CACHE_MANAGER
|
||||||
|
if CACHE_MANAGER is None:
|
||||||
|
raise RuntimeError("cache manager was not initialized")
|
||||||
|
|
||||||
|
return CACHE_MANAGER
|
|
@ -0,0 +1,532 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
# Flash attention imports
|
||||||
|
import dropout_layer_norm
|
||||||
|
|
||||||
|
# vllm imports
|
||||||
|
import vllm_cache_ops
|
||||||
|
import vllm_attention_ops
|
||||||
|
|
||||||
|
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
TensorParallelHead,
|
||||||
|
get_linear,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not HAS_FLASH_ATTN_V2:
|
||||||
|
raise ImportError("Mistral model requires flash attn v2")
|
||||||
|
|
||||||
|
|
||||||
|
class MistralConfig(PretrainedConfig):
|
||||||
|
model_type = "mistral"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=14336,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096 * 32,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
sliding_window=4096,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MistralRMSNorm(nn.Module):
|
||||||
|
def __init__(self, prefix, weights, eps=1e-6):
|
||||||
|
"""
|
||||||
|
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
self.weight = nn.Parameter(weight)
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
if hidden_states.shape[-1] > 8192:
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(
|
||||||
|
variance + self.variance_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert into half-precision if necessary
|
||||||
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
|
||||||
|
return self.weight * hidden_states, residual
|
||||||
|
else:
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.weight,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.variance_epsilon,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
True, # Activate RMSNorm
|
||||||
|
)
|
||||||
|
if res is None:
|
||||||
|
res = hidden_states
|
||||||
|
|
||||||
|
return normed_hidden_states, res
|
||||||
|
|
||||||
|
|
||||||
|
def load_attention(config, prefix, weights):
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
|
return _load_gqa(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gqa(config, prefix: str, weights):
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
quantize=config.quantize,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.quantize not in ["gptq", "awq"]:
|
||||||
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
|
assert list(weight.shape) == [
|
||||||
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
|
config.hidden_size,
|
||||||
|
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
|
return TensorParallelColumnLinear(
|
||||||
|
get_linear(weight, bias=None, quantize=config.quantize)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MistralAttention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.max_past = (
|
||||||
|
config.sliding_window if config.sliding_window is not None else 0
|
||||||
|
)
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.kv_head_mapping = torch.arange(
|
||||||
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
):
|
||||||
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
|
self.rotary_emb(query, cos, sin)
|
||||||
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
|
else:
|
||||||
|
kv_to_cache = kv
|
||||||
|
|
||||||
|
vllm_cache_ops.reshape_and_cache(
|
||||||
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
|
)
|
||||||
|
|
||||||
|
# output tensor
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
# flash attention
|
||||||
|
attention(
|
||||||
|
query,
|
||||||
|
torch.select(kv, dim=1, index=0),
|
||||||
|
torch.select(kv, dim=1, index=1),
|
||||||
|
attn_output,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
max_s,
|
||||||
|
self.softmax_scale,
|
||||||
|
window_size_left=self.max_past,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||||
|
block_size = kv_cache[1].shape[3]
|
||||||
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
|
attn_output,
|
||||||
|
query,
|
||||||
|
kv_cache[0],
|
||||||
|
kv_cache[1],
|
||||||
|
self.kv_head_mapping,
|
||||||
|
self.softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
|
||||||
|
class MistralMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
act = config.hidden_act
|
||||||
|
self.act = (
|
||||||
|
ACT2FN[act]
|
||||||
|
if "gelu" not in act
|
||||||
|
else lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate="tanh"
|
||||||
|
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Fuse gate and up proj
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||||
|
|
||||||
|
|
||||||
|
class MistralLayer(nn.Module):
|
||||||
|
def __init__(self, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = f"model.layers.{layer_id}"
|
||||||
|
self.self_attn = MistralAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
|
self.input_layernorm = MistralRMSNorm(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = MistralRMSNorm(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
):
|
||||||
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
normed_hidden_states,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
# faster post attention rms norm
|
||||||
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
|
attn_output, res
|
||||||
|
)
|
||||||
|
|
||||||
|
mlp_output = self.mlp(normed_attn_res_output)
|
||||||
|
|
||||||
|
return mlp_output, attn_res
|
||||||
|
|
||||||
|
|
||||||
|
class MistralModel(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
process_group = weights.process_group
|
||||||
|
self.tp_rank = process_group.rank()
|
||||||
|
self.tp_world_size = process_group.size()
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
|
prefix="model.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MistralLayer(
|
||||||
|
layer_id,
|
||||||
|
config,
|
||||||
|
weights,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = MistralRMSNorm(
|
||||||
|
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Get rotary cos and sin for this forward
|
||||||
|
# Avoid to index in each layer
|
||||||
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||||
|
position_ids, max_s, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache[i],
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = MistralModel(config, weights)
|
||||||
|
self.lm_head = TensorParallelHead.load(
|
||||||
|
config,
|
||||||
|
prefix="lm_head",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.max_past = config.sliding_window
|
||||||
|
if self.max_past is None:
|
||||||
|
raise ValueError("max_past cannot be None")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if prefill_cache_indices is not None:
|
||||||
|
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||||
|
slots = slots[prefill_cache_indices]
|
||||||
|
else:
|
||||||
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
|
# kernel requires the true values
|
||||||
|
max_s = min(self.max_past, max_s)
|
||||||
|
input_lengths = torch.clamp(input_lengths, max=self.max_past)
|
||||||
|
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
input_lengths,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
return logits
|
|
@ -19,99 +19,17 @@ from text_generation_server.models.types import (
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
TopTokens,
|
TopTokens,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.cache_manager import (
|
||||||
|
get_cache_manager,
|
||||||
|
set_cache_manager,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
BLOCK_SIZE = 16
|
|
||||||
# Will be set in warmup
|
|
||||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CacheManager:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_blocks: int,
|
|
||||||
num_layers: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
self.block_size = BLOCK_SIZE
|
|
||||||
self.num_blocks = num_blocks
|
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
|
||||||
x = self.block_size // element_size
|
|
||||||
|
|
||||||
self.kv_cache = [
|
|
||||||
(
|
|
||||||
torch.empty(
|
|
||||||
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
torch.empty(
|
|
||||||
(num_blocks, num_heads, head_size, self.block_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
|
||||||
self.slots = torch.arange(
|
|
||||||
0, num_blocks * self.block_size, dtype=torch.int32
|
|
||||||
).view(num_blocks, self.block_size)
|
|
||||||
|
|
||||||
def allocate(self, batch: "FlashCausalLMBatch"):
|
|
||||||
# Get free blocks indices by finding values in mask that are not set to 0
|
|
||||||
free_block_indices = self.free_block_mask.nonzero()
|
|
||||||
assert (
|
|
||||||
len(free_block_indices) >= batch.blocks
|
|
||||||
), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"
|
|
||||||
|
|
||||||
# Slice by the number of required blocks
|
|
||||||
block_indices = free_block_indices[: batch.blocks]
|
|
||||||
block_indices = block_indices.flatten()
|
|
||||||
|
|
||||||
# Padded block tables
|
|
||||||
block_tables_tensor = torch.zeros(
|
|
||||||
(len(batch), batch.max_blocks), dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
# Allocate paged attention blocks
|
|
||||||
cumulative_blocks = 0
|
|
||||||
slots = []
|
|
||||||
block_tables = []
|
|
||||||
for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
|
|
||||||
# Get allocated blocks for this sequence
|
|
||||||
allocated_blocks = block_indices[
|
|
||||||
cumulative_blocks : cumulative_blocks + needed_blocks
|
|
||||||
]
|
|
||||||
# Get slots for the allocated blocks
|
|
||||||
allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]
|
|
||||||
|
|
||||||
slots.append(allocated_slots)
|
|
||||||
block_tables.append(allocated_blocks.tolist())
|
|
||||||
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
|
||||||
cumulative_blocks += needed_blocks
|
|
||||||
|
|
||||||
batch.needed_blocks_slots = None
|
|
||||||
batch.block_tables = block_tables
|
|
||||||
batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
|
|
||||||
batch.slots = torch.concat(slots).to(batch.input_ids.device)
|
|
||||||
|
|
||||||
# Allocate the required number of blocks by setting the mask to 0
|
|
||||||
self.free_block_mask[block_indices] = 0
|
|
||||||
|
|
||||||
def free(self, block_indices: Optional[List[int]]):
|
|
||||||
if block_indices is not None and block_indices:
|
|
||||||
# Reset mask
|
|
||||||
self.free_block_mask[block_indices] = 1
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
|
@ -481,7 +399,6 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
|
||||||
global CACHE_MANAGER
|
|
||||||
block_indices_to_free = []
|
block_indices_to_free = []
|
||||||
# Iterate on all requests
|
# Iterate on all requests
|
||||||
for i, r in enumerate(self.requests):
|
for i, r in enumerate(self.requests):
|
||||||
|
@ -489,7 +406,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
if r.id not in requests_idx_mapping.keys():
|
if r.id not in requests_idx_mapping.keys():
|
||||||
block_indices_to_free.extend(self.block_tables[i])
|
block_indices_to_free.extend(self.block_tables[i])
|
||||||
# Free blocks
|
# Free blocks
|
||||||
CACHE_MANAGER.free(block_indices_to_free)
|
get_cache_manager().free(block_indices_to_free)
|
||||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||||
self.block_tables = None
|
self.block_tables = None
|
||||||
|
|
||||||
|
@ -508,7 +425,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return type(self)(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
@ -665,7 +582,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
b.block_tables = None
|
b.block_tables = None
|
||||||
del b
|
del b
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
@ -698,9 +615,10 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.block_tables is not None and self.block_tables:
|
if self.block_tables is not None and self.block_tables:
|
||||||
global CACHE_MANAGER
|
|
||||||
# Free blocks
|
# Free blocks
|
||||||
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
|
get_cache_manager().free(
|
||||||
|
list(itertools.chain.from_iterable(self.block_tables))
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
@ -718,6 +636,7 @@ class FlashCausalLM(Model):
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
|
@ -731,6 +650,7 @@ class FlashCausalLM(Model):
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
sliding_window=sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -738,15 +658,14 @@ class FlashCausalLM(Model):
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
global CACHE_MANAGER
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
try:
|
try:
|
||||||
CACHE_MANAGER = CacheManager(
|
cache_manager = set_cache_manager(
|
||||||
batch.blocks,
|
batch.blocks,
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
|
self.sliding_window is not None,
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
@ -775,48 +694,36 @@ class FlashCausalLM(Model):
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
int(free_memory // total_cache_size)
|
int(free_memory // total_cache_size)
|
||||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||||
+ CACHE_MANAGER.num_blocks
|
+ cache_manager.num_blocks
|
||||||
)
|
)
|
||||||
|
|
||||||
del CACHE_MANAGER
|
|
||||||
del batch
|
del batch
|
||||||
torch.cuda.empty_cache()
|
del cache_manager
|
||||||
|
|
||||||
CACHE_MANAGER = CacheManager(
|
set_cache_manager(
|
||||||
num_blocks,
|
num_blocks,
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
|
self.sliding_window is not None,
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
def forward(
|
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
input_lengths: torch.Tensor,
|
|
||||||
max_s: int,
|
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
global CACHE_MANAGER
|
|
||||||
|
|
||||||
# Model Forward
|
# Model Forward
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=batch.input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=batch.position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
||||||
kv_cache=CACHE_MANAGER.kv_cache,
|
kv_cache=get_cache_manager().kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=batch.block_tables_tensor,
|
||||||
slots=slots,
|
slots=batch.slots[batch.slot_indices],
|
||||||
input_lengths=input_lengths,
|
input_lengths=batch.input_lengths_tensor,
|
||||||
max_s=max_s,
|
max_s=batch.max_seqlen,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=batch.prefill_head_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
|
@ -828,19 +735,19 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
if batch.needed_blocks_slots:
|
if batch.needed_blocks_slots:
|
||||||
# Allocate blocks to this batch
|
# Allocate blocks to this batch
|
||||||
CACHE_MANAGER.allocate(batch)
|
block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
|
||||||
|
batch.needed_blocks_slots,
|
||||||
|
batch.blocks,
|
||||||
|
batch.max_blocks,
|
||||||
|
batch.input_ids.device,
|
||||||
|
)
|
||||||
|
batch.needed_blocks_slots = None
|
||||||
|
batch.block_tables = block_tables
|
||||||
|
batch.block_tables_tensor = block_tables_tensor
|
||||||
|
batch.slots = slots
|
||||||
|
|
||||||
try:
|
try:
|
||||||
out = self.forward(
|
out = self.forward(batch)
|
||||||
batch.input_ids,
|
|
||||||
batch.position_ids,
|
|
||||||
batch.cu_seqlen_prefill,
|
|
||||||
batch.block_tables_tensor,
|
|
||||||
batch.slots[batch.slot_indices],
|
|
||||||
batch.input_lengths_tensor,
|
|
||||||
batch.max_seqlen,
|
|
||||||
batch.prefill_head_indices,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
del batch
|
del batch
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -0,0 +1,357 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from opentelemetry import trace
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from transformers.models.llama import LlamaTokenizerFast
|
||||||
|
from typing import Optional, Tuple, Type
|
||||||
|
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.models import FlashCausalLM
|
||||||
|
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
|
||||||
|
from text_generation_server.models.cache_manager import (
|
||||||
|
get_cache_manager,
|
||||||
|
set_cache_manager,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
|
FlashMistralForCausalLM,
|
||||||
|
MistralConfig,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
HeterogeneousNextTokenChooser,
|
||||||
|
StoppingCriteria,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
# Will be set in init
|
||||||
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
|
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Adds windowing logic to FlashCausalLMBatch
|
||||||
|
@dataclass
|
||||||
|
class FlashMistralBatch(FlashCausalLMBatch):
|
||||||
|
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
||||||
|
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "FlashCausalLMBatch":
|
||||||
|
global SLIDING_WINDOW
|
||||||
|
global SLIDING_WINDOW_BLOCKS
|
||||||
|
|
||||||
|
batch_inputs = []
|
||||||
|
max_truncation = 0
|
||||||
|
for r in pb.requests:
|
||||||
|
batch_inputs.append(r.inputs)
|
||||||
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
|
batch_tokenized_inputs = tokenizer(
|
||||||
|
batch_inputs, truncation=True, max_length=max_truncation
|
||||||
|
)["input_ids"]
|
||||||
|
|
||||||
|
position_ids = []
|
||||||
|
cu_seqlen_prefill = [0]
|
||||||
|
needed_blocks_slots = []
|
||||||
|
start_slots = []
|
||||||
|
slot_indices = []
|
||||||
|
prefill_cache_indices = []
|
||||||
|
|
||||||
|
input_lengths = []
|
||||||
|
prefix_offsets = []
|
||||||
|
read_offsets = []
|
||||||
|
all_input_ids = []
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
|
all_prefill_logprobs = True
|
||||||
|
no_prefill_logprobs = True
|
||||||
|
prefill_head_indices = []
|
||||||
|
prefill_next_token_indices = []
|
||||||
|
prefill_cu_outlens = [0]
|
||||||
|
|
||||||
|
next_token_chooser_parameters = []
|
||||||
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
|
# Cumulative length
|
||||||
|
cumulative_length = 0
|
||||||
|
cumulative_max_length = 0
|
||||||
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
|
blocks = 0
|
||||||
|
max_seqlen = 0
|
||||||
|
max_length = 0
|
||||||
|
max_blocks = 0
|
||||||
|
|
||||||
|
# Parse batch
|
||||||
|
for i, (r, tokenized_input) in enumerate(
|
||||||
|
zip(pb.requests, batch_tokenized_inputs)
|
||||||
|
):
|
||||||
|
# request id -> idx in list mapping
|
||||||
|
requests_idx_mapping[r.id] = i
|
||||||
|
|
||||||
|
tokenized_input = tokenized_input[-r.truncate :]
|
||||||
|
|
||||||
|
input_length = len(tokenized_input)
|
||||||
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
|
prefix_offsets.append(input_length - 5)
|
||||||
|
read_offsets.append(input_length)
|
||||||
|
|
||||||
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
|
# Position ids
|
||||||
|
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
||||||
|
position_ids.append(request_position_ids)
|
||||||
|
|
||||||
|
# Add cumulative lengths of all previous inputs
|
||||||
|
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||||
|
|
||||||
|
next_token_chooser_parameters.append(r.parameters)
|
||||||
|
|
||||||
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
|
r.stopping_parameters, tokenizer
|
||||||
|
)
|
||||||
|
max_new_tokens = stopping_criteria.max_new_tokens
|
||||||
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
|
|
||||||
|
# Paged attention
|
||||||
|
# Remove one as the first token des not have a past
|
||||||
|
total_tokens = input_length + max_new_tokens - 1
|
||||||
|
|
||||||
|
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
||||||
|
needed_blocks = min(
|
||||||
|
math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS
|
||||||
|
)
|
||||||
|
blocks += needed_blocks
|
||||||
|
|
||||||
|
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||||
|
start_slots.append(cumulative_max_length)
|
||||||
|
|
||||||
|
request_slot_indices = torch.arange(
|
||||||
|
cumulative_max_length,
|
||||||
|
cumulative_max_length + input_length,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
slot_indices.append(request_slot_indices)
|
||||||
|
|
||||||
|
# Create tensor to slice into the kv tensor in prefill
|
||||||
|
request_prefill_cache_indices = torch.arange(
|
||||||
|
cumulative_length + max(0, input_length - SLIDING_WINDOW),
|
||||||
|
cumulative_length + input_length,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||||
|
|
||||||
|
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||||
|
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||||
|
|
||||||
|
if r.prefill_logprobs:
|
||||||
|
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||||
|
prefill_next_token_indices.append(
|
||||||
|
prefill_out_cumulative_length + input_length - 1
|
||||||
|
)
|
||||||
|
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
||||||
|
prefill_out_cumulative_length += input_length
|
||||||
|
else:
|
||||||
|
prefill_head_indices.append(
|
||||||
|
torch.tensor(
|
||||||
|
[cumulative_length + input_length - 1], dtype=torch.int32
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||||
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||||
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
|
# Update
|
||||||
|
cumulative_length += input_length
|
||||||
|
cumulative_max_length += total_tokens
|
||||||
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
|
max_blocks = max(max_blocks, needed_blocks)
|
||||||
|
max_length = max(max_length, input_length + max_new_tokens)
|
||||||
|
|
||||||
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
next_token_chooser_parameters, dtype, device
|
||||||
|
)
|
||||||
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
|
# Padded all_input_ids_tensor
|
||||||
|
all_input_ids_tensor = np.zeros(
|
||||||
|
(len(all_input_ids), max_length), dtype=np.int64
|
||||||
|
)
|
||||||
|
for i, input_ids in enumerate(all_input_ids):
|
||||||
|
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||||
|
|
||||||
|
# Create tensors on device
|
||||||
|
all_input_ids_tensor = torch.tensor(
|
||||||
|
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(pb.requests) > 1:
|
||||||
|
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||||
|
position_ids = torch.cat(position_ids)
|
||||||
|
slot_indices = torch.cat(slot_indices)
|
||||||
|
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||||
|
else:
|
||||||
|
input_ids = all_input_ids[0]
|
||||||
|
position_ids = position_ids[0]
|
||||||
|
slot_indices = slot_indices[0]
|
||||||
|
prefill_cache_indices = prefill_cache_indices[0]
|
||||||
|
|
||||||
|
cu_seqlen_prefill = torch.tensor(
|
||||||
|
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids = position_ids.to(device)
|
||||||
|
slot_indices = slot_indices.to(device)
|
||||||
|
prefill_cache_indices = prefill_cache_indices.to(device)
|
||||||
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
|
input_lengths_tensor = torch.tensor(
|
||||||
|
input_lengths, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
if all_prefill_logprobs:
|
||||||
|
prefill_head_indices = None
|
||||||
|
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||||
|
elif no_prefill_logprobs:
|
||||||
|
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
||||||
|
prefill_next_token_indices = None
|
||||||
|
else:
|
||||||
|
prefill_head_indices = torch.tensor(
|
||||||
|
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
prefill_next_token_indices = torch.tensor(
|
||||||
|
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
batch_id=pb.id,
|
||||||
|
requests=pb.requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
start_slots=start_slots,
|
||||||
|
slot_indices=slot_indices,
|
||||||
|
needed_blocks_slots=needed_blocks_slots,
|
||||||
|
block_tables=None,
|
||||||
|
block_tables_tensor=None,
|
||||||
|
slots=None,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
prefill_head_indices=prefill_head_indices,
|
||||||
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
|
prefill_cu_outlens=prefill_cu_outlens,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
prefix_offsets=prefix_offsets,
|
||||||
|
read_offsets=read_offsets,
|
||||||
|
all_input_ids=all_input_ids,
|
||||||
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
next_token_chooser=next_token_chooser,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
|
blocks=blocks,
|
||||||
|
max_blocks=max_blocks,
|
||||||
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashMistral(FlashCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
global SLIDING_WINDOW
|
||||||
|
global SLIDING_WINDOW_BLOCKS
|
||||||
|
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
|
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = MistralConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
config.quantize = quantize
|
||||||
|
|
||||||
|
# Set context windows
|
||||||
|
SLIDING_WINDOW = config.sliding_window
|
||||||
|
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
if config.quantize in ["gptq", "awq"]:
|
||||||
|
weights._set_gptq_params(model_id)
|
||||||
|
|
||||||
|
model = FlashMistralForCausalLM(config, weights)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(FlashMistral, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_layers=len(model.model.layers),
|
||||||
|
num_kv_heads=model.model.num_key_value_heads,
|
||||||
|
head_size=model.model.head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
sliding_window=config.sliding_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_type(self) -> Type[FlashMistralBatch]:
|
||||||
|
return FlashMistralBatch
|
||||||
|
|
||||||
|
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Model Forward
|
||||||
|
logits = self.model.forward(
|
||||||
|
input_ids=batch.input_ids,
|
||||||
|
position_ids=batch.position_ids,
|
||||||
|
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
||||||
|
kv_cache=get_cache_manager().kv_cache,
|
||||||
|
block_tables=batch.block_tables_tensor,
|
||||||
|
slots=batch.slots[batch.slot_indices],
|
||||||
|
input_lengths=batch.input_lengths_tensor,
|
||||||
|
max_s=batch.max_seqlen,
|
||||||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=batch.prefill_head_indices,
|
||||||
|
)
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
batch.prefill_cache_indices = None
|
||||||
|
return logits
|
|
@ -21,6 +21,7 @@ class Model(ABC):
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
@ -30,6 +31,7 @@ class Model(ABC):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
self.has_position_ids = (
|
self.has_position_ids = (
|
||||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||||
|
@ -40,10 +42,14 @@ class Model(ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def info(self) -> InfoResponse:
|
def info(self) -> InfoResponse:
|
||||||
|
if self.requires_padding and self.sliding_window is not None:
|
||||||
|
raise NotImplementedError("sliding_window is not implemented with padding")
|
||||||
|
|
||||||
return InfoResponse(
|
return InfoResponse(
|
||||||
requires_padding=self.requires_padding,
|
requires_padding=self.requires_padding,
|
||||||
dtype=str(self.dtype),
|
dtype=str(self.dtype),
|
||||||
device_type=self.device.type,
|
device_type=self.device.type,
|
||||||
|
window_size=self.sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -57,6 +57,7 @@ def attention(
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
):
|
):
|
||||||
if HAS_FLASH_ATTN_V2:
|
if HAS_FLASH_ATTN_V2:
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
@ -72,11 +73,18 @@ def attention(
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
True,
|
True,
|
||||||
|
window_size_left,
|
||||||
|
0,
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if HAS_FLASH_ATTN:
|
if HAS_FLASH_ATTN:
|
||||||
|
if window_size_left != 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"window_size_left is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
if k.shape[1] != q.shape[1]:
|
if k.shape[1] != q.shape[1]:
|
||||||
# MQA expand
|
# MQA expand
|
||||||
|
|
|
@ -53,6 +53,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Monkey patching
|
# Monkey patching
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_layer_norm(cls, prefix, weights, eps):
|
def load_layer_norm(cls, prefix, weights, eps):
|
||||||
|
|
|
@ -8,7 +8,9 @@ def main():
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
output = subprocess.check_output(["text-generation-launcher", "--help"]).decode("utf-8")
|
output = subprocess.check_output(["text-generation-launcher", "--help"]).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
final_doc = f"# Text-generation-launcher arguments\n```\n{output}\n```"
|
final_doc = f"# Text-generation-launcher arguments\n```\n{output}\n```"
|
||||||
|
|
||||||
filename = "docs/source/basic_tutorials/launcher.md"
|
filename = "docs/source/basic_tutorials/launcher.md"
|
||||||
|
@ -16,16 +18,20 @@ def main():
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
doc = f.read()
|
doc = f.read()
|
||||||
if doc != final_doc:
|
if doc != final_doc:
|
||||||
|
|
||||||
tmp = "launcher.md"
|
tmp = "launcher.md"
|
||||||
with open(tmp, "w") as g:
|
with open(tmp, "w") as g:
|
||||||
g.write(final_doc)
|
g.write(final_doc)
|
||||||
diff = subprocess.run(["diff",tmp, filename], capture_output=True).stdout.decode("utf-8")
|
diff = subprocess.run(
|
||||||
|
["diff", tmp, filename], capture_output=True
|
||||||
|
).stdout.decode("utf-8")
|
||||||
print(diff)
|
print(diff)
|
||||||
raise Exception("Doc is not up-to-date, run `python update_doc.py` in order to update it")
|
raise Exception(
|
||||||
|
"Doc is not up-to-date, run `python update_doc.py` in order to update it"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
f.write(final_doc)
|
f.write(final_doc)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in New Issue