Adding Llava-Next (Llava 1.6) with full support. (#1709)
# What does this PR do? - Changed all models to extract `embed_tokens` in order to enable llava to separately call the embeddings and the core model layers. - Added VlmCausalLM to inherit from FlashMistral in order to be maximally supported. The only added logics sits on top and parses images into pixel values, preallocates input_ids space for the image embeddings, and passes them for the model. - Added Clip for the vision tower. - Didn't add flash for the vision tower since there's no padding anyway. - Added heuristic (potentially incomplete) to calculate number of features *before* calculating the clip patches (allows for easier logic reuse of the LLM under the hood). Still needs to be done: - [x] Implement the image parsing in the controller side, to avoid downloading n times per TP shard and also refusing requests too large early and avoid issues where the truncation actually truncates the image. - [ ] Make sure it works with quantization properly. - [x] Make sure it works with TP>1 <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
This commit is contained in:
parent
106d8ee818
commit
4634b00c2a
|
@ -40,6 +40,12 @@ dependencies = [
|
|||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aligned-vec"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1"
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.13"
|
||||
|
@ -94,12 +100,35 @@ version = "1.0.81"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247"
|
||||
|
||||
[[package]]
|
||||
name = "arbitrary"
|
||||
version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110"
|
||||
|
||||
[[package]]
|
||||
name = "arc-swap"
|
||||
version = "1.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
|
||||
|
||||
[[package]]
|
||||
name = "arg_enum_proc_macro"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.55",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711"
|
||||
|
||||
[[package]]
|
||||
name = "async-rustls"
|
||||
version = "0.3.0"
|
||||
|
@ -150,6 +179,20 @@ version = "1.2.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80"
|
||||
|
||||
[[package]]
|
||||
name = "av1-grain"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6678909d8c5d46a42abcf571271e15fdbc0a225e3646cf23762cd415046c78bf"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"arrayvec",
|
||||
"log",
|
||||
"nom",
|
||||
"num-rational",
|
||||
"v_frame",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "average"
|
||||
version = "0.14.2"
|
||||
|
@ -161,6 +204,15 @@ dependencies = [
|
|||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "avif-serialize"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "876c75a42f6364451a033496a14c44bffe41f5f4a8236f697391f11024e596d2"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "awaitdrop"
|
||||
version = "0.1.2"
|
||||
|
@ -267,6 +319,12 @@ version = "0.21.7"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51"
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.5.3"
|
||||
|
@ -282,6 +340,12 @@ version = "0.6.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
|
||||
|
||||
[[package]]
|
||||
name = "bit_field"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc827186963e592360843fb5ba4b973e145841266c1357f7180c43526f2e5b61"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
|
@ -294,6 +358,12 @@ version = "2.5.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1"
|
||||
|
||||
[[package]]
|
||||
name = "bitstream-io"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06c9989a51171e2e81038ab168b6ae22886fe9ded214430dbb4f41c28cf176da"
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.4"
|
||||
|
@ -303,6 +373,12 @@ dependencies = [
|
|||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "built"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38d17f4d6e4dc36d1a02fbedc2753a096848e7c1b0772f7654eab8e2c927dd53"
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.15.4"
|
||||
|
@ -315,6 +391,12 @@ version = "0.6.7"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1e5f035d16fc623ae5f74981db80a439803888314e3a555fd6f04acd51a3205"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.5.0"
|
||||
|
@ -370,6 +452,20 @@ name = "cc"
|
|||
version = "1.0.90"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-expr"
|
||||
version = "0.15.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa50868b64a9a6fda9d593ce778849ea8715cd2a3d2cc17ffdb4a2f2f2f1961d"
|
||||
dependencies = [
|
||||
"smallvec",
|
||||
"target-lexicon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
|
@ -423,6 +519,12 @@ version = "0.7.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce"
|
||||
|
||||
[[package]]
|
||||
name = "color_quant"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.0"
|
||||
|
@ -535,6 +637,12 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crunchy"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.6"
|
||||
|
@ -736,6 +844,22 @@ dependencies = [
|
|||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exr"
|
||||
version = "1.72.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4"
|
||||
dependencies = [
|
||||
"bit_field",
|
||||
"flume",
|
||||
"half",
|
||||
"lebe",
|
||||
"miniz_oxide",
|
||||
"rayon-core",
|
||||
"smallvec",
|
||||
"zune-inflate",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fancy-regex"
|
||||
version = "0.11.0"
|
||||
|
@ -752,6 +876,15 @@ version = "2.0.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984"
|
||||
|
||||
[[package]]
|
||||
name = "fdeflate"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645"
|
||||
dependencies = [
|
||||
"simd-adler32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fixedbitset"
|
||||
version = "0.4.2"
|
||||
|
@ -780,6 +913,15 @@ version = "1.0.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
|
||||
|
||||
[[package]]
|
||||
name = "flume"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181"
|
||||
dependencies = [
|
||||
"spin 0.9.8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
|
@ -941,6 +1083,16 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gif"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fb2d69b19215e18bb912fa30f7ce15846e301408695e44e0ef719f1da9e19f2"
|
||||
dependencies = [
|
||||
"color_quant",
|
||||
"weezl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gimli"
|
||||
version = "0.28.1"
|
||||
|
@ -976,6 +1128,16 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.12.3"
|
||||
|
@ -1161,6 +1323,45 @@ dependencies = [
|
|||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "image"
|
||||
version = "0.25.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"byteorder",
|
||||
"color_quant",
|
||||
"exr",
|
||||
"gif",
|
||||
"image-webp",
|
||||
"num-traits",
|
||||
"png",
|
||||
"qoi",
|
||||
"ravif",
|
||||
"rayon",
|
||||
"rgb",
|
||||
"tiff",
|
||||
"zune-core",
|
||||
"zune-jpeg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "image-webp"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a84a25dcae3ac487bc24ef280f9e20c79c9b1a3e5e32cbed3041d1c514aa87c"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "imgref"
|
||||
version = "1.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44feda355f4159a7c757171a77de25daf6411e217b4cabd03bd6650690468126"
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "1.9.3"
|
||||
|
@ -1223,6 +1424,17 @@ dependencies = [
|
|||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "interpolate_name"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.55",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipnet"
|
||||
version = "2.9.0"
|
||||
|
@ -1271,6 +1483,21 @@ version = "1.0.11"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
|
||||
|
||||
[[package]]
|
||||
name = "jobserver"
|
||||
version = "0.1.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jpeg-decoder"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0"
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.69"
|
||||
|
@ -1316,12 +1543,29 @@ version = "1.4.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
|
||||
[[package]]
|
||||
name = "lebe"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.153"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
|
||||
|
||||
[[package]]
|
||||
name = "libfuzzer-sys"
|
||||
version = "0.4.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a96cfd5557eb82f2b83fed4955246c988d331975a002961b07c81584d107e7f7"
|
||||
dependencies = [
|
||||
"arbitrary",
|
||||
"cc",
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.8"
|
||||
|
@ -1361,6 +1605,15 @@ version = "0.4.21"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
|
||||
|
||||
[[package]]
|
||||
name = "loop9"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fae87c125b03c1d2c0150c90365d7d6bcc53fb73a9acaef207d2d065860f062"
|
||||
dependencies = [
|
||||
"imgref",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mach2"
|
||||
version = "0.4.2"
|
||||
|
@ -1407,6 +1660,16 @@ version = "0.7.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||
|
||||
[[package]]
|
||||
name = "maybe-rayon"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.2"
|
||||
|
@ -1486,8 +1749,8 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "minijinja"
|
||||
version = "1.0.16"
|
||||
source = "git+https://github.com/mitsuhiko/minijinja.git?branch=main#82d0160b5513844e5429db084f1cbdd3313ed482"
|
||||
version = "1.0.12"
|
||||
source = "git+https://github.com/mitsuhiko/minijinja.git?rev=5cd4efb#5cd4efb9e2639247df275fe6e22a5dbe0ce71b28"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
@ -1505,6 +1768,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7"
|
||||
dependencies = [
|
||||
"adler",
|
||||
"simd-adler32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1583,6 +1847,12 @@ dependencies = [
|
|||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "new_debug_unreachable"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
|
||||
|
||||
[[package]]
|
||||
name = "ngrok"
|
||||
version = "0.13.1"
|
||||
|
@ -1642,6 +1912,12 @@ dependencies = [
|
|||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "noop_proc_macro"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8"
|
||||
|
||||
[[package]]
|
||||
name = "ntapi"
|
||||
version = "0.4.1"
|
||||
|
@ -1707,6 +1983,17 @@ version = "0.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
|
||||
|
||||
[[package]]
|
||||
name = "num-derive"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.55",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
|
@ -2071,6 +2358,19 @@ version = "0.3.30"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec"
|
||||
|
||||
[[package]]
|
||||
name = "png"
|
||||
version = "0.17.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"crc32fast",
|
||||
"fdeflate",
|
||||
"flate2",
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.6.0"
|
||||
|
@ -2132,6 +2432,25 @@ dependencies = [
|
|||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "profiling"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43d84d1d7a6ac92673717f9f6d1518374ef257669c24ebc5ac25d5033828be58"
|
||||
dependencies = [
|
||||
"profiling-procmacros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "profiling-procmacros"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"syn 2.0.55",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost"
|
||||
version = "0.11.9"
|
||||
|
@ -2209,6 +2528,15 @@ dependencies = [
|
|||
"prost 0.12.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "qoi"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f6d64c71eb498fe9eae14ce4ec935c555749aef511cca85b5568910d6e48001"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.11.1"
|
||||
|
@ -2225,6 +2553,12 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.35"
|
||||
|
@ -2281,6 +2615,56 @@ dependencies = [
|
|||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rav1e"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd87ce80a7665b1cce111f8a16c1f3929f6547ce91ade6addf4ec86a8dda5ce9"
|
||||
dependencies = [
|
||||
"arbitrary",
|
||||
"arg_enum_proc_macro",
|
||||
"arrayvec",
|
||||
"av1-grain",
|
||||
"bitstream-io",
|
||||
"built",
|
||||
"cfg-if",
|
||||
"interpolate_name",
|
||||
"itertools 0.12.1",
|
||||
"libc",
|
||||
"libfuzzer-sys",
|
||||
"log",
|
||||
"maybe-rayon",
|
||||
"new_debug_unreachable",
|
||||
"noop_proc_macro",
|
||||
"num-derive",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"paste",
|
||||
"profiling",
|
||||
"rand",
|
||||
"rand_chacha",
|
||||
"simd_helpers",
|
||||
"system-deps",
|
||||
"thiserror",
|
||||
"v_frame",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ravif"
|
||||
version = "0.11.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc13288f5ab39e6d7c9d501759712e6969fcc9734220846fc9ed26cae2cc4234"
|
||||
dependencies = [
|
||||
"avif-serialize",
|
||||
"imgref",
|
||||
"loop9",
|
||||
"quick-error",
|
||||
"rav1e",
|
||||
"rayon",
|
||||
"rgb",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "10.7.0"
|
||||
|
@ -2431,6 +2815,15 @@ dependencies = [
|
|||
"winreg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rgb"
|
||||
version = "0.8.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05aaa8004b64fd573fc9d002f4e632d51ad4f026c2b5ba95fcb6c2f32c2c47d8"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.16.20"
|
||||
|
@ -2695,6 +3088,15 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_spanned"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_urlencoded"
|
||||
version = "0.7.1"
|
||||
|
@ -2766,6 +3168,21 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simd-adler32"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
|
||||
|
||||
[[package]]
|
||||
name = "simd_helpers"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95890f873bec569a0362c235787f3aca6e1e887302ba4840839bcc6459c42da6"
|
||||
dependencies = [
|
||||
"quote",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sketches-ddsketch"
|
||||
version = "0.2.2"
|
||||
|
@ -2817,6 +3234,9 @@ name = "spin"
|
|||
version = "0.9.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spm_precompiled"
|
||||
|
@ -2933,6 +3353,19 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-deps"
|
||||
version = "6.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349"
|
||||
dependencies = [
|
||||
"cfg-expr",
|
||||
"heck 0.5.0",
|
||||
"pkg-config",
|
||||
"toml",
|
||||
"version-compare",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tabled"
|
||||
version = "0.14.0"
|
||||
|
@ -2957,6 +3390,12 @@ dependencies = [
|
|||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "target-lexicon"
|
||||
version = "0.12.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f"
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.10.1"
|
||||
|
@ -3029,10 +3468,12 @@ dependencies = [
|
|||
"async-stream",
|
||||
"axum",
|
||||
"axum-tracing-opentelemetry",
|
||||
"base64 0.22.0",
|
||||
"clap",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"hf-hub",
|
||||
"image",
|
||||
"init-tracing-opentelemetry",
|
||||
"jsonschema",
|
||||
"metrics",
|
||||
|
@ -3092,6 +3533,17 @@ dependencies = [
|
|||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiff"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba1310fcea54c6a9a4fd1aad794ecc02c31682f6bfbecdf460bf19533eed1e3e"
|
||||
dependencies = [
|
||||
"flate2",
|
||||
"jpeg-decoder",
|
||||
"weezl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.34"
|
||||
|
@ -3295,6 +3747,40 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.8.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e9dd1545e8208b4a5af1aa9bbd0b4cf7e9ea08fabc5d0a5c67fcaafa17433aa3"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_spanned",
|
||||
"toml_datetime",
|
||||
"toml_edit",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_datetime"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml_edit"
|
||||
version = "0.22.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e40bb779c5187258fd7aad0eb68cb8706a0a81fa712fbea808ab43c4b8374c4"
|
||||
dependencies = [
|
||||
"indexmap 2.2.6",
|
||||
"serde",
|
||||
"serde_spanned",
|
||||
"toml_datetime",
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic"
|
||||
version = "0.9.2"
|
||||
|
@ -3699,6 +4185,17 @@ version = "1.8.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0"
|
||||
|
||||
[[package]]
|
||||
name = "v_frame"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6f32aaa24bacd11e488aa9ba66369c7cd514885742c9fe08cfe85884db3e92b"
|
||||
dependencies = [
|
||||
"aligned-vec",
|
||||
"num-traits",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
version = "0.1.0"
|
||||
|
@ -3727,6 +4224,12 @@ dependencies = [
|
|||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "version-compare"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
|
@ -3853,6 +4356,12 @@ dependencies = [
|
|||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "weezl"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "4.4.2"
|
||||
|
@ -4113,6 +4622,15 @@ version = "0.52.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dffa400e67ed5a4dd237983829e66475f0a4a26938c4b04c21baede6262215b8"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winreg"
|
||||
version = "0.50.0"
|
||||
|
@ -4160,3 +4678,27 @@ dependencies = [
|
|||
"crossbeam-utils",
|
||||
"flate2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zune-core"
|
||||
version = "0.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f423a2c17029964870cfaabb1f13dfab7d092a62a29a89264f4d36990ca414a"
|
||||
|
||||
[[package]]
|
||||
name = "zune-inflate"
|
||||
version = "0.2.54"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73ab332fe2f6680068f3582b16a24f90ad7096d5d39b974d1c0aff0125116f02"
|
||||
dependencies = [
|
||||
"simd-adler32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zune-jpeg"
|
||||
version = "0.4.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec866b44a2a1fd6133d363f073ca1b179f438f99e7e5bfb1e33f7181facfe448"
|
||||
dependencies = [
|
||||
"zune-core",
|
||||
]
|
||||
|
|
|
@ -22,6 +22,8 @@ The following models are optimized and can be served with TGI, which uses custom
|
|||
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
|
||||
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||
- [Phi](https://huggingface.co/microsoft/phi-2)
|
||||
- [Idefics](HuggingFaceM4/idefics-9b-instruct) (Multimodal)
|
||||
- [Llava-next](llava-hf/llava-v1.6-mistral-7b-hf) (Multimodal)
|
||||
|
||||
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
||||
|
||||
|
|
|
@ -277,6 +277,8 @@ def launcher(event_loop):
|
|||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
master_port = random.randint(10_000, 20_000)
|
||||
|
@ -314,6 +316,12 @@ def launcher(event_loop):
|
|||
args.append(revision)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
if max_input_length:
|
||||
args.append("--max-input-length")
|
||||
args.append(str(max_input_length))
|
||||
if max_total_tokens:
|
||||
args.append("--max-total-tokens")
|
||||
args.append(str(max_total_tokens))
|
||||
|
||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||
|
||||
|
@ -347,6 +355,8 @@ def launcher(event_loop):
|
|||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
|
||||
|
@ -367,6 +377,12 @@ def launcher(event_loop):
|
|||
args.append(revision)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
if max_input_length:
|
||||
args.append("--max-input-length")
|
||||
args.append(str(max_input_length))
|
||||
if max_total_tokens:
|
||||
args.append("--max-total-tokens")
|
||||
args.append(str(max_total_tokens))
|
||||
|
||||
client = docker.from_env()
|
||||
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "stop_sequence",
|
||||
"generated_tokens": 6,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -10.5,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -12.140625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.0654297,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1014,
|
||||
"logprob": -2.7460938,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 6032,
|
||||
"logprob": -1.359375,
|
||||
"special": false,
|
||||
"text": " purpose"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 456,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " this"
|
||||
},
|
||||
{
|
||||
"id": 1369,
|
||||
"logprob": -0.40063477,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request\nThe purpose of this test"
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,73 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.00756073,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.20117188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 16114,
|
||||
"logprob": -1.2597656,
|
||||
"special": false,
|
||||
"text": "Once"
|
||||
},
|
||||
{
|
||||
"id": 3714,
|
||||
"logprob": -0.20825195,
|
||||
"special": false,
|
||||
"text": " upon"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.00178051,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 727,
|
||||
"logprob": -0.011955261,
|
||||
"special": false,
|
||||
"text": " time"
|
||||
},
|
||||
{
|
||||
"id": 28725,
|
||||
"logprob": -0.17541504,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 736,
|
||||
"logprob": -0.91308594,
|
||||
"special": false,
|
||||
"text": " there"
|
||||
},
|
||||
{
|
||||
"id": 403,
|
||||
"logprob": -0.058410645,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.009689331,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nOnce upon a time, there was a"
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 9,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 0,
|
||||
|
@ -14,7 +14,7 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 16017,
|
||||
"logprob": -0.30908203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " blue"
|
||||
},
|
||||
|
@ -26,39 +26,45 @@
|
|||
},
|
||||
{
|
||||
"id": 259,
|
||||
"logprob": -0.28271484,
|
||||
"logprob": -0.4716797,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 15484,
|
||||
"logprob": -1.7929688,
|
||||
"id": 261,
|
||||
"logprob": -0.044677734,
|
||||
"special": false,
|
||||
"text": "appear"
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 345,
|
||||
"logprob": -0.8935547,
|
||||
"id": 35622,
|
||||
"logprob": -0.79589844,
|
||||
"special": false,
|
||||
"text": "ed"
|
||||
"text": " cloud"
|
||||
},
|
||||
{
|
||||
"id": 281,
|
||||
"id": 263,
|
||||
"logprob": -1.2958984,
|
||||
"special": false,
|
||||
"text": "s"
|
||||
},
|
||||
{
|
||||
"id": 305,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 287,
|
||||
"id": 35622,
|
||||
"logprob": -1.1630859,
|
||||
"special": false,
|
||||
"text": " cloud"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 20495,
|
||||
"logprob": -0.32299805,
|
||||
"special": false,
|
||||
"text": " sky"
|
||||
"text": "s"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
|
@ -66,7 +72,8 @@
|
|||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Why is the sky blue?blue sky appeared in the sky"
|
||||
"generated_text": "Why is the sky blue?blue sky, clouds and clouds"
|
||||
}
|
||||
|
|
|
@ -33,6 +33,9 @@ async def test_idefics(idefics, response_snapshot):
|
|||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text == " \nAssistant: A rooster stands"
|
||||
), f"{repr(response.generated_text)}"
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
@ -48,6 +51,9 @@ async def test_idefics_load(idefics, generate_load, response_snapshot):
|
|||
|
||||
generated_texts = [r.generated_text for r in responses]
|
||||
|
||||
assert (
|
||||
generated_texts[0] == " \nAssistant: A rooster stands"
|
||||
), f"{response.generated_text}"
|
||||
assert len(generated_texts) == 4
|
||||
assert generated_texts, all(
|
||||
[text == generated_texts[0] for text in generated_texts]
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llava_next_handle(launcher):
|
||||
with launcher(
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
num_shard=4,
|
||||
max_input_length=4000,
|
||||
max_total_tokens=4096,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llava_next(flash_llava_next_handle):
|
||||
await flash_llava_next_handle.health(300)
|
||||
return flash_llava_next_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
response = await flash_llava_next.generate(
|
||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||
max_new_tokens=10,
|
||||
)
|
||||
assert (
|
||||
response.generated_text == "\n\nOnce upon a time, there was a"
|
||||
), f"{repr(response.generated_text)}"
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
||||
response = await flash_llava_next.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 6
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_load(
|
||||
flash_llava_next, generate_load, response_snapshot
|
||||
):
|
||||
chicken = get_chicken()
|
||||
responses = await generate_load(
|
||||
flash_llava_next,
|
||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
generated_texts = [r.generated_text for r in responses]
|
||||
assert generated_texts[0] == "\n\nOnce upon a time, there was a"
|
||||
assert len(generated_texts) == 4
|
||||
assert all([r.generated_text == generated_texts[0] for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -45,7 +45,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
|||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 9
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
|
|
@ -44,10 +44,12 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
|
|||
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", branch = "main", commit = "5cd4efb" }
|
||||
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
image = "0.25.1"
|
||||
base64 = "0.22.0"
|
||||
|
||||
[build-dependencies]
|
||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||
|
|
|
@ -112,10 +112,15 @@ impl Client {
|
|||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str("![](");
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
||||
inputs,
|
||||
truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct LlavaNext {
|
||||
text_config: TextConfig,
|
||||
vision_config: VisionConfig,
|
||||
image_grid_pinpoints: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
fn get_anyres_image_grid_shape(
|
||||
height: usize,
|
||||
width: usize,
|
||||
grid_pinpoints: &[(usize, usize)],
|
||||
patch_size: usize,
|
||||
) -> (usize, usize) {
|
||||
let (height, width) = select_best_resolution(height, width, grid_pinpoints);
|
||||
(height / patch_size, width / patch_size)
|
||||
}
|
||||
|
||||
/// Selects the best resolution from a list of possible resolutions based on the original size.
|
||||
/// This is done by calculating the effective and wasted resolution for each possible resolution.
|
||||
/// The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
|
||||
fn select_best_resolution(
|
||||
original_height: usize,
|
||||
original_width: usize,
|
||||
possible_resolutions: &[(usize, usize)],
|
||||
) -> (usize, usize) {
|
||||
let mut best_fit = None;
|
||||
let mut max_effective_resolution = 0;
|
||||
let mut min_wasted_resolution = f32::NEG_INFINITY;
|
||||
|
||||
for (height, width) in possible_resolutions {
|
||||
let wscale = *width as f32 / original_width as f32;
|
||||
let hscale = *height as f32 / original_height as f32;
|
||||
// f32 partial ord.
|
||||
let scale = if wscale > hscale { hscale } else { wscale };
|
||||
let downscaled_width = (*width as f32 * scale) as usize;
|
||||
let downscaled_height = (*height as f32 * scale) as usize;
|
||||
let effective_resolution = std::cmp::min(
|
||||
downscaled_width * downscaled_height,
|
||||
original_width * original_height,
|
||||
);
|
||||
let wasted_resolution = (width * height) - effective_resolution;
|
||||
|
||||
if effective_resolution > max_effective_resolution
|
||||
|| (effective_resolution == max_effective_resolution
|
||||
&& (wasted_resolution as f32) < min_wasted_resolution)
|
||||
{
|
||||
max_effective_resolution = effective_resolution;
|
||||
min_wasted_resolution = wasted_resolution as f32;
|
||||
best_fit = Some((*height, *width));
|
||||
}
|
||||
}
|
||||
|
||||
best_fit.unwrap_or((original_height, original_width))
|
||||
}
|
||||
|
||||
impl LlavaNext {
|
||||
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
||||
let image_size = self.vision_config.image_size;
|
||||
let patch_size = self.vision_config.patch_size;
|
||||
assert!(image_size % patch_size == 0);
|
||||
let npatches = image_size / patch_size;
|
||||
let (num_patch_height, num_patch_width) =
|
||||
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
||||
// Ceil
|
||||
let height_of_patch = (height * npatches + width - 1) / width;
|
||||
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
|
||||
// They are only added after width
|
||||
let newline_features = height_of_patch * num_patch_width;
|
||||
// The base patch covers the entire image
|
||||
let base_features = npatches.pow(2);
|
||||
unpadded_features + newline_features + base_features
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct ClipVisionModel {
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Config {
|
||||
LlavaNext(LlavaNext),
|
||||
ClipVisionModel(ClipVisionModel),
|
||||
Mistral,
|
||||
Idefics,
|
||||
Ssm,
|
||||
GptBigcode,
|
||||
Santacoder,
|
||||
Bloom,
|
||||
Mpt,
|
||||
GptNeox,
|
||||
Phi,
|
||||
#[serde(rename = "phi-msft")]
|
||||
PhiMsft,
|
||||
Llama,
|
||||
Baichuan,
|
||||
Gemma,
|
||||
Cohere,
|
||||
Drbx,
|
||||
Falcon,
|
||||
Mixtral,
|
||||
Starcoder2,
|
||||
Qwen2,
|
||||
Opt,
|
||||
T5,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct TextConfig {}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct VisionConfig {
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_llava_next_features() {
|
||||
let config = LlavaNext {
|
||||
text_config: TextConfig {},
|
||||
vision_config: VisionConfig {
|
||||
image_size: 336,
|
||||
patch_size: 14,
|
||||
},
|
||||
image_grid_pinpoints: vec![
|
||||
(336, 672),
|
||||
(672, 336),
|
||||
(672, 672),
|
||||
(1008, 336),
|
||||
(336, 1008),
|
||||
],
|
||||
};
|
||||
|
||||
let slots = config.get_number_of_features(640, 640);
|
||||
assert_eq!(slots, 2928);
|
||||
let slots = config.get_number_of_features(480, 640);
|
||||
assert_eq!(slots, 2340);
|
||||
let slots = config.get_number_of_features(899, 1024);
|
||||
assert_eq!(slots, 2732);
|
||||
let slots = config.get_number_of_features(1024, 899);
|
||||
assert_eq!(slots, 3320);
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
pub mod config;
|
||||
mod health;
|
||||
/// Text Generation Inference Webserver
|
||||
mod infer;
|
||||
|
|
|
@ -13,6 +13,7 @@ use std::io::BufReader;
|
|||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use text_generation_client::{ClientError, ShardedClient};
|
||||
use text_generation_router::config::Config;
|
||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||
use thiserror::Error;
|
||||
use tokenizers::Tokenizer;
|
||||
|
@ -191,15 +192,19 @@ async fn main() -> Result<(), RouterError> {
|
|||
};
|
||||
|
||||
// Load tokenizer and model info
|
||||
let (tokenizer, model_info) = if local_model {
|
||||
let (tokenizer, model_info, config) = if local_model {
|
||||
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
||||
let model_info = HubModelInfo {
|
||||
model_id: tokenizer_name.to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
};
|
||||
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json"))
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| serde_json::from_str(c).ok());
|
||||
|
||||
(tokenizer, model_info)
|
||||
(tokenizer, model_info, config)
|
||||
} else if let Some(api) = api.clone() {
|
||||
let api_repo = api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
|
@ -212,6 +217,19 @@ async fn main() -> Result<(), RouterError> {
|
|||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||
};
|
||||
|
||||
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| {
|
||||
let config: Result<Config, _> = serde_json::from_str(c);
|
||||
if let Err(err) = &config {
|
||||
tracing::warn!("Could not parse config {err:?}");
|
||||
}
|
||||
config.ok()
|
||||
})
|
||||
});
|
||||
|
||||
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
HubModelInfo {
|
||||
|
@ -221,7 +239,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
}
|
||||
});
|
||||
|
||||
(tokenizer, model_info)
|
||||
(tokenizer, model_info, config)
|
||||
} else {
|
||||
// No API and no local model
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
|
@ -229,6 +247,8 @@ async fn main() -> Result<(), RouterError> {
|
|||
));
|
||||
};
|
||||
|
||||
tracing::info!("Using config {config:?}");
|
||||
|
||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
||||
let tokenizer_config = if let Some(path) = tokenizer_config_path {
|
||||
tracing::info!("Using local tokenizer config from user specified path");
|
||||
|
@ -363,6 +383,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
max_batch_size,
|
||||
sharded_client,
|
||||
tokenizer,
|
||||
config,
|
||||
validation_workers,
|
||||
addr,
|
||||
cors_allow_origin,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::config::Config;
|
||||
/// HTTP Server logic
|
||||
use crate::health::Health;
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
|
@ -164,7 +165,8 @@ async fn generate(
|
|||
let start_time = Instant::now();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
tracing::debug!("Input: {}", req.inputs);
|
||||
// Do not long ultra long inputs, like image payloads.
|
||||
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
|
||||
|
||||
let compute_characters = req.inputs.chars().count();
|
||||
let mut add_prompt = None;
|
||||
|
@ -1154,6 +1156,7 @@ pub async fn run(
|
|||
max_batch_size: Option<usize>,
|
||||
client: ShardedClient,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
config: Option<Config>,
|
||||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
|
@ -1236,6 +1239,7 @@ pub async fn run(
|
|||
let validation = Validation::new(
|
||||
validation_workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
use crate::config::Config;
|
||||
/// Payload validation logic
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{GenerateParameters, GenerateRequest, GrammarType};
|
||||
use jsonschema::{Draft, JSONSchema};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
use text_generation_client::{
|
||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokenizers::TruncationDirection;
|
||||
// use tokenizers::TruncationDirection;
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use image::{io::Reader as ImageReader, ImageFormat};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{instrument, Span};
|
||||
|
@ -34,6 +38,7 @@ impl Validation {
|
|||
pub(crate) fn new(
|
||||
workers: usize,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
config: Option<Config>,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
|
@ -50,12 +55,13 @@ impl Validation {
|
|||
// Create workers
|
||||
for _ in 0..workers {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
let config_clone = config.clone();
|
||||
let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
|
||||
senders.push(tokenizer_sender);
|
||||
|
||||
// Spawn worker
|
||||
tokio::task::spawn_blocking(move || {
|
||||
tokenizer_worker(tokenizer_clone, tokenizer_receiver)
|
||||
tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver)
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -408,48 +414,137 @@ async fn round_robin_task(
|
|||
}
|
||||
|
||||
/// Start tokenization workers
|
||||
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
|
||||
fn tokenizer_worker(
|
||||
tokenizer: Tokenizer,
|
||||
config: Option<Config>,
|
||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||
) {
|
||||
// Loop over requests
|
||||
let is_multimodal = {
|
||||
let vocab = tokenizer.get_vocab(true);
|
||||
vocab.contains_key("<image>")
|
||||
};
|
||||
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(prepare_input(inputs, truncate, &tokenizer, is_multimodal))
|
||||
.send(prepare_input(inputs, truncate, &tokenizer, &config))
|
||||
.unwrap_or(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
|
||||
match mimetype {
|
||||
"image/png" => Some(ImageFormat::Png),
|
||||
"image/jpeg" => Some(ImageFormat::Jpeg),
|
||||
"image/jpg" => Some(ImageFormat::Jpeg),
|
||||
"image/gif" => Some(ImageFormat::Gif),
|
||||
"image/webp" => Some(ImageFormat::WebP),
|
||||
"image/tiff" => Some(ImageFormat::Tiff),
|
||||
// "image/pnm"=>Some(ImageFormat::Pnm),
|
||||
// "image/tga"=>Some(ImageFormat::Tga),
|
||||
// "image/dds"=>Some(ImageFormat::Dds),
|
||||
// "image/bmp"=>Some(ImageFormat::Bmp),
|
||||
// "image/ico"=>Some(ImageFormat::Ico),
|
||||
// "image/x-exr"=>Some(ImageFormat::OpenExr),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
fn format_to_mimetype(format: ImageFormat) -> String {
|
||||
match format {
|
||||
ImageFormat::Png => "image/png",
|
||||
ImageFormat::Jpeg => "image/jpeg",
|
||||
ImageFormat::Gif => "image/gif",
|
||||
ImageFormat::WebP => "image/webp",
|
||||
ImageFormat::Tiff => "image/tiff",
|
||||
_ => "application/octet-stream",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
||||
if input.starts_with("![](http://") || input.starts_with("![](https://") {
|
||||
let url = &input["![](".len()..input.len() - 1];
|
||||
let data = reqwest::blocking::get(url)?.bytes()?;
|
||||
|
||||
let format = image::guess_format(&data)?;
|
||||
// TODO Remove this clone
|
||||
let img = ImageReader::with_format(Cursor::new(data.clone()), format).decode()?;
|
||||
let height: usize = img.height().try_into()?;
|
||||
let width: usize = img.width().try_into()?;
|
||||
let mimetype = format_to_mimetype(format);
|
||||
let encoded = STANDARD.encode(data);
|
||||
let data_uri = format!("![](data:{mimetype};base64,{encoded})");
|
||||
Ok((data_uri, height, width))
|
||||
} else if input.starts_with("![](data:") {
|
||||
// Remove ![](....)
|
||||
let content = &input["![](data:".len()..input.len() - 1];
|
||||
let tokens: Vec<_> = content.split(';').collect();
|
||||
if tokens.len() != 2 {
|
||||
return Err(ValidationError::InvalidImageContent(content.to_string()));
|
||||
}
|
||||
let mimetype = tokens[0];
|
||||
let content = tokens[1];
|
||||
|
||||
if !content.starts_with("base64,") {
|
||||
return Err(ValidationError::InvalidImageContent(content.to_string()));
|
||||
}
|
||||
|
||||
let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
|
||||
let img = if let Some(format) = format_from_mimetype(mimetype) {
|
||||
ImageReader::with_format(Cursor::new(data), format).decode()?
|
||||
} else {
|
||||
ImageReader::new(Cursor::new(data))
|
||||
.with_guessed_format()
|
||||
.map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
|
||||
.decode()?
|
||||
};
|
||||
|
||||
let height: usize = img.height().try_into()?;
|
||||
let width: usize = img.width().try_into()?;
|
||||
Ok((input.to_string(), height, width))
|
||||
} else {
|
||||
Err(ValidationError::InvalidImageContent(input.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get input length and optionally truncate it
|
||||
fn prepare_input(
|
||||
mut inputs: String,
|
||||
truncate: Option<usize>,
|
||||
_truncate: Option<usize>,
|
||||
tokenizer: &Tokenizer,
|
||||
is_multimodal: bool,
|
||||
config: &Option<Config>,
|
||||
) -> Result<(tokenizers::Encoding, String), ValidationError> {
|
||||
let simplified_query = if is_multimodal {
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
RE.replace_all(&inputs, "<image>").into()
|
||||
} else {
|
||||
inputs.clone()
|
||||
};
|
||||
// Get the number of tokens in the input
|
||||
let mut encoding = tokenizer
|
||||
.encode(simplified_query, true)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
// Optionally truncate
|
||||
if let Some(truncate) = truncate {
|
||||
if truncate < encoding.len() && !is_multimodal {
|
||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||
inputs = tokenizer
|
||||
.decode(encoding.get_ids(), false)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let tokenizer_query = match config {
|
||||
Some(Config::LlavaNext(config)) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
modified_inputs.push_str(&image_uri);
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() - 1 {
|
||||
modified_inputs.push_str(&inputs[start..]);
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
}
|
||||
}
|
||||
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(),
|
||||
_ => inputs.clone(),
|
||||
};
|
||||
|
||||
// Get the number of tokens in the input
|
||||
let encoding = tokenizer
|
||||
.encode(tokenizer_query, true)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
Ok((encoding, inputs))
|
||||
}
|
||||
|
@ -523,6 +618,16 @@ pub enum ValidationError {
|
|||
Grammar,
|
||||
#[error("grammar is not valid: {0}")]
|
||||
InvalidGrammar(String),
|
||||
#[error("base64 encoding is invalid: {0}")]
|
||||
InvalidBase64(#[from] base64::DecodeError),
|
||||
#[error("invalid image: {0}")]
|
||||
InvalidImage(#[from] image::ImageError),
|
||||
#[error("invalid integer: {0}")]
|
||||
InvalidInt(#[from] core::num::TryFromIntError),
|
||||
#[error("invalid image content: {0}")]
|
||||
InvalidImageContent(String),
|
||||
#[error("Could not fetch image: {0}")]
|
||||
FailedFetchImage(#[from] reqwest::Error),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -541,9 +646,11 @@ mod tests {
|
|||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let disable_grammar_support = true;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -572,9 +679,11 @@ mod tests {
|
|||
let max_total_tokens = 6;
|
||||
let disable_grammar_support = true;
|
||||
let workers = 1;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -603,9 +712,11 @@ mod tests {
|
|||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let disable_grammar_support = true;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -639,9 +750,11 @@ mod tests {
|
|||
let max_total_tokens = 106;
|
||||
let workers = 1;
|
||||
let disable_grammar_support = true;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -704,9 +817,11 @@ mod tests {
|
|||
let max_total_tokens = 106;
|
||||
let workers = 1;
|
||||
let disable_grammar_support = true;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
|
|
|
@ -67,6 +67,7 @@ try:
|
|||
FlashSantacoderSharded,
|
||||
)
|
||||
from text_generation_server.models.idefics import IDEFICSSharded
|
||||
from text_generation_server.models.llava_next import LlavaNext
|
||||
from text_generation_server.models.flash_mistral import FlashMistral
|
||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||
from text_generation_server.models.flash_phi import FlashPhi
|
||||
|
@ -579,6 +580,19 @@ def get_model(
|
|||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
||||
if model_type == "llava_next":
|
||||
if FLASH_ATTENTION:
|
||||
return LlavaNext(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
||||
|
||||
if sharded:
|
||||
raise NotImplementedError("sharded is not supported for AutoModel")
|
||||
if quantize == "gptq":
|
||||
|
|
|
@ -0,0 +1,827 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_create_4d_causal_attention_mask,
|
||||
_prepare_4d_attention_mask,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
ImageClassifierOutput,
|
||||
)
|
||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(nn.Module):
|
||||
def __init__(self, prefix, config: CLIPVisionConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
# TODO Should we TP this ?
|
||||
self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding")
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
self.patch_embedding.weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.position_embedding", weights=weights
|
||||
)
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(
|
||||
pixel_values.to(dtype=target_dtype)
|
||||
) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPTextEmbeddings(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
config.max_position_embeddings, embed_dim
|
||||
)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(config.max_position_embeddings).expand((1, -1)),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = (
|
||||
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = self.embed_dim // self.num_heads
|
||||
if self.head_size * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||
self.scale = self.head_size**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.out_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return (
|
||||
tensor.view(bsz, seq_len, self.num_heads, self.head_size)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
|
||||
qkv = self.qkv(hidden_states)
|
||||
query_states, key_states, value_states = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
]
|
||||
* 3,
|
||||
dim=2,
|
||||
)
|
||||
query_states = query_states * self.scale
|
||||
key_states = self._shape(key_states, -1, bsz)
|
||||
value_states = self._shape(value_states, -1, bsz)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_size)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = (
|
||||
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
+ causal_attention_mask
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = (
|
||||
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
+ attention_mask
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
attn_probs = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class CLIPMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = CLIPAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.layer_norm2 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
causal_attention_mask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
`(config.encoder_attention_heads,)`.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPPreTrainedModel(nn.Module):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = CLIPConfig
|
||||
base_model_prefix = "clip"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
|
||||
CLIP_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
"""
|
||||
|
||||
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||
"""
|
||||
|
||||
CLIP_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||
return_loss (`bool`, *optional*):
|
||||
Whether or not to return the contrastive loss.
|
||||
"""
|
||||
|
||||
|
||||
class CLIPEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`CLIPEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: CLIPConfig
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
CLIPEncoderLayer(
|
||||
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
"""
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPTextTransformer(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
self.embeddings = CLIPTextEmbeddings(config)
|
||||
self.encoder = CLIPEncoder(
|
||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
# For `pooled_output` computation
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if input_ids is None:
|
||||
raise ValueError("You have to specify input_ids")
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
||||
|
||||
# CLIP's text model uses causal mask, prepare it here.
|
||||
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||
causal_attention_mask = _create_4d_causal_attention_mask(
|
||||
input_shape, hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(
|
||||
attention_mask, hidden_states.dtype
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
if self.eos_token_id == 2:
|
||||
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
||||
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
||||
# ------------------------------------------------------------
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(
|
||||
last_hidden_state.shape[0], device=last_hidden_state.device
|
||||
),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
|
||||
dim=-1
|
||||
),
|
||||
]
|
||||
else:
|
||||
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(
|
||||
last_hidden_state.shape[0], device=last_hidden_state.device
|
||||
),
|
||||
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
||||
(
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device)
|
||||
== self.eos_token_id
|
||||
)
|
||||
.int()
|
||||
.argmax(dim=-1),
|
||||
]
|
||||
|
||||
return last_hidden_state
|
||||
|
||||
|
||||
class CLIPTextModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPTextConfig
|
||||
|
||||
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__(config)
|
||||
self.text_model = CLIPTextTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, CLIPTextModel
|
||||
|
||||
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
||||
```"""
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
return self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionTransformer(nn.Module):
|
||||
def __init__(self, prefix, config: CLIPVisionConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(
|
||||
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||
)
|
||||
self.pre_layrnorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
self.encoder = CLIPEncoder(
|
||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||
)
|
||||
# self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.pre_layrnorm(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
)
|
||||
last_hidden_state = encoder_outputs
|
||||
# pooled_output = last_hidden_state[:, 0, :]
|
||||
# pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
# pooler_output=pooled_output,
|
||||
# hidden_states=encoder_outputs,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPVisionConfig):
|
||||
super().__init__(config)
|
||||
self.vision_model = CLIPVisionTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, CLIPVisionModel
|
||||
|
||||
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
||||
```"""
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
)
|
||||
|
||||
|
||||
class CLIPModel(nn.Module):
|
||||
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||
super().__init__()
|
||||
text_config = config.text_config
|
||||
vision_config = config.vision_config
|
||||
|
||||
self.projection_dim = config.projection_dim
|
||||
self.text_embed_dim = text_config.hidden_size
|
||||
self.vision_embed_dim = vision_config.hidden_size
|
||||
|
||||
self.text_model = CLIPTextTransformer(text_config)
|
||||
self.vision_model = CLIPVisionTransformer(vision_config)
|
||||
|
||||
self.visual_projection = nn.Linear(
|
||||
self.vision_embed_dim, self.projection_dim, bias=False
|
||||
)
|
||||
self.text_projection = nn.Linear(
|
||||
self.text_embed_dim, self.projection_dim, bias=False
|
||||
)
|
||||
self.logit_scale = nn.Parameter(
|
||||
torch.tensor(self.config.logit_scale_init_value)
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Returns:
|
||||
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
||||
applying the projection layer to the pooled output of [`CLIPTextModel`].
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, CLIPModel
|
||||
|
||||
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||
>>> text_features = model.get_text_features(**inputs)
|
||||
```"""
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
pooled_output = text_outputs[1]
|
||||
text_features = self.text_projection(pooled_output)
|
||||
|
||||
return text_features
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Returns:
|
||||
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
||||
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, CLIPModel
|
||||
|
||||
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> image_features = model.get_image_features(**inputs)
|
||||
```"""
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
)
|
||||
|
||||
pooled_output = vision_outputs[1] # pooled_output
|
||||
image_features = self.visual_projection(pooled_output)
|
||||
|
||||
return image_features
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, CLIPModel
|
||||
|
||||
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(
|
||||
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
||||
... )
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
```"""
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
image_embeds = self.visual_projection(image_embeds)
|
||||
|
||||
text_embeds = text_outputs[1]
|
||||
text_embeds = self.text_projection(text_embeds)
|
||||
|
||||
# normalized features
|
||||
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
return logits_per_image, logits_per_text
|
|
@ -281,9 +281,8 @@ class LlamaMLP(nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
|
@ -337,27 +336,30 @@ class FlashLlamaLayer(nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
prefix=(
|
||||
f"model.layers.{layer_id}"
|
||||
if not prefix
|
||||
else f"{prefix}.model.layers.{layer_id}"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
@ -368,7 +370,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
|
@ -376,8 +378,10 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
|
@ -406,13 +410,19 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=(
|
||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = FlashLlamaModel(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
|
@ -426,10 +436,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
|
@ -437,6 +449,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
|
|
@ -285,9 +285,8 @@ class MistralMLP(nn.Module):
|
|||
|
||||
|
||||
class MistralLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = MistralAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
|
@ -343,27 +342,24 @@ class MistralLayer(nn.Module):
|
|||
|
||||
|
||||
class MistralModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MistralLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
prefix=f"{prefix}.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
@ -374,7 +370,7 @@ class MistralModel(torch.nn.Module):
|
|||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
|
@ -384,9 +380,8 @@ class MistralModel(torch.nn.Module):
|
|||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
):
|
||||
hidden_states = inputs_embeds
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
|
@ -410,18 +405,27 @@ class MistralModel(torch.nn.Module):
|
|||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashMistralForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = MistralModel(config, weights)
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=(
|
||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = MistralModel(
|
||||
prefix="model" if not prefix else f"{prefix}.model",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
|
@ -453,8 +457,9 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
|
|
|
@ -0,0 +1,302 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):
|
||||
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||
original_size (`tuple`):
|
||||
The original size of the image (height, width).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The unpadded image tensor.
|
||||
"""
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||
class LlavaNextMultiModalProjector(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def load_vision_model(prefix, config, weights):
|
||||
if config.model_type == "clip_vision_model":
|
||||
from text_generation_server.models.custom_modeling.clip import (
|
||||
CLIPVisionTransformer,
|
||||
)
|
||||
|
||||
return CLIPVisionTransformer(
|
||||
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
|
||||
def load_text_model(prefix, config, weights):
|
||||
if config.model_type == "llama":
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
|
||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
||||
elif config.model_type == "mistral":
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
)
|
||||
|
||||
return FlashMistralForCausalLM(prefix, config, weights)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
|
||||
class LlavaNextForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
vision_config = config.vision_config
|
||||
# Instead of selecting in hidden_states[-2].
|
||||
# Instead compute only the n -2 + 1 layers and don't pool
|
||||
if config.vision_feature_layer < 0:
|
||||
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||
else:
|
||||
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
prefix="multi_modal_projector", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.image_newline = weights.get_tensor("image_newline")
|
||||
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.config = config
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.use_medusa = config.use_medusa
|
||||
self.language_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
)
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
# Let's pray we have enabled enough slots !
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
|
||||
# 2. Merge text and images
|
||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(
|
||||
num_images * num_patches, channels, height, width
|
||||
)
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
|
||||
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||
# Already done within the clip model
|
||||
selected_image_feature = image_features.last_hidden_state
|
||||
|
||||
if self.config.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.config.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [num_patches] * num_images
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.image_newline[None]), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_features
|
||||
)
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
|
@ -106,6 +106,19 @@ class FlashCausalLMBatch(Batch):
|
|||
max_tokens=self.blocks * BLOCK_SIZE,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(cls, requests, tokenizer):
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
return batch_tokenized_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
|
@ -114,16 +127,7 @@ class FlashCausalLMBatch(Batch):
|
|||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in pb.requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
|
||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||
position_ids = []
|
||||
speculative_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
|
|
|
@ -67,7 +67,8 @@ class FlashLlama(FlashCausalLM):
|
|||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
prefix = ""
|
||||
model = FlashLlamaForCausalLM(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
|
|
|
@ -6,8 +6,7 @@ import numpy as np
|
|||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase, AutoTokenizer
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
|
@ -65,19 +64,21 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
|
||||
@classmethod
|
||||
def from_tokenized(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
batch_tokenized_inputs,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
sliding_window, sliding_window_blocks = get_sliding_windows()
|
||||
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in pb.requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
|
||||
position_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
needed_blocks_slots = []
|
||||
|
@ -301,14 +302,15 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||
class BaseFlashMistral(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config_cls,
|
||||
model_cls,
|
||||
model_id: str,
|
||||
config_cls=AutoConfig,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -317,22 +319,13 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
else:
|
||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
||||
|
||||
try:
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
except Exception:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = config_cls.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
|
@ -341,10 +334,12 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
config.use_medusa = use_medusa
|
||||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
if getattr(config, "sliding_window", None) is not None:
|
||||
set_sliding_window(
|
||||
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||
)
|
||||
else:
|
||||
config.sliding_window = None
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -353,17 +348,19 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = model_cls(config, weights)
|
||||
prefix = ""
|
||||
model = model_cls(prefix, config, weights)
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
|
@ -371,6 +368,16 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
sliding_window=config.sliding_window,
|
||||
)
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.model.layers),
|
||||
model.model.num_key_value_heads,
|
||||
model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> int:
|
||||
return self.model.max_past
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[FlashMistralBatch]:
|
||||
return FlashMistralBatch
|
||||
|
@ -485,11 +492,11 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.model.max_past is not None:
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.model.max_past, max_s)
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import torch
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
@ -20,29 +21,13 @@ from text_generation_server.models.types import (
|
|||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.models.vlm_causal_lm import split
|
||||
|
||||
import re
|
||||
|
||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||||
|
||||
|
||||
def split(string):
|
||||
parts = []
|
||||
cursor = 0
|
||||
for pattern in IMAGES.finditer(string):
|
||||
start = pattern.start()
|
||||
if start != cursor:
|
||||
parts.append(string[cursor:start])
|
||||
|
||||
parts.append(pattern.group(1))
|
||||
cursor = pattern.end()
|
||||
|
||||
if cursor != len(string):
|
||||
parts.append(string[cursor:])
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
|
@ -93,10 +78,21 @@ class IdeficsCausalLMBatch(Batch):
|
|||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IdeficsCausalLMBatch":
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
processor: ProcessorMixin, # Hack
|
||||
config,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IdeficsCausalLMBatch":
|
||||
|
@ -127,10 +123,14 @@ class IdeficsCausalLMBatch(Batch):
|
|||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
# TODO Check impact on idefics
|
||||
prompts = []
|
||||
for inp in inputs:
|
||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||
prompts.append(split(inp))
|
||||
prompt = []
|
||||
for chunk in split(inp):
|
||||
prompt.append(chunk["content"])
|
||||
prompts.append(prompt)
|
||||
|
||||
# The processor replaces the call to tokenizer, and
|
||||
# a/ takes care of fetching images from the URL
|
||||
|
@ -141,7 +141,8 @@ class IdeficsCausalLMBatch(Batch):
|
|||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||
# TODO Check impact on idefics
|
||||
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||
).to(device)
|
||||
for _ in pb.requests:
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
|
@ -156,7 +157,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
max_input_length = input_lengths.max()
|
||||
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
pixel_values = tokenized_inputs["pixel_values"]
|
||||
pixel_values = tokenized_inputs.get("pixel_values", None)
|
||||
image_hidden_states = None
|
||||
# Allocate maximum attention_mask
|
||||
attention_mask = input_ids.new_zeros(
|
||||
|
@ -165,16 +166,19 @@ class IdeficsCausalLMBatch(Batch):
|
|||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||
# Do the same for image_attention_mask
|
||||
image_attention_mask = input_ids.new_zeros(
|
||||
(
|
||||
pb.size,
|
||||
max_input_length + padding_right_offset,
|
||||
tokenized_inputs["pixel_values"].size(1),
|
||||
if pixel_values is None:
|
||||
image_attention_mask = None
|
||||
else:
|
||||
image_attention_mask = input_ids.new_zeros(
|
||||
(
|
||||
pb.size,
|
||||
max_input_length + padding_right_offset,
|
||||
pixel_values.size(1),
|
||||
)
|
||||
)
|
||||
)
|
||||
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
|
||||
"image_attention_mask"
|
||||
]
|
||||
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
|
||||
"image_attention_mask"
|
||||
]
|
||||
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
|
@ -677,19 +681,22 @@ class IdeficsCausalLM(Model):
|
|||
start = time.time_ns()
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
if batch.input_ids.size(1) == 1:
|
||||
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
|
||||
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
|
||||
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
|
||||
# token need to attend to the encoder hidden states (i.e. the vision encoder)
|
||||
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, -(batch.padding_right_offset + 1)
|
||||
].unsqueeze(1)
|
||||
if batch.image_attention_mask is None:
|
||||
image_attention_mask = None
|
||||
else:
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, : -batch.padding_right_offset
|
||||
]
|
||||
if batch.input_ids.size(1) == 1:
|
||||
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
|
||||
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
|
||||
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
|
||||
# token need to attend to the encoder hidden states (i.e. the vision encoder)
|
||||
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, -(batch.padding_right_offset + 1)
|
||||
].unsqueeze(1)
|
||||
else:
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, : -batch.padding_right_offset
|
||||
]
|
||||
|
||||
logits, speculative_logits, past, image_hidden_states = self.forward(
|
||||
input_ids=batch.input_ids,
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
|
||||
|
||||
class LlavaNext(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
super().__init__(
|
||||
model_cls=LlavaNextForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
|
@ -0,0 +1,329 @@
|
|||
import re
|
||||
import torch
|
||||
import math
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import base64
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
FlashMistralBatch,
|
||||
)
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||||
|
||||
|
||||
def split(string) -> List[Dict[str, str]]:
|
||||
parts = []
|
||||
cursor = 0
|
||||
for pattern in IMAGES.finditer(string):
|
||||
start = pattern.start()
|
||||
if start != cursor:
|
||||
parts.append({"type": "text", "content": string[cursor:start]})
|
||||
|
||||
parts.append({"type": "image", "content": pattern.group(1)})
|
||||
cursor = pattern.end()
|
||||
|
||||
if cursor != len(string):
|
||||
parts.append({"type": "text", "content": string[cursor:]})
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def get_number_of_features(height: int, width: int, config) -> int:
|
||||
# From config
|
||||
# Hardcoded for CLIP for now
|
||||
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
|
||||
image_grid_pinpoints = config.image_grid_pinpoints
|
||||
image_size = config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
|
||||
assert image_size % patch_size == 0
|
||||
|
||||
npatches = image_size // patch_size
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
[height, width],
|
||||
image_grid_pinpoints,
|
||||
image_size,
|
||||
)
|
||||
|
||||
height_of_patch = math.ceil(height / width * npatches)
|
||||
|
||||
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
|
||||
# They are only added after width
|
||||
newline_features = height_of_patch * num_patch_width
|
||||
# The base patch covers the entire image
|
||||
base_features = npatches**2
|
||||
return unpadded_features + newline_features + base_features
|
||||
|
||||
|
||||
def load_data_uri(image_uri: str) -> Image.Image:
|
||||
image_uri = image_uri.split(",")[-1]
|
||||
content = base64.b64decode(image_uri)
|
||||
image = Image.open(BytesIO(content))
|
||||
return image
|
||||
|
||||
|
||||
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
|
||||
# assert get_number_of_features(640, 640) == 2928
|
||||
|
||||
|
||||
class VlmCausalLMBatch(FlashMistralBatch):
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches):
|
||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||
batch.pixel_values = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]):
|
||||
batch = super().filter(request_ids)
|
||||
batch.pixel_values = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
||||
batch_inputs = []
|
||||
image_inputs = []
|
||||
max_truncation = 0
|
||||
for r in requests:
|
||||
chunks = split(r.inputs)
|
||||
full_text = ""
|
||||
for chunk in chunks:
|
||||
if chunk["type"] == "text":
|
||||
full_text += chunk["content"]
|
||||
elif chunk["type"] == "image":
|
||||
image = chunk["content"]
|
||||
# Should never receive URLs anymore, processing should be done
|
||||
# On the rust layer.
|
||||
# This avoid making n queries per TP
|
||||
# if image.startswith("https://") or image.startswith("http://"):
|
||||
# image = processor.image_processor.fetch_images(image)
|
||||
if image.startswith("data:"):
|
||||
image = load_data_uri(image)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot process input image not starting with data:"
|
||||
)
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
height, width = image_input["image_sizes"][0]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
full_text += "<image>" * num_features
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||
|
||||
batch_inputs.append(full_text)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
if image_inputs:
|
||||
image_inputs = {
|
||||
"pixel_values": torch.cat(
|
||||
[img["pixel_values"] for img in image_inputs], dim=0
|
||||
),
|
||||
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
|
||||
}
|
||||
else:
|
||||
image_inputs = None
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
processor,
|
||||
config,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "VlmCausalLMBatch":
|
||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||
pb.requests, tokenizer, processor, config
|
||||
)
|
||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
if image_inputs is not None:
|
||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||
else:
|
||||
batch.pixel_values = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
|
||||
class VlmCausalLM(BaseFlashMistral):
|
||||
@property
|
||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
return VlmCausalLMBatch
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.language_model.model.layers),
|
||||
model.language_model.model.num_key_value_heads,
|
||||
model.language_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.language_model, "max_past", None)
|
||||
|
||||
def forward(
|
||||
self, batch: VlmCausalLMBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
|
||||
B, speculative_length = speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
new_input_ids = torch.cat(
|
||||
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||
).reshape(-1)
|
||||
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||
arange_int = arange.to(dtype=torch.int32)
|
||||
new_position_ids = (
|
||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||
).view(-1)
|
||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
block_tables = (
|
||||
block_tables.unsqueeze(1)
|
||||
.expand(B, new_length, -1)
|
||||
.reshape(B * new_length, -1)
|
||||
.contiguous()
|
||||
)
|
||||
max_s = max_s + speculative_length
|
||||
|
||||
input_ids = new_input_ids
|
||||
position_ids = new_position_ids
|
||||
else:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
image_sizes=batch.image_sizes,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
|
@ -13,6 +13,7 @@ from typing import List, Optional
|
|||
from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
|
@ -78,13 +79,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
if (
|
||||
self.model.batch_type == IdeficsCausalLMBatch
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb(
|
||||
if self.model.batch_type in {
|
||||
IdeficsCausalLMBatch,
|
||||
VlmCausalLMBatch,
|
||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb_processor(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.processor,
|
||||
self.model.model.config,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
|
@ -100,13 +103,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
|
||||
async def Prefill(self, request, context):
|
||||
start = time.time_ns()
|
||||
if (
|
||||
self.model.batch_type == IdeficsCausalLMBatch
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb(
|
||||
if self.model.batch_type in {
|
||||
IdeficsCausalLMBatch,
|
||||
VlmCausalLMBatch,
|
||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb_processor(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.processor,
|
||||
self.model.model.config,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue