Mllama flash version (#2585)
* Working loading state. * Preprocessing. * Working state ? (Broke idefics1 temporarily). * Cleaner condition. * Fix idefics. * Updating config, removing TODO * Mllama * Ugrade transformers 4.45 * Flashing mllama. * Starting to get there. * Working state. * Integrations tests for mllama (cutting to 10 tokens because there seems' to be instability after (meaning size of the batch matters. * Updating model link. * Earlier assert. * Fix vlm ? * remove log. * Force ignore all images but last. * Default dtype bfloat16. * Update integration test after switch to bf16. * Remove dead code. * Removed dead code. * Upgrade the flake to latest transformers/tokenizers * Move to hf tgi-nix * Upgrade to 0.5.0
This commit is contained in:
parent
584b4d7a68
commit
d18ed5cfc5
|
@ -133,7 +133,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -172,7 +172,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -183,7 +183,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -205,9 +205,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autocfg"
|
name = "autocfg"
|
||||||
version = "1.3.0"
|
version = "1.4.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
|
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "av1-grain"
|
name = "av1-grain"
|
||||||
|
@ -316,12 +316,12 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum"
|
name = "axum"
|
||||||
version = "0.7.6"
|
version = "0.7.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8f43644eed690f5374f1af436ecd6aea01cd201f6fbdf0178adaf6907afb2cec"
|
checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum-core 0.4.4",
|
"axum-core 0.4.5",
|
||||||
"bytes",
|
"bytes",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"http 1.1.0",
|
"http 1.1.0",
|
||||||
|
@ -367,9 +367,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum-core"
|
name = "axum-core"
|
||||||
version = "0.4.4"
|
version = "0.4.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5e6b8ba012a258d63c9adfa28b9ddcf66149da6f986c5b5452e629d5ee64bf00"
|
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
@ -392,7 +392,7 @@ version = "0.16.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08"
|
checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"axum 0.7.6",
|
"axum 0.7.7",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"http 1.1.0",
|
"http 1.1.0",
|
||||||
|
@ -456,7 +456,7 @@ dependencies = [
|
||||||
"regex",
|
"regex",
|
||||||
"rustc-hash",
|
"rustc-hash",
|
||||||
"shlex",
|
"shlex",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
"which",
|
"which",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -605,9 +605,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.1.21"
|
version = "1.1.22"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0"
|
checksum = "9540e661f81799159abee814118cc139a2004b3a3aa3ea37724a1b66530b90e0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"jobserver",
|
"jobserver",
|
||||||
"libc",
|
"libc",
|
||||||
|
@ -704,7 +704,7 @@ dependencies = [
|
||||||
"heck 0.5.0",
|
"heck 0.5.0",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -971,7 +971,7 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"scratch",
|
"scratch",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -988,7 +988,7 @@ checksum = "98532a60dedaebc4848cb2cba5023337cc9ea3af16a5b062633fabfd9f18fb60"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1012,7 +1012,7 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"strsim",
|
"strsim",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1023,7 +1023,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"darling_core",
|
"darling_core",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1053,7 +1053,7 @@ dependencies = [
|
||||||
"darling",
|
"darling",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1063,7 +1063,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc"
|
checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"derive_builder_core",
|
"derive_builder_core",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1192,9 +1192,9 @@ checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fdeflate"
|
name = "fdeflate"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645"
|
checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"simd-adler32",
|
"simd-adler32",
|
||||||
]
|
]
|
||||||
|
@ -1207,9 +1207,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "flate2"
|
name = "flate2"
|
||||||
version = "1.0.33"
|
version = "1.0.34"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253"
|
checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"crc32fast",
|
"crc32fast",
|
||||||
"miniz_oxide 0.8.0",
|
"miniz_oxide 0.8.0",
|
||||||
|
@ -1338,7 +1338,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1864,7 +1864,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b23a0c8dfe501baac4adf6ebbfa6eddf8f0c07f56b058cc1288017e32397846c"
|
checksum = "b23a0c8dfe501baac4adf6ebbfa6eddf8f0c07f56b058cc1288017e32397846c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1884,7 +1884,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2270,7 +2270,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08"
|
checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"adler",
|
"adler",
|
||||||
"simd-adler32",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2280,6 +2279,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
|
checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"adler2",
|
"adler2",
|
||||||
|
"simd-adler32",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2319,7 +2319,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2519,7 +2519,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2599,9 +2599,12 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "once_cell"
|
name = "once_cell"
|
||||||
version = "1.19.0"
|
version = "1.20.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1"
|
||||||
|
dependencies = [
|
||||||
|
"portable-atomic",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "onig"
|
name = "onig"
|
||||||
|
@ -2654,7 +2657,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2808,7 +2811,7 @@ dependencies = [
|
||||||
"glob",
|
"glob",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"opentelemetry 0.21.0",
|
"opentelemetry 0.21.0",
|
||||||
"ordered-float 4.2.2",
|
"ordered-float 4.3.0",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"rand",
|
"rand",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
|
@ -2828,7 +2831,7 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"opentelemetry 0.23.0",
|
"opentelemetry 0.23.0",
|
||||||
"ordered-float 4.2.2",
|
"ordered-float 4.3.0",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"rand",
|
"rand",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
|
@ -2851,9 +2854,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ordered-float"
|
name = "ordered-float"
|
||||||
version = "4.2.2"
|
version = "4.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4a91171844676f8c7990ce64959210cd2eaef32c2612c50f9fae9f8aaa6065a6"
|
checksum = "44d501f1a72f71d3c063a6bbc8f7271fa73aa09fe5d6283b6571e2ed176a2537"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
]
|
]
|
||||||
|
@ -2937,7 +2940,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -2988,22 +2991,22 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "png"
|
name = "png"
|
||||||
version = "0.17.13"
|
version = "0.17.14"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1"
|
checksum = "52f9d46a34a05a6a57566bc2bfae066ef07585a6e3fa30fbbdff5936380623f0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 1.3.2",
|
"bitflags 1.3.2",
|
||||||
"crc32fast",
|
"crc32fast",
|
||||||
"fdeflate",
|
"fdeflate",
|
||||||
"flate2",
|
"flate2",
|
||||||
"miniz_oxide 0.7.4",
|
"miniz_oxide 0.8.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "portable-atomic"
|
name = "portable-atomic"
|
||||||
version = "1.8.0"
|
version = "1.9.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce"
|
checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "powerfmt"
|
name = "powerfmt"
|
||||||
|
@ -3027,7 +3030,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba"
|
checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3079,7 +3082,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
|
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3119,7 +3122,7 @@ dependencies = [
|
||||||
"prost 0.12.6",
|
"prost 0.12.6",
|
||||||
"prost-types",
|
"prost-types",
|
||||||
"regex",
|
"regex",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -3146,7 +3149,7 @@ dependencies = [
|
||||||
"itertools 0.12.1",
|
"itertools 0.12.1",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3205,7 +3208,7 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"pyo3-macros-backend",
|
"pyo3-macros-backend",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3218,7 +3221,7 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"pyo3-build-config",
|
"pyo3-build-config",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3402,9 +3405,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
version = "0.5.5"
|
version = "0.5.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "62871f2d65009c0256aed1b9cfeeb8ac272833c404e13d53d400cd0dad7a2ac0"
|
checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 2.6.0",
|
"bitflags 2.6.0",
|
||||||
]
|
]
|
||||||
|
@ -3422,14 +3425,14 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex"
|
name = "regex"
|
||||||
version = "1.10.6"
|
version = "1.11.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
|
checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"memchr",
|
"memchr",
|
||||||
"regex-automata 0.4.7",
|
"regex-automata 0.4.8",
|
||||||
"regex-syntax 0.8.4",
|
"regex-syntax 0.8.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3443,13 +3446,13 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex-automata"
|
name = "regex-automata"
|
||||||
version = "0.4.7"
|
version = "0.4.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
|
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"memchr",
|
"memchr",
|
||||||
"regex-syntax 0.8.4",
|
"regex-syntax 0.8.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3460,9 +3463,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex-syntax"
|
name = "regex-syntax"
|
||||||
version = "0.8.4"
|
version = "0.8.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
|
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "reqwest"
|
name = "reqwest"
|
||||||
|
@ -3563,7 +3566,7 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"rust-embed-utils",
|
"rust-embed-utils",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
"walkdir",
|
"walkdir",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -3686,9 +3689,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls-pki-types"
|
name = "rustls-pki-types"
|
||||||
version = "1.8.0"
|
version = "1.9.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
|
checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls-webpki"
|
name = "rustls-webpki"
|
||||||
|
@ -3813,7 +3816,7 @@ checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3840,9 +3843,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_spanned"
|
name = "serde_spanned"
|
||||||
version = "0.6.7"
|
version = "0.6.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d"
|
checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
@ -4028,7 +4031,7 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"rustversion",
|
"rustversion",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -4050,9 +4053,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "2.0.77"
|
version = "2.0.79"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed"
|
checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -4152,9 +4155,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tempfile"
|
name = "tempfile"
|
||||||
version = "3.12.0"
|
version = "3.13.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
|
checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"fastrand",
|
"fastrand",
|
||||||
|
@ -4259,7 +4262,7 @@ version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum 0.7.6",
|
"axum 0.7.7",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"clap 4.5.18",
|
"clap 4.5.18",
|
||||||
|
@ -4308,7 +4311,7 @@ version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum 0.7.6",
|
"axum 0.7.7",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"clap 4.5.18",
|
"clap 4.5.18",
|
||||||
|
@ -4357,7 +4360,7 @@ version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum 0.7.6",
|
"axum 0.7.7",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"clap 4.5.18",
|
"clap 4.5.18",
|
||||||
|
@ -4428,7 +4431,7 @@ checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -4533,7 +4536,7 @@ dependencies = [
|
||||||
"rayon",
|
"rayon",
|
||||||
"rayon-cond",
|
"rayon-cond",
|
||||||
"regex",
|
"regex",
|
||||||
"regex-syntax 0.8.4",
|
"regex-syntax 0.8.5",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"spm_precompiled",
|
"spm_precompiled",
|
||||||
|
@ -4566,7 +4569,7 @@ dependencies = [
|
||||||
"rayon",
|
"rayon",
|
||||||
"rayon-cond",
|
"rayon-cond",
|
||||||
"regex",
|
"regex",
|
||||||
"regex-syntax 0.8.4",
|
"regex-syntax 0.8.5",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"spm_precompiled",
|
"spm_precompiled",
|
||||||
|
@ -4612,7 +4615,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -4771,7 +4774,7 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"prost-build",
|
"prost-build",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -4858,7 +4861,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -5151,7 +5154,7 @@ dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"regex",
|
"regex",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -5160,7 +5163,7 @@ version = "6.0.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da"
|
checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"axum 0.7.6",
|
"axum 0.7.7",
|
||||||
"mime_guess",
|
"mime_guess",
|
||||||
"regex",
|
"regex",
|
||||||
"rust-embed",
|
"rust-embed",
|
||||||
|
@ -5189,7 +5192,7 @@ checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -5290,7 +5293,7 @@ dependencies = [
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -5324,7 +5327,7 @@ checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
"wasm-bindgen-backend",
|
"wasm-bindgen-backend",
|
||||||
"wasm-bindgen-shared",
|
"wasm-bindgen-shared",
|
||||||
]
|
]
|
||||||
|
@ -5668,9 +5671,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winnow"
|
name = "winnow"
|
||||||
version = "0.6.19"
|
version = "0.6.20"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c52ac009d615e79296318c1bcce2d422aaca15ad08515e344feeda07df67a587"
|
checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
@ -5703,7 +5706,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 2.0.77",
|
"syn 2.0.79",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
@ -28,11 +28,17 @@ class ToolCall(BaseModel):
|
||||||
function: dict
|
function: dict
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(BaseModel):
|
||||||
|
type: str
|
||||||
|
text: Optional[str] = None
|
||||||
|
image_url: Any = None
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
# Role of the message sender
|
# Role of the message sender
|
||||||
role: str
|
role: str
|
||||||
# Content of the message
|
# Content of the message
|
||||||
content: Optional[str] = None
|
content: Optional[Union[str, List[Chunk]]] = None
|
||||||
# Optional name of the message sender
|
# Optional name of the message sender
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
# Tool calls associated with the chat completion
|
# Tool calls associated with the chat completion
|
||||||
|
|
|
@ -35,6 +35,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
||||||
- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
||||||
- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)
|
- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)
|
||||||
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)
|
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)
|
||||||
|
- [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal)
|
||||||
|
|
||||||
|
|
||||||
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
||||||
|
|
33
flake.lock
33
flake.lock
|
@ -497,11 +497,11 @@
|
||||||
"systems": "systems_7"
|
"systems": "systems_7"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1710146030,
|
"lastModified": 1726560853,
|
||||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
|
||||||
"owner": "numtide",
|
"owner": "numtide",
|
||||||
"repo": "flake-utils",
|
"repo": "flake-utils",
|
||||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -718,11 +718,11 @@
|
||||||
},
|
},
|
||||||
"nixpkgs_6": {
|
"nixpkgs_6": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1724915739,
|
"lastModified": 1727675176,
|
||||||
"narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=",
|
"narHash": "sha256-xIjBFMYldWvj+g8ahxMPofsj+OqxvKJN6YylNHQ7gn4=",
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "85be051bb60943d3328d91aaf2598798f87e19af",
|
"rev": "a6d0207fea9212d28cd3d487efe6bc699663b93a",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -853,11 +853,11 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1726626348,
|
"lastModified": 1727836133,
|
||||||
"narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=",
|
"narHash": "sha256-JE0zciM5IGWvK8J/pE2VldNBf7oyMH5WrU8tZArefbg=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2",
|
"rev": "02321540b0c8000b36889b1b974d1fec585b25a4",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -978,17 +978,16 @@
|
||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1727710820,
|
"lastModified": 1727859277,
|
||||||
"narHash": "sha256-BuSafCxoFQhkp7lnvNtpquxSK43rIbnouL2HypIUC+o=",
|
"narHash": "sha256-AsrPuQqhg8x5RRR3aX0vvDDRQb+HREq2wGxXOpZnWus=",
|
||||||
"owner": "danieldk",
|
"owner": "huggingface",
|
||||||
"repo": "tgi-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"rev": "4f4dc4b85dd856fd7904e8e3e486a2ff153584a2",
|
"rev": "14196ab62f31d005f46207f7a251f82a81d0a09f",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "danieldk",
|
"owner": "huggingface",
|
||||||
"ref": "moe-kernels-0.5.0",
|
"repo": "text-generation-inference-nix",
|
||||||
"repo": "tgi-nix",
|
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
};
|
};
|
||||||
nix-filter.url = "github:numtide/nix-filter";
|
nix-filter.url = "github:numtide/nix-filter";
|
||||||
tgi-nix.url = "github:danieldk/tgi-nix/moe-kernels-0.5.0";
|
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
|
|
|
@ -0,0 +1,106 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a bustling city, a chicken named Cluck",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727773835,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.3.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 60
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a world where even chickens could dream big,",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727773835,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.3.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 60
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a world where even chickens could dream big,",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727773835,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.3.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 60
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a world where even chickens could dream big,",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727773835,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.3.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 60
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,26 @@
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a bustling city, a chicken named Cluck",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727556016,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.3.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 60
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,105 @@
|
||||||
|
import pytest
|
||||||
|
import base64
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mllama_handle(launcher):
|
||||||
|
with launcher("meta-llama/Llama-3.2-11B-Vision-Instruct", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def mllama(mllama_handle):
|
||||||
|
await mllama_handle.health(300)
|
||||||
|
return mllama_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
# TODO fix the server parsser to count inline image tokens correctly
|
||||||
|
def get_chicken():
|
||||||
|
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cow_beach():
|
||||||
|
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mllama_simpl(mllama, response_snapshot):
|
||||||
|
# chicken = get_chicken()
|
||||||
|
response = await mllama.chat(
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Can you tell me a very short story based on the image?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.usage == {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"total_tokens": 60,
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== "In a bustling city, a chicken named Cluck"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mllama_load(mllama, generate_load, response_snapshot):
|
||||||
|
futures = [
|
||||||
|
mllama.chat(
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Can you tell me a very short story based on the image?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for i in range(4)
|
||||||
|
]
|
||||||
|
responses = await asyncio.gather(*futures)
|
||||||
|
|
||||||
|
generated_texts = [response.choices[0].message.content for response in responses]
|
||||||
|
|
||||||
|
assert generated_texts[0] == "In a bustling city, a chicken named Cluck"
|
||||||
|
assert len(generated_texts) == 4
|
||||||
|
assert generated_texts, all(
|
||||||
|
[text == generated_texts[0] for text in generated_texts]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
|
@ -146,6 +146,7 @@ pub enum Config {
|
||||||
ClipVisionModel(ClipVisionModel),
|
ClipVisionModel(ClipVisionModel),
|
||||||
Mistral,
|
Mistral,
|
||||||
Idefics,
|
Idefics,
|
||||||
|
Mllama,
|
||||||
Idefics2(Idefics2),
|
Idefics2(Idefics2),
|
||||||
Ssm,
|
Ssm,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
|
|
|
@ -29,7 +29,7 @@ impl ChatTemplate {
|
||||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
let template_str = template.into_boxed_str();
|
let template_str = template.into_boxed_str();
|
||||||
env.add_function("raise_exception", raise_exception);
|
env.add_function("raise_exception", raise_exception);
|
||||||
tracing::debug!("Loading template: {:#?}", template_str);
|
tracing::debug!("Loading template: {}", template_str);
|
||||||
|
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
let template = Box::leak(env)
|
let template = Box::leak(env)
|
||||||
|
|
|
@ -567,6 +567,7 @@ fn image_tokens(
|
||||||
use HubPreprocessorConfig::*;
|
use HubPreprocessorConfig::*;
|
||||||
match config {
|
match config {
|
||||||
Idefics => "<image>".to_string(),
|
Idefics => "<image>".to_string(),
|
||||||
|
Mllama => "<|image|>".to_string(),
|
||||||
Idefics2(config) => {
|
Idefics2(config) => {
|
||||||
const FAKE: &str = "<fake_token_around_image>";
|
const FAKE: &str = "<fake_token_around_image>";
|
||||||
const IMAGE: &str = "<image>";
|
const IMAGE: &str = "<image>";
|
||||||
|
@ -618,7 +619,7 @@ fn prepare_input(
|
||||||
use Config::*;
|
use Config::*;
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -23,10 +23,10 @@ opentelemetry-api = "^1.25.0"
|
||||||
opentelemetry-exporter-otlp = "^1.25.0"
|
opentelemetry-exporter-otlp = "^1.25.0"
|
||||||
opentelemetry-instrumentation-grpc = "^0.46b0"
|
opentelemetry-instrumentation-grpc = "^0.46b0"
|
||||||
hf-transfer = "^0.1.2"
|
hf-transfer = "^0.1.2"
|
||||||
sentencepiece = "^0.1.97"
|
sentencepiece = "^0.2"
|
||||||
tokenizers = "^0.19.1"
|
tokenizers = "^0.20"
|
||||||
huggingface-hub = "^0.23"
|
huggingface-hub = "^0.23"
|
||||||
transformers = "^4.43"
|
transformers = "^4.45"
|
||||||
einops = "^0.6.1"
|
einops = "^0.6.1"
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
texttable = { version = "^1.6.7", optional = true }
|
||||||
datasets = { version = "^2.14.0", optional = true }
|
datasets = { version = "^2.14.0", optional = true }
|
||||||
|
|
|
@ -1,19 +1,19 @@
|
||||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
|
||||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -1,19 +1,19 @@
|
||||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
|
||||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -1,19 +1,19 @@
|
||||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
|
||||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
|
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -76,6 +76,7 @@ FLASH_ATTENTION = True
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||||
|
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||||
FlashDeepseekV2ForCausalLM,
|
FlashDeepseekV2ForCausalLM,
|
||||||
DeepseekV2Config,
|
DeepseekV2Config,
|
||||||
|
@ -112,7 +113,11 @@ try:
|
||||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
FlashPhiForCausalLM,
|
FlashPhiForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.idefics import IDEFICSSharded
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
||||||
|
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||||
|
from text_generation_server.models.custom_modeling.mllama import (
|
||||||
|
MllamaForConditionalGeneration,
|
||||||
|
)
|
||||||
from text_generation_server.models.custom_modeling.llava_next import (
|
from text_generation_server.models.custom_modeling.llava_next import (
|
||||||
LlavaNextForConditionalGeneration,
|
LlavaNextForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
@ -149,7 +154,7 @@ except ImportError as e:
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
__all__.append(FlashCausalLM)
|
__all__.append(FlashCausalLM)
|
||||||
__all__.append(IDEFICSSharded)
|
__all__.append(IdeficsCausalLM)
|
||||||
|
|
||||||
MAMBA_AVAILABLE = True
|
MAMBA_AVAILABLE = True
|
||||||
try:
|
try:
|
||||||
|
@ -316,6 +321,12 @@ class ModelType(enum.Enum):
|
||||||
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
|
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
|
||||||
"multimodal": True,
|
"multimodal": True,
|
||||||
}
|
}
|
||||||
|
MLLAMA = {
|
||||||
|
"type": "mllama",
|
||||||
|
"name": "Mllama",
|
||||||
|
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
__GLOBALS = locals()
|
__GLOBALS = locals()
|
||||||
|
@ -1116,7 +1127,7 @@ def get_model(
|
||||||
)
|
)
|
||||||
if model_type == IDEFICS:
|
if model_type == IDEFICS:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return IDEFICSSharded(
|
return IdeficsCausalLM(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
@ -1126,6 +1137,22 @@ def get_model(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
if model_type == MLLAMA:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return MllamaCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=MllamaForConditionalGeneration,
|
||||||
|
batch_class=MllamaCausalLMBatch,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
||||||
if model_type == IDEFICS2:
|
if model_type == IDEFICS2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
|
|
|
@ -450,6 +450,7 @@ class FlashLlamaLayer(nn.Module):
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
cross_attention_states,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -487,6 +488,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
|
|
||||||
# Skip fp8 quant for first and last layers
|
# Skip fp8 quant for first and last layers
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
|
self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
|
@ -499,8 +501,27 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layers.extend(
|
# Skip first and last layers
|
||||||
[
|
for layer_id in range(1, config.num_hidden_layers - 1):
|
||||||
|
if layer_id in self.cross_attention_layers:
|
||||||
|
from text_generation_server.models.custom_modeling.mllama import (
|
||||||
|
FlashLlamaCrossLayer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers.append(
|
||||||
|
FlashLlamaCrossLayer(
|
||||||
|
index=layer_id,
|
||||||
|
prefix=(
|
||||||
|
f"model.layers.{layer_id}"
|
||||||
|
if not prefix
|
||||||
|
else f"{prefix}.model.layers.{layer_id}"
|
||||||
|
),
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(
|
||||||
|
@ -511,9 +532,6 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
# Skip first and last layers
|
|
||||||
for layer_id in range(1, config.num_hidden_layers - 1)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
|
@ -556,6 +574,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
cross_attention_states=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
@ -579,6 +598,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
cross_attention_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
@ -625,6 +645,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_states=None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
|
@ -639,6 +660,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
|
@ -48,7 +48,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
text_config = config.text_config
|
text_config = config.text_config
|
||||||
|
|
|
@ -0,0 +1,995 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch Mllama model."""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
FastLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
|
FlashLlamaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_aspect_ratio_attention_mask(
|
||||||
|
aspect_ratio_mask: torch.Tensor,
|
||||||
|
num_patches: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Expand aspect ratio mask to target_length
|
||||||
|
batch_size, max_num_tiles = aspect_ratio_mask.shape
|
||||||
|
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
|
||||||
|
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
|
||||||
|
|
||||||
|
# Mask padding patches
|
||||||
|
pad_patches = target_length - num_patches
|
||||||
|
attention_mask[:, :, -pad_patches:] = 0
|
||||||
|
|
||||||
|
# Invert the mask (0 -> 1, 1 -> 0)
|
||||||
|
attention_mask = 1 - attention_mask
|
||||||
|
|
||||||
|
# Reshape to 2D and create 4D attention mask
|
||||||
|
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
|
||||||
|
attention_mask = attention_mask.reshape(
|
||||||
|
batch_size, max_num_tiles * target_length, 1
|
||||||
|
)
|
||||||
|
attention_mask = (
|
||||||
|
attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
|
||||||
|
)
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
min_dtype: float,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
min_dtype (`float`):
|
||||||
|
The minimum value representable with the dtype `dtype`.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length),
|
||||||
|
fill_value=min_dtype,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(
|
||||||
|
target_length, device=device
|
||||||
|
) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = (
|
||||||
|
causal_mask.clone()
|
||||||
|
) # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = (
|
||||||
|
causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
)
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[
|
||||||
|
:, :, :, :mask_length
|
||||||
|
].masked_fill(padding_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_cross_attention_mask(
|
||||||
|
cross_attention_mask: torch.Tensor,
|
||||||
|
num_vision_tokens: int,
|
||||||
|
dtype: str,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# reshape so it can be used by attn module
|
||||||
|
batch_size, text_total_length, *_ = cross_attention_mask.shape
|
||||||
|
cross_attention_mask = cross_attention_mask.repeat_interleave(
|
||||||
|
num_vision_tokens, dim=3
|
||||||
|
)
|
||||||
|
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
|
||||||
|
cross_attention_mask = cross_attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# invert the mask
|
||||||
|
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
|
||||||
|
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
|
||||||
|
inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
|
||||||
|
# last dimension contains negative infinity values, otherwise it's 1
|
||||||
|
negative_inf_value = torch.finfo(dtype).min
|
||||||
|
full_text_row_masked_out_mask = (
|
||||||
|
(cross_attention_mask != negative_inf_value)
|
||||||
|
.any(dim=-1)
|
||||||
|
.type_as(cross_attention_mask)[..., None]
|
||||||
|
)
|
||||||
|
cross_attention_mask *= full_text_row_masked_out_mask
|
||||||
|
|
||||||
|
return cross_attention_mask, full_text_row_masked_out_mask
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
|
||||||
|
class MllamaVisionMLP(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaVisionSdpaAttention(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.head_dim = config.hidden_size // config.attention_heads
|
||||||
|
self.num_heads = config.attention_heads // weights.process_group.size()
|
||||||
|
|
||||||
|
self.qkv_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_state: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv = self.qkv_proj(hidden_state)
|
||||||
|
query, key, value = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_dim * self.num_heads,
|
||||||
|
self.head_dim * self.num_heads,
|
||||||
|
self.head_dim * self.num_heads,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, q_seq_len, _ = query.shape
|
||||||
|
_, kv_seq_len, _ = key.shape
|
||||||
|
|
||||||
|
query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
|
||||||
|
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
|
||||||
|
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
query = query.transpose(1, 2)
|
||||||
|
key = key.transpose(1, 2)
|
||||||
|
value = value.transpose(1, 2)
|
||||||
|
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
|
||||||
|
|
||||||
|
output = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaVisionEncoderLayer(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights, is_gated: bool):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_attention_heads = config.attention_heads
|
||||||
|
self.is_gated = is_gated
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
|
||||||
|
self.self_attn = MllamaVisionSdpaAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = MllamaVisionMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
|
||||||
|
)
|
||||||
|
|
||||||
|
# there used to be an if else here, no code path
|
||||||
|
if is_gated:
|
||||||
|
self.gate_attn = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.gate_ffn = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_state: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_state
|
||||||
|
hidden_state = self.input_layernorm(hidden_state)
|
||||||
|
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
|
||||||
|
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
|
||||||
|
hidden_state = residual + gate_attn * hidden_state
|
||||||
|
|
||||||
|
# Feed forward
|
||||||
|
residual = hidden_state
|
||||||
|
hidden_state = self.post_attention_layernorm(hidden_state)
|
||||||
|
hidden_state = self.mlp(hidden_state)
|
||||||
|
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
|
||||||
|
hidden_state = residual + gate_ffn * hidden_state
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaVisionEncoder(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = [
|
||||||
|
MllamaVisionEncoderLayer(
|
||||||
|
prefix=f"{prefix}.layers.{i}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
is_gated=is_gated,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
encoder_states = [hidden_states]
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs
|
||||||
|
encoder_states.append(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, encoder_states
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.max_num_tiles = config.max_num_tiles
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
||||||
|
|
||||||
|
self.embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.embedding", weights=weights
|
||||||
|
)
|
||||||
|
self.gate = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
embeddings = self.embedding(aspect_ratio_ids)
|
||||||
|
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
|
||||||
|
|
||||||
|
# Always gated.
|
||||||
|
embeddings = embeddings * self.gate.tanh()
|
||||||
|
|
||||||
|
hidden_state = hidden_state + embeddings
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaPrecomputedPositionEmbedding(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.max_num_tiles = config.max_num_tiles
|
||||||
|
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
||||||
|
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.scale = config.hidden_size**-0.5
|
||||||
|
|
||||||
|
self.gate = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# position embedding
|
||||||
|
embedding = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
|
||||||
|
self.tile_embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.tile_embedding", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# position embeddings
|
||||||
|
hidden_state = hidden_state + self.gated_position_embedding.view(
|
||||||
|
1, 1, self.num_patches, self.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# precomputed tile position embeddings
|
||||||
|
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
|
||||||
|
batch_size = hidden_state.shape[0]
|
||||||
|
tile_position_embedding = tile_position_embedding.reshape(
|
||||||
|
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
|
||||||
|
)
|
||||||
|
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
|
||||||
|
hidden_state = hidden_state + gated_tile_position_embedding
|
||||||
|
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaVisionModel(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.max_num_tiles = config.max_num_tiles
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_channels = config.num_channels
|
||||||
|
self.intermediate_layers_indices = config.intermediate_layers_indices
|
||||||
|
|
||||||
|
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
||||||
|
self.scale = config.hidden_size**-0.5
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=self.hidden_size,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
padding="valid",
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.class_embedding = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
|
||||||
|
prefix=f"{prefix}.gated_positional_embedding",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
||||||
|
prefix=f"{prefix}.pre_tile_positional_embedding",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
||||||
|
prefix=f"{prefix}.post_tile_positional_embedding",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
## layer norms
|
||||||
|
self.layernorm_pre = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layernorm_pre",
|
||||||
|
weights=weights,
|
||||||
|
# torch default
|
||||||
|
eps=1e-05,
|
||||||
|
)
|
||||||
|
self.layernorm_post = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layernorm_post",
|
||||||
|
weights=weights,
|
||||||
|
# torch default
|
||||||
|
eps=1e-05,
|
||||||
|
)
|
||||||
|
|
||||||
|
## encoders
|
||||||
|
self.transformer = MllamaVisionEncoder(
|
||||||
|
prefix=f"{prefix}.transformer",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
is_gated=False,
|
||||||
|
num_layers=config.num_hidden_layers,
|
||||||
|
)
|
||||||
|
self.global_transformer = MllamaVisionEncoder(
|
||||||
|
prefix=f"{prefix}.global_transformer",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
is_gated=True,
|
||||||
|
num_layers=config.num_global_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, _, hidden_size = hidden_state.shape
|
||||||
|
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
||||||
|
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
aspect_ratio_ids: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
|
||||||
|
pixel_values.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
pixel_values = pixel_values.reshape(
|
||||||
|
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
|
||||||
|
)
|
||||||
|
aspect_ratio_ids = aspect_ratio_ids.reshape(
|
||||||
|
batch_size * num_concurrent_media, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
# patch embedding
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values)
|
||||||
|
hidden_state = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# tile embeddings
|
||||||
|
_, num_patches, dim = hidden_state.shape
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media, num_tiles, -1, dim
|
||||||
|
)
|
||||||
|
hidden_state = self.pre_tile_positional_embedding(
|
||||||
|
hidden_state, aspect_ratio_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply cls token
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media * num_tiles, num_patches, dim
|
||||||
|
)
|
||||||
|
hidden_state = self.apply_class_embedding(hidden_state)
|
||||||
|
num_patches += 1
|
||||||
|
|
||||||
|
# apply position embeddings
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media, num_tiles, num_patches, dim
|
||||||
|
)
|
||||||
|
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
|
||||||
|
|
||||||
|
# apply encoder
|
||||||
|
hidden_state = self.layernorm_pre(hidden_state)
|
||||||
|
|
||||||
|
# Compute the number of tokens to pad
|
||||||
|
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
|
||||||
|
# Compute padding tuple for pad function
|
||||||
|
padding = (
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
num_padding_patches,
|
||||||
|
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
|
||||||
|
# Pad the tensor
|
||||||
|
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
|
||||||
|
slice_index = -num_padding_patches if num_padding_patches > 0 else None
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.reshape(
|
||||||
|
batch_size * num_concurrent_media, -1
|
||||||
|
)
|
||||||
|
attention_mask = _prepare_aspect_ratio_attention_mask(
|
||||||
|
aspect_ratio_mask=attention_mask,
|
||||||
|
num_patches=self.num_patches,
|
||||||
|
target_length=hidden_state.shape[2],
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
|
||||||
|
hidden_state, all_intermediate_hidden_states = self.transformer(
|
||||||
|
hidden_state,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
intermediate_hidden_states = [
|
||||||
|
hidden_state
|
||||||
|
for idx, hidden_state in enumerate(all_intermediate_hidden_states)
|
||||||
|
if idx in self.intermediate_layers_indices
|
||||||
|
]
|
||||||
|
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
|
||||||
|
|
||||||
|
# apply global encoder
|
||||||
|
hidden_state = self.layernorm_post(hidden_state)
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media,
|
||||||
|
num_tiles,
|
||||||
|
num_patches + num_padding_patches,
|
||||||
|
dim,
|
||||||
|
)
|
||||||
|
hidden_state = self.post_tile_positional_embedding(
|
||||||
|
hidden_state, aspect_ratio_ids
|
||||||
|
)
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media,
|
||||||
|
num_tiles * (num_patches + num_padding_patches),
|
||||||
|
dim,
|
||||||
|
)
|
||||||
|
hidden_state, _ = self.global_transformer(
|
||||||
|
hidden_state, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size * num_concurrent_media,
|
||||||
|
num_tiles,
|
||||||
|
num_patches + num_padding_patches,
|
||||||
|
dim,
|
||||||
|
)
|
||||||
|
hidden_state = hidden_state[:, :, :slice_index]
|
||||||
|
|
||||||
|
# adding intermediate layer outputs
|
||||||
|
hidden_state = hidden_state.reshape(
|
||||||
|
batch_size, num_concurrent_media, num_tiles, num_patches, dim
|
||||||
|
)
|
||||||
|
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
||||||
|
batch_size * num_concurrent_media,
|
||||||
|
num_tiles,
|
||||||
|
num_patches + num_padding_patches,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
|
||||||
|
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
||||||
|
batch_size, num_concurrent_media, num_tiles, num_patches, -1
|
||||||
|
)
|
||||||
|
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaTextCrossAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, *, prefix, config, weights, layer_idx):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.num_heads = self.config.num_attention_heads
|
||||||
|
self.num_key_value_heads = self.config.num_key_value_heads
|
||||||
|
self.dropout = config.dropout
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_size = config.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.num_key_value_heads = (
|
||||||
|
self.num_key_value_heads // weights.process_group.size()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.k_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.k_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.v_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.v_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_norm = MllamaTextRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.k_norm = MllamaTextRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
# past_key_value=None,
|
||||||
|
# attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
# cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
# hidden_states = hidden_states.unsqueeze(0)
|
||||||
|
# bsz, q_len, _ = hidden_states.size()
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
query_states = query_states.view(-1, self.num_heads, self.head_size)
|
||||||
|
query_states = self.q_norm(query_states)
|
||||||
|
|
||||||
|
(
|
||||||
|
cross_attention_states,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
|
max_q,
|
||||||
|
max_k,
|
||||||
|
indices,
|
||||||
|
) = cross_attention_states
|
||||||
|
|
||||||
|
key_states = self.k_proj(cross_attention_states)
|
||||||
|
value_states = self.v_proj(cross_attention_states)
|
||||||
|
key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
|
||||||
|
key_states = self.k_norm(key_states)
|
||||||
|
|
||||||
|
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
|
||||||
|
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
|
||||||
|
|
||||||
|
causal = False
|
||||||
|
# logger.info(
|
||||||
|
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
||||||
|
# )
|
||||||
|
attn_output = flash_attn_2_cuda.varlen_fwd(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
None,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None, # block_tables
|
||||||
|
None,
|
||||||
|
max_q,
|
||||||
|
max_k,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
causal, # Causal
|
||||||
|
-1, # window_size_left,
|
||||||
|
-1,
|
||||||
|
0.0, # softcap
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)[0]
|
||||||
|
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
|
||||||
|
class MllamaTextMLP(nn.Module):
|
||||||
|
def __init__(self, *, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = (
|
||||||
|
config.intermediate_size // weights.process_group.size()
|
||||||
|
)
|
||||||
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
|
weights=weights,
|
||||||
|
dim=0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
shape = x.shape
|
||||||
|
gate_up_states = self.gate_up_proj(x)
|
||||||
|
gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
|
||||||
|
result = self.down_proj(
|
||||||
|
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class FlashLlamaCrossLayer(torch.nn.Module):
|
||||||
|
"""Cross-attention transformer block with tanh-gated attention and feedforward."""
|
||||||
|
|
||||||
|
def __init__(self, *, prefix, config, weights, index) -> None:
|
||||||
|
layer_idx = index
|
||||||
|
super().__init__()
|
||||||
|
self.cross_attn = MllamaTextCrossAttention(
|
||||||
|
prefix=f"{prefix}.cross_attn",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_layernorm = MllamaTextRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.cross_attn_attn_gate = torch.nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
self.post_attention_layernorm = MllamaTextRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.cross_attn_mlp_gate = torch.nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
adapter_data,
|
||||||
|
cross_attention_states, # [ IB, ...]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if cross_attention_states is None:
|
||||||
|
return hidden_states, residual
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
|
||||||
|
indices = cross_attention_states[-1]
|
||||||
|
out_hidden_states = hidden_states[:]
|
||||||
|
if len(indices) > 0:
|
||||||
|
assert max(indices) < hidden_states.shape[0]
|
||||||
|
hidden_states = hidden_states[indices]
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.cross_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
# attention_mask=cross_attention_mask,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
)
|
||||||
|
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
||||||
|
|
||||||
|
out_hidden_states[indices] = hidden_states
|
||||||
|
hidden_states = out_hidden_states
|
||||||
|
|
||||||
|
return hidden_states, None
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
|
||||||
|
class MllamaTextRMSNorm(nn.Module):
|
||||||
|
def __init__(self, weight, eps):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = weight
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, *, prefix, weights, eps):
|
||||||
|
weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
return cls(weight=weight, eps=eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = None
|
||||||
|
config.vision_config.speculator = config.speculator
|
||||||
|
config.text_config.quantize = config.quantize
|
||||||
|
config.text_config.speculator = config.speculator
|
||||||
|
config.text_config._attn_implementation = "sdpa"
|
||||||
|
self.hidden_size = config.text_config.hidden_size
|
||||||
|
self.vision_model = MllamaVisionModel(
|
||||||
|
prefix="vision_model", config=config.vision_config, weights=weights
|
||||||
|
)
|
||||||
|
self.multi_modal_projector = FastLinear.load(
|
||||||
|
prefix="multi_modal_projector", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.text_model = FlashLlamaForCausalLM(
|
||||||
|
prefix="language_model", config=config.text_config, weights=weights
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
self.device = weights.device
|
||||||
|
|
||||||
|
def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
|
||||||
|
if aspect_ratio_ids is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||||
|
)
|
||||||
|
# logger.info(f"PIxel values {pixel_values.shape}")
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
vision_states = self.vision_model(
|
||||||
|
pixel_values, aspect_ratio_ids, aspect_ratio_mask
|
||||||
|
)
|
||||||
|
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
|
||||||
|
-1, vision_states.shape[-2], self.hidden_size
|
||||||
|
)
|
||||||
|
_, _, h = cross_attention_states.shape
|
||||||
|
cross_attention_states = cross_attention_states.view(batch_size, -1, h)
|
||||||
|
# logger.info(f"cross {cross_attention_states.shape}")
|
||||||
|
return cross_attention_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
# XXX: Putting these as optional so that the cuda warmup calls can go through.
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
image_indices=None,
|
||||||
|
):
|
||||||
|
if cross_attention_states is not None:
|
||||||
|
seqlen_q = len(image_indices)
|
||||||
|
n_images = cross_attention_states.shape[0]
|
||||||
|
seqlen_k = cross_attention_states.shape[1]
|
||||||
|
device = cross_attention_states.device
|
||||||
|
if cu_seqlen_prefill is not None:
|
||||||
|
offset = 0
|
||||||
|
cu_q = []
|
||||||
|
indices = []
|
||||||
|
for index in image_indices:
|
||||||
|
cu_q.append(offset)
|
||||||
|
length = seqlen.input_lengths[index].item()
|
||||||
|
assert index < seqlen.cu_seqlen_q.shape[0]
|
||||||
|
input_ids_offset = seqlen.cu_seqlen_q[index]
|
||||||
|
indices.extend(range(input_ids_offset, input_ids_offset + length))
|
||||||
|
offset += length
|
||||||
|
cu_q.append(offset)
|
||||||
|
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
assert max(indices) < input_ids.shape[0]
|
||||||
|
|
||||||
|
cu_seqlen_k = (
|
||||||
|
torch.arange(
|
||||||
|
n_images + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
* seqlen_k
|
||||||
|
)
|
||||||
|
max_q = cu_seqlen_q[-1].item()
|
||||||
|
max_k = seqlen_k
|
||||||
|
else:
|
||||||
|
cu_seqlen_q = torch.arange(
|
||||||
|
seqlen_q + 1, device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
seqlen_k = cross_attention_states.shape[1]
|
||||||
|
n_images = cross_attention_states.shape[0]
|
||||||
|
cu_seqlen_k = (
|
||||||
|
torch.arange(
|
||||||
|
n_images + 1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
* seqlen_k
|
||||||
|
)
|
||||||
|
max_q = seqlen_q
|
||||||
|
max_k = seqlen_k
|
||||||
|
indices = image_indices[:]
|
||||||
|
|
||||||
|
cross_attention_states = (
|
||||||
|
cross_attention_states,
|
||||||
|
cu_seqlen_q,
|
||||||
|
cu_seqlen_k,
|
||||||
|
max_q,
|
||||||
|
max_k,
|
||||||
|
indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.text_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs
|
|
@ -1,105 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
|
|
||||||
from text_generation_server.models.custom_modeling.idefics_processing import (
|
|
||||||
IdeficsProcessor,
|
|
||||||
)
|
|
||||||
from transformers import LlamaTokenizerFast
|
|
||||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
|
||||||
IdeficsForVisionText2Text,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
|
||||||
from text_generation_server.utils import (
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
Weights,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.quantization import get_loader
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
class IDEFICSSharded(IdeficsCausalLM):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
speculator: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.quantize = quantize
|
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
# 9b seems to work correctly enough in float16, but 80b seems
|
|
||||||
# to be really saturating for f16.
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
elif SYSTEM == "ipex":
|
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
device = torch.device(f"xpu:{rank}")
|
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
# Float16 doesn't exist on target.
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
self.device, self.dtype = device, dtype
|
|
||||||
|
|
||||||
config = IdeficsConfig.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
config.quantize = quantize
|
|
||||||
config.speculator = speculator
|
|
||||||
config.vision_config.quantize = quantize
|
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
self.processor = IdeficsProcessor.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
weights_loader = get_loader(
|
|
||||||
quantize=quantize, model_id=model_id, revision=revision
|
|
||||||
)
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
||||||
weights = Weights(
|
|
||||||
filenames,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
process_group=self.process_group,
|
|
||||||
weights_loader=weights_loader,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = IdeficsForVisionText2Text(config, weights)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
super(IdeficsCausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
|
@ -6,6 +6,7 @@ import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
|
@ -22,6 +23,18 @@ from text_generation_server.models.types import (
|
||||||
)
|
)
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
import torch.distributed
|
||||||
|
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
||||||
|
IdeficsForVisionText2Text,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.quantization import get_loader
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
@ -577,23 +590,38 @@ class IdeficsCausalLM(Model):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.quantize = quantize
|
self.quantize = quantize
|
||||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
IdeficsForVisionText2Text,
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
# 9b seems to work correctly enough in float16, but 80b seems
|
||||||
|
# to be really saturating for f16.
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
# Float16 doesn't exist on target.
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
|
||||||
raise ValueError("quantization is not available on CPU")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
self.device, self.dtype = device, dtype
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
config.quantize = quantize
|
||||||
|
config.speculator = speculator
|
||||||
|
config.vision_config.quantize = quantize
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -609,38 +637,34 @@ class IdeficsCausalLM(Model):
|
||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
model = IdeficsForVisionText2Text.from_pretrained(
|
|
||||||
model_id,
|
weights_loader = get_loader(
|
||||||
revision=revision,
|
quantize=quantize, model_id=model_id, revision=revision
|
||||||
torch_dtype=dtype,
|
)
|
||||||
device_map=(
|
torch.distributed.barrier(group=self.process_group)
|
||||||
"auto"
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
weights = Weights(
|
||||||
else None
|
filenames,
|
||||||
),
|
device=device,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
process_group=self.process_group,
|
||||||
|
weights_loader=weights_loader,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
|
||||||
model = model.cuda()
|
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
model = IdeficsForVisionText2Text(config, weights)
|
||||||
if model.config.pad_token_id is not None:
|
|
||||||
tokenizer.pad_token_id = model.config.pad_token_id
|
|
||||||
elif model.config.eos_token_id is not None:
|
|
||||||
tokenizer.pad_token_id = model.config.eos_token_id
|
|
||||||
elif tokenizer.eos_token_id is not None:
|
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
||||||
else:
|
|
||||||
tokenizer.add_special_tokens({"pad_token": "<unk>"})
|
|
||||||
|
|
||||||
super(IdeficsCausalLM, self).__init__(
|
self.config = config
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super().__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -0,0 +1,357 @@
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from typing import Iterable, Optional, Tuple, List, Dict
|
||||||
|
from text_generation_server.pb.generate_pb2 import Request
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from opentelemetry import trace
|
||||||
|
from transformers import (
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
|
block_tables_to_ragged,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MllamaCausalLMBatch(VlmCausalLMBatch):
|
||||||
|
image_indices: List[int] = 42
|
||||||
|
aspect_ratio_ids: Optional[torch.Tensor] = None
|
||||||
|
aspect_ratio_mask: Optional[torch.Tensor] = None
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@tracer.start_as_current_span("concatenate")
|
||||||
|
def concatenate(cls, batches):
|
||||||
|
batch = super().concatenate(batches)
|
||||||
|
batch.pixel_values = None
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
image_indices = []
|
||||||
|
attention_states = []
|
||||||
|
for b in batches:
|
||||||
|
if b.cross_attention_states is not None:
|
||||||
|
attention_states.append(b.cross_attention_states)
|
||||||
|
image_indices.extend([i + offset for i in b.image_indices])
|
||||||
|
offset += len(b.image_indices)
|
||||||
|
if len(attention_states) > 0:
|
||||||
|
assert len(image_indices) > 0
|
||||||
|
batch.cross_attention_states = torch.cat(attention_states, dim=0)
|
||||||
|
batch.image_indices = image_indices
|
||||||
|
else:
|
||||||
|
batch.cross_attention_states = None
|
||||||
|
batch.image_indices = []
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@tracer.start_as_current_span("filter")
|
||||||
|
def filter(self, request_ids: List[int]):
|
||||||
|
assert self.image_indices is not None
|
||||||
|
batch = super().filter(request_ids)
|
||||||
|
assert self.image_indices is not None
|
||||||
|
indices = []
|
||||||
|
for i, request_id in enumerate(request_ids):
|
||||||
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
indices.append(idx)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
new_image_indices = []
|
||||||
|
prev_i = None
|
||||||
|
for i in self.image_indices:
|
||||||
|
if i in indices:
|
||||||
|
new_image_indices.append(offset)
|
||||||
|
if i != prev_i:
|
||||||
|
offset += 1
|
||||||
|
prev_i = i
|
||||||
|
|
||||||
|
batch.image_indices = new_image_indices
|
||||||
|
if len(new_image_indices) > 0:
|
||||||
|
assert max(new_image_indices) < self.cross_attention_states.shape[0]
|
||||||
|
assert offset <= self.cross_attention_states.shape[0]
|
||||||
|
batch.cross_attention_states = self.cross_attention_states[
|
||||||
|
new_image_indices
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch.cross_attention_states = None
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_tokenized_inputs(
|
||||||
|
cls, requests: Iterable[Request], tokenizer, processor, config
|
||||||
|
):
|
||||||
|
image_inputs = []
|
||||||
|
texts = []
|
||||||
|
image_indices = []
|
||||||
|
batch_tokenized_inputs = []
|
||||||
|
|
||||||
|
for i, r in enumerate(requests):
|
||||||
|
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||||
|
curr_text = ""
|
||||||
|
curr_image = None
|
||||||
|
curr_i = None
|
||||||
|
for chunk in r.input_chunks.chunks:
|
||||||
|
chunk_type = chunk.WhichOneof("chunk")
|
||||||
|
if chunk_type == "text":
|
||||||
|
curr_text += chunk.text
|
||||||
|
elif chunk_type == "image":
|
||||||
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
|
# TODO unsure about BOS
|
||||||
|
curr_text += "<|image|>"
|
||||||
|
image_input = processor.image_processor(image, return_tensors="pt")
|
||||||
|
curr_image = image_input
|
||||||
|
curr_i = i
|
||||||
|
# image_inputs.append(image_input)
|
||||||
|
# image_indices.append(i)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
texts.append(curr_text)
|
||||||
|
if curr_image is not None:
|
||||||
|
image_inputs.append(curr_image)
|
||||||
|
image_indices.append(curr_i)
|
||||||
|
|
||||||
|
input_ids = tokenizer(
|
||||||
|
curr_text,
|
||||||
|
truncation=True,
|
||||||
|
max_length=r.truncate,
|
||||||
|
add_special_tokens=r.add_special_tokens,
|
||||||
|
)["input_ids"]
|
||||||
|
batch_tokenized_inputs.append(input_ids)
|
||||||
|
if image_inputs:
|
||||||
|
image_input = image_inputs[0]
|
||||||
|
new_image_inputs = {
|
||||||
|
"pixel_values": torch.cat(
|
||||||
|
[img["pixel_values"] for img in image_inputs], dim=0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
if "aspect_ratio_ids" in image_input:
|
||||||
|
new_image_inputs["aspect_ratio_ids"] = torch.cat(
|
||||||
|
[img["aspect_ratio_ids"] for img in image_inputs], dim=0
|
||||||
|
)
|
||||||
|
if "aspect_ratio_mask" in image_input:
|
||||||
|
new_image_inputs["aspect_ratio_mask"] = torch.cat(
|
||||||
|
[img["aspect_ratio_mask"] for img in image_inputs], dim=0
|
||||||
|
)
|
||||||
|
image_inputs = new_image_inputs
|
||||||
|
image_inputs["image_indices"] = image_indices
|
||||||
|
else:
|
||||||
|
image_inputs = None
|
||||||
|
|
||||||
|
if image_inputs is not None:
|
||||||
|
assert len(image_indices) == image_inputs["pixel_values"].shape[0]
|
||||||
|
|
||||||
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb_processor(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
processor,
|
||||||
|
config,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "VlmCausalLMBatch":
|
||||||
|
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||||
|
pb.requests, tokenizer, processor, config
|
||||||
|
)
|
||||||
|
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
|
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
|
||||||
|
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
|
||||||
|
max=config.text_config.vocab_size - 1
|
||||||
|
)
|
||||||
|
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
|
||||||
|
|
||||||
|
if image_inputs is not None:
|
||||||
|
batch.pixel_values = image_inputs["pixel_values"].to(
|
||||||
|
device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device)
|
||||||
|
batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
batch.image_indices = image_inputs["image_indices"]
|
||||||
|
else:
|
||||||
|
batch.pixel_values = None
|
||||||
|
batch.aspect_ratio_ids = None
|
||||||
|
batch.aspect_ratio_mask = None
|
||||||
|
batch.image_indices = []
|
||||||
|
assert batch.image_indices is not None
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaCausalLM(VlmCausalLM):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch: VlmCausalLMBatch,
|
||||||
|
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
# Model Forward
|
||||||
|
if batch.speculative_ids is not None:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = self.kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
max_s = batch.max_seqlen
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
speculative_ids = batch.speculative_ids
|
||||||
|
|
||||||
|
B, speculative_length = speculative_ids.shape
|
||||||
|
new_length = speculative_length + 1
|
||||||
|
new_input_ids = torch.cat(
|
||||||
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||||
|
).reshape(-1)
|
||||||
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||||
|
arange_int = arange.to(dtype=torch.int32)
|
||||||
|
new_position_ids = (
|
||||||
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||||
|
).view(-1)
|
||||||
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||||
|
input_lengths = (
|
||||||
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
|
).view(-1)
|
||||||
|
prefix_lens_tensor = (
|
||||||
|
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
|
# Add Copy the block tables for all members
|
||||||
|
block_tables = (
|
||||||
|
block_tables.unsqueeze(1)
|
||||||
|
.expand(B, new_length, -1)
|
||||||
|
.reshape(B * new_length, -1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
|
input_ids = new_input_ids
|
||||||
|
position_ids = new_position_ids
|
||||||
|
else:
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = self.kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||||
|
max_s = batch.max_seqlen
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
|
# in a circular buffer mode.
|
||||||
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
# Try to find an associated cuda graph
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
|
if sorted_padded_bs:
|
||||||
|
# Get associated cuda graph
|
||||||
|
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||||
|
else:
|
||||||
|
cuda_graph = None
|
||||||
|
if (
|
||||||
|
cu_seqlen_prefill is not None
|
||||||
|
or cuda_graph is None
|
||||||
|
# Only run cuda graphs when there's no images.
|
||||||
|
or batch.cross_attention_states is not None
|
||||||
|
):
|
||||||
|
input_lengths = input_lengths + prefix_lens_tensor
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=block_tables,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
input_lengths_tensor=input_lengths,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
|
):
|
||||||
|
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lengths=prefix_lens_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
max_q=max_s,
|
||||||
|
max_k=max_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch.pixel_values is not None:
|
||||||
|
cross_attention_states = self.model.vision_forward(
|
||||||
|
pixel_values=batch.pixel_values,
|
||||||
|
aspect_ratio_ids=batch.aspect_ratio_ids,
|
||||||
|
aspect_ratio_mask=batch.aspect_ratio_mask,
|
||||||
|
)
|
||||||
|
batch.cross_attention_states = cross_attention_states
|
||||||
|
|
||||||
|
cross_attention_states = batch.cross_attention_states
|
||||||
|
|
||||||
|
logits, speculative_logits = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
image_indices=batch.image_indices[:],
|
||||||
|
)
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
batch.prefill_cache_indices = None
|
||||||
|
if batch.pixel_values is not None:
|
||||||
|
batch.pixel_values = None
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
# Static inputs are potentially padded
|
||||||
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
else:
|
||||||
|
cuda_graph["block_tables"][
|
||||||
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
|
] = block_tables
|
||||||
|
cuda_graph["slots"].fill_(0)
|
||||||
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
|
cuda_graph["input_lengths"].zero_()
|
||||||
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
|
input_lengths + prefix_lens_tensor
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replay the graph
|
||||||
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
|
# Slice output to the correct shape
|
||||||
|
speculative_logits = (
|
||||||
|
cuda_graph["speculative_logits"][:bs]
|
||||||
|
if cuda_graph["speculative_logits"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
return logits, speculative_logits
|
|
@ -22,8 +22,14 @@ try:
|
||||||
VlmCausalLMBatch,
|
VlmCausalLMBatch,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||||
|
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||||
|
|
||||||
VLM_BATCH_TYPES = {PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch}
|
VLM_BATCH_TYPES = {
|
||||||
|
PaliGemmaBatch,
|
||||||
|
VlmCausalLMBatch,
|
||||||
|
IdeficsCausalLMBatch,
|
||||||
|
MllamaCausalLMBatch,
|
||||||
|
}
|
||||||
except (ImportError, NotImplementedError):
|
except (ImportError, NotImplementedError):
|
||||||
# These imports can fail on CPU/Non flash.
|
# These imports can fail on CPU/Non flash.
|
||||||
VLM_BATCH_TYPES = set()
|
VLM_BATCH_TYPES = set()
|
||||||
|
|
Loading…
Reference in New Issue