refacto
This commit is contained in:
parent
b562680be4
commit
93e0a7de8b
|
@ -194,17 +194,6 @@ dependencies = [
|
|||
"v_frame",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "average"
|
||||
version = "0.14.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c309b1c7fca12ebeec3ecba29ea917b3a4cb458ccf504df68bb4d8a0ca565a00"
|
||||
dependencies = [
|
||||
"easy-cast",
|
||||
"float-ord",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "avif-serialize"
|
||||
version = "0.8.1"
|
||||
|
@ -503,12 +492,6 @@ dependencies = [
|
|||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cassowary"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.98"
|
||||
|
@ -570,7 +553,7 @@ version = "4.5.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64"
|
||||
dependencies = [
|
||||
"heck 0.5.0",
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
|
@ -675,31 +658,6 @@ version = "0.8.20"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80"
|
||||
|
||||
[[package]]
|
||||
name = "crossterm"
|
||||
version = "0.27.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df"
|
||||
dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
"crossterm_winapi",
|
||||
"libc",
|
||||
"mio",
|
||||
"parking_lot",
|
||||
"signal-hook",
|
||||
"signal-hook-mio",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossterm_winapi"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crunchy"
|
||||
version = "0.2.2"
|
||||
|
@ -832,15 +790,6 @@ dependencies = [
|
|||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "easy-cast"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10936778145f3bea71fd9bf61332cce28c28e96a380714f7ab34838b80733fd6"
|
||||
dependencies = [
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.12.0"
|
||||
|
@ -944,12 +893,6 @@ dependencies = [
|
|||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "float-ord"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d"
|
||||
|
||||
[[package]]
|
||||
name = "float_eq"
|
||||
version = "1.0.1"
|
||||
|
@ -1208,12 +1151,6 @@ version = "0.14.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
|
@ -1498,12 +1435,6 @@ dependencies = [
|
|||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indoc"
|
||||
version = "2.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
|
||||
|
||||
[[package]]
|
||||
name = "init-tracing-opentelemetry"
|
||||
version = "0.14.1"
|
||||
|
@ -1674,12 +1605,6 @@ dependencies = [
|
|||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||
|
||||
[[package]]
|
||||
name = "libredox"
|
||||
version = "0.1.3"
|
||||
|
@ -1896,7 +1821,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"wasi",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
@ -2148,7 +2072,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -2412,17 +2335,6 @@ version = "0.1.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||
|
||||
[[package]]
|
||||
name = "papergrid"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2ccbe15f2b6db62f9a9871642746427e297b0ceb85f9a7f1ee5ff47d184d0c8"
|
||||
dependencies = [
|
||||
"bytecount",
|
||||
"fnv",
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.3"
|
||||
|
@ -2626,7 +2538,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"heck 0.5.0",
|
||||
"heck",
|
||||
"itertools 0.12.1",
|
||||
"log",
|
||||
"multimap",
|
||||
|
@ -2745,23 +2657,6 @@ dependencies = [
|
|||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ratatui"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e2e4cd95294a85c3b4446e63ef054eea43e0205b1fd60120c16b74ff7ff96ad"
|
||||
dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
"cassowary",
|
||||
"crossterm",
|
||||
"indoc",
|
||||
"itertools 0.11.0",
|
||||
"paste",
|
||||
"strum",
|
||||
"unicode-segmentation",
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rav1e"
|
||||
version = "0.7.1"
|
||||
|
@ -3269,27 +3164,6 @@ dependencies = [
|
|||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook"
|
||||
version = "0.3.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"signal-hook-registry",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-mio"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"mio",
|
||||
"signal-hook",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.2"
|
||||
|
@ -3387,28 +3261,6 @@ version = "0.11.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||
|
||||
[[package]]
|
||||
name = "strum"
|
||||
version = "0.25.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125"
|
||||
dependencies = [
|
||||
"strum_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strum_macros"
|
||||
version = "0.25.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 2.0.66",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.5.0"
|
||||
|
@ -3491,36 +3343,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349"
|
||||
dependencies = [
|
||||
"cfg-expr",
|
||||
"heck 0.5.0",
|
||||
"heck",
|
||||
"pkg-config",
|
||||
"toml",
|
||||
"version-compare",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tabled"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dfe9c3632da101aba5131ed63f9eed38665f8b3c68703a6bb18124835c1a5d22"
|
||||
dependencies = [
|
||||
"papergrid",
|
||||
"tabled_derive",
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tabled_derive"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "99f688a08b54f4f02f0a3c382aefdb7884d3d69609f785bd253dc033243e3fe4"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "target-lexicon"
|
||||
version = "0.12.14"
|
||||
|
@ -3539,45 +3367,6 @@ dependencies = [
|
|||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "2.0.5-dev0"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap",
|
||||
"crossterm",
|
||||
"float-ord",
|
||||
"hf-hub",
|
||||
"ratatui",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tabled",
|
||||
"text-generation-client",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "2.0.5-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"futures",
|
||||
"grpc-metadata",
|
||||
"prost 0.12.6",
|
||||
"prost-build",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tonic 0.10.2",
|
||||
"tonic-build",
|
||||
"tower",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "2.0.5-dev0"
|
||||
|
@ -3627,7 +3416,6 @@ dependencies = [
|
|||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"text-generation-client",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
|
@ -3641,6 +3429,54 @@ dependencies = [
|
|||
"vergen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v3"
|
||||
version = "2.0.5-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"axum 0.7.5",
|
||||
"axum-tracing-opentelemetry",
|
||||
"base64 0.22.1",
|
||||
"clap",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"grpc-metadata",
|
||||
"hf-hub",
|
||||
"image",
|
||||
"init-tracing-opentelemetry",
|
||||
"jsonschema",
|
||||
"metrics",
|
||||
"metrics-exporter-prometheus",
|
||||
"minijinja",
|
||||
"minijinja-contrib",
|
||||
"nohash-hasher",
|
||||
"once_cell",
|
||||
"opentelemetry 0.20.0",
|
||||
"opentelemetry-otlp",
|
||||
"prost 0.12.6",
|
||||
"prost-build",
|
||||
"rand",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"text-generation-router",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tonic 0.10.2",
|
||||
"tonic-build",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-opentelemetry 0.21.0",
|
||||
"tracing-subscriber",
|
||||
"utoipa",
|
||||
"utoipa-swagger-ui",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.61"
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
[workspace]
|
||||
members = [
|
||||
"benchmark",
|
||||
"router",
|
||||
"router/client",
|
||||
"router/grpc-metadata",
|
||||
# "benchmark",
|
||||
"backends/v3",
|
||||
# "backends/client",
|
||||
"backends/grpc-metadata",
|
||||
"launcher"
|
||||
]
|
||||
resolver = "2"
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
[package]
|
||||
name = "text-generation-router-v3"
|
||||
description = "Text Generation Webserver"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "text-generation-router"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.16"
|
||||
text-generation-router = { path = "../../router" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = "0.21.1"
|
||||
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.13.0"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.20", features = [] }
|
||||
serde = "1.0.188"
|
||||
serde_json = "1.0.107"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true}
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.21.0"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||
minijinja = { version = "2.0.2" }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
image = "0.25.1"
|
||||
base64 = { workspace = true }
|
||||
prost = "^0.12"
|
||||
tonic = "^0.10"
|
||||
tower = "^0.4"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.10.1"
|
||||
prost-build = "0.12.1"
|
||||
|
||||
[features]
|
||||
default = ["ngrok"]
|
||||
ngrok = ["text-generation-router/ngrok"]
|
||||
google = ["text-generation-router/google"]
|
||||
kserve = ["text-generation-router/kserve"]
|
|
@ -0,0 +1,19 @@
|
|||
use std::fs;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("cargo:rerun-if-changed=../../proto/");
|
||||
|
||||
fs::create_dir_all("src/client/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/client/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -1,16 +1,15 @@
|
|||
/// Batching and inference logic
|
||||
use crate::infer::v3::queue::{Entry, Queue};
|
||||
use crate::infer::{
|
||||
use crate::queue::{Entry, Queue};
|
||||
use text_generation_router::infer::{
|
||||
GeneratedText, InferError, InferStreamResponse, Backend,
|
||||
};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::{FinishReason, PrefillToken, Token};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::{
|
||||
Arc,
|
||||
};
|
||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
|
||||
use text_generation_client::{ClientError, Health};
|
||||
use crate::client::{Batch, CachedBatch, Generation, ShardedClient, ClientError, Health};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
|
@ -18,7 +17,7 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
|
|||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
use async_trait::async_trait;
|
||||
|
||||
pub(crate) struct BackendV3 {
|
||||
pub struct BackendV3 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
|
@ -78,7 +77,6 @@ impl Backend for BackendV3 {
|
|||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
let input_length = request.input_length;
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
|
@ -480,14 +478,14 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||
});
|
||||
}
|
||||
|
||||
impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
|
||||
fn from(value: text_generation_client::v3::GeneratedText) -> Self {
|
||||
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||
fn from(value: crate::client::GeneratedText) -> Self {
|
||||
let v3_finish_reason =
|
||||
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v3_finish_reason {
|
||||
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
|
||||
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
crate::client::FinishReason::Length => FinishReason::Length,
|
||||
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
};
|
||||
|
||||
Self {
|
|
@ -1,6 +1,7 @@
|
|||
use crate::v3::{pb, Chunk};
|
||||
use crate::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||
/// Single shard Client
|
||||
|
||||
use crate::client::{pb, Chunk};
|
||||
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use base64::Engine;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
|
@ -1,15 +1,24 @@
|
|||
//! Text Generation gRPC client library
|
||||
|
||||
pub mod v2;
|
||||
pub mod v3;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use thiserror::Error;
|
||||
use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
pub use v3::{Chunk, Image, Input, InputChunk};
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Health {
|
||||
|
@ -63,29 +72,6 @@ impl From<Chunk> for InputChunk {
|
|||
}
|
||||
}
|
||||
|
||||
/// Convert input chunks to a stringly-typed input for backwards
|
||||
/// compat for backends that haven't implemented chunked inputs.
|
||||
pub trait ChunksToString {
|
||||
/// Convert chunks to string.
|
||||
fn chunks_to_string(&self) -> String;
|
||||
}
|
||||
|
||||
impl ChunksToString for Vec<InputChunk> {
|
||||
fn chunks_to_string(&self) -> String {
|
||||
let mut output = String::new();
|
||||
self.iter().for_each(|c| match &c.chunk {
|
||||
Some(Chunk::Text(text)) => output.push_str(text),
|
||||
Some(Chunk::Image(Image { data, mimetype })) => {
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
||||
}
|
||||
// We don't create empty chunks, so this should be unreachable.
|
||||
None => unreachable!("Chunks should never be empty"),
|
||||
});
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
|
@ -1,14 +1,14 @@
|
|||
/// Multi shard Client
|
||||
use crate::{v3, Health, ShardInfo};
|
||||
use crate::{ClientError, Result};
|
||||
use crate::client::{Health, ShardInfo};
|
||||
use crate::client::{ClientError, Result};
|
||||
|
||||
use crate::v3::{Chunk, InfoResponse, Input};
|
||||
use crate::client::{Chunk, InfoResponse, Input};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
use v3::client::{DecodeTimings, PrefillTimings};
|
||||
use v3::{
|
||||
use crate::client::client::{DecodeTimings, PrefillTimings};
|
||||
use crate::client::{
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
|
@ -1,14 +1,13 @@
|
|||
mod block_allocator;
|
||||
mod queue;
|
||||
mod backend;
|
||||
mod client;
|
||||
|
||||
use futures_util::TryFutureExt;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use utoipa::ToSchema;
|
||||
pub(crate) use backend::BackendV3;
|
||||
use text_generation_client::ClientError;
|
||||
use text_generation_client::v3::ShardedClient;
|
||||
use crate::client::{ShardedClient, ClientError};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct BackendInfo {
|
||||
|
@ -127,7 +126,7 @@ pub async fn connect_backend(
|
|||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum V3Error {
|
||||
pub enum V3Error {
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
|
@ -20,6 +20,7 @@ use tower_http::cors::AllowOrigin;
|
|||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
||||
use text_generation_router_v3::{connect_backend, V3Error};
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
|
@ -336,9 +337,11 @@ async fn main() -> Result<(), RouterError> {
|
|||
}
|
||||
};
|
||||
|
||||
let (backend, backend_info) = connect_backend(max_input_tokens, max_total_tokens, master_shard_uds_path, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size).await?;
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
master_shard_uds_path,
|
||||
backend,
|
||||
model_info,
|
||||
compat_return_full_text,
|
||||
max_concurrent_requests,
|
||||
|
@ -347,11 +350,6 @@ async fn main() -> Result<(), RouterError> {
|
|||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
tokenizer,
|
||||
config,
|
||||
validation_workers,
|
||||
|
@ -508,6 +506,8 @@ pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConf
|
|||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("Backend failed: {0}")]
|
||||
Backend(#[from] V3Error),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
|
@ -1,17 +1,14 @@
|
|||
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
|
||||
use crate::infer::InferError;
|
||||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::{
|
||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
};
|
||||
use crate::block_allocator::{BlockAllocation, BlockAllocator};
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::validation::{Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::{max, min};
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::v3::{
|
||||
use crate::client::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use text_generation_client::ChunksToString;
|
||||
use text_generation_client::Input;
|
||||
use crate::client as client;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
@ -337,8 +334,11 @@ impl State {
|
|||
batch_requests.push(Request {
|
||||
id,
|
||||
prefill_logprobs: entry.request.decoder_input_details,
|
||||
input_chunks: Some(Input {
|
||||
chunks: entry.request.inputs.clone(),
|
||||
input_chunks: Some(client::Input {
|
||||
chunks: entry.request.inputs.clone().into_iter().map(|c| client::InputChunk { chunk: Some(match c {
|
||||
Chunk::Text(text) => client::Chunk::Text(text),
|
||||
Chunk::Image(image) => client::Chunk::Image(client::Image { data: image.data, mimetype: image.mimetype })
|
||||
})}).collect()
|
||||
}),
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
truncate: entry.request.truncate,
|
|
@ -21,7 +21,7 @@ float-ord = "0.3.2"
|
|||
serde = {version = "1.0.188", features = ["derive"]}
|
||||
serde_json = "1.0"
|
||||
tabled = "0.14.0"
|
||||
text-generation-client = { path = "../router/client" }
|
||||
text-generation-client = { path = "../backends/client" }
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
||||
|
|
|
@ -7,19 +7,11 @@ edition.workspace = true
|
|||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "text-generation-router"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.16"
|
||||
text-generation-client = { path = "client" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
use std::fs;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("cargo:rerun-if-changed=../../proto/");
|
||||
|
||||
fs::create_dir_all("src/v2/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/v2/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
||||
.map_err(|e| match e.kind(){
|
||||
std::io::ErrorKind::NotFound => {panic!("`protoc` not found, install libprotoc")},
|
||||
std::io::ErrorKind::Other => {panic!("`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases")},
|
||||
e => {e}
|
||||
}).unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
fs::create_dir_all("src/v3/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/v3/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
*
|
|
@ -1,13 +0,0 @@
|
|||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters, Tokens,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
|
@ -1 +0,0 @@
|
|||
*
|
|
@ -1,5 +1,4 @@
|
|||
// pub(crate) mod v2;
|
||||
pub(crate) mod v3;
|
||||
mod chat_template;
|
||||
pub mod tool_grammar;
|
||||
|
||||
|
@ -23,7 +22,7 @@ use chat_template::ChatTemplate;
|
|||
use async_trait::async_trait;
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait Backend {
|
||||
pub trait Backend {
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
|
@ -286,15 +285,15 @@ pub(crate) type GenerateStreamResponse = (
|
|||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct GeneratedText {
|
||||
pub(crate) text: String,
|
||||
pub(crate) generated_tokens: u32,
|
||||
pub(crate) finish_reason: FinishReason,
|
||||
pub(crate) seed: Option<u64>,
|
||||
pub struct GeneratedText {
|
||||
pub text: String,
|
||||
pub generated_tokens: u32,
|
||||
pub finish_reason: FinishReason,
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum InferStreamResponse {
|
||||
pub enum InferStreamResponse {
|
||||
// Optional first message
|
||||
Prefill(Vec<PrefillToken>),
|
||||
// Intermediate messages
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
/// Text Generation Inference Webserver
|
||||
pub mod config;
|
||||
mod infer;
|
||||
pub mod infer;
|
||||
pub mod server;
|
||||
mod validation;
|
||||
pub mod validation;
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
mod kserve;
|
||||
|
@ -1055,23 +1055,23 @@ impl From<CompatGenerateRequest> for GenerateRequest {
|
|||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct PrefillToken {
|
||||
#[schema(example = 0)]
|
||||
id: u32,
|
||||
pub id: u32,
|
||||
#[schema(example = "test")]
|
||||
text: String,
|
||||
pub text: String,
|
||||
#[schema(nullable = true, example = - 0.34)]
|
||||
logprob: f32,
|
||||
pub logprob: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema, Clone)]
|
||||
pub struct Token {
|
||||
#[schema(example = 0)]
|
||||
id: u32,
|
||||
pub id: u32,
|
||||
#[schema(example = "test")]
|
||||
text: String,
|
||||
pub text: String,
|
||||
#[schema(nullable = true, example = - 0.34)]
|
||||
logprob: f32,
|
||||
pub logprob: f32,
|
||||
#[schema(example = "false")]
|
||||
special: bool,
|
||||
pub special: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
|
@ -1089,7 +1089,7 @@ pub struct SimpleToken {
|
|||
#[derive(Debug, Serialize, ToSchema)]
|
||||
#[serde(rename_all(serialize = "snake_case"))]
|
||||
#[schema(example = "Length")]
|
||||
pub(crate) enum FinishReason {
|
||||
pub enum FinishReason {
|
||||
#[schema(rename = "length")]
|
||||
Length,
|
||||
#[serde(rename = "eos_token")]
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
/// HTTP Server logic
|
||||
use crate::config::Config;
|
||||
use crate::infer::v3::{connect_backend, V3Error};
|
||||
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, Backend};
|
||||
use crate::infer::tool_grammar::ToolGrammar;
|
||||
#[cfg(feature = "kserve")]
|
||||
|
@ -38,9 +37,6 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
|||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{v2, v3, ClientError, ShardInfo};
|
||||
use thiserror::Error;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::select;
|
||||
|
@ -1398,7 +1394,7 @@ pub(crate) struct ComputeType(String);
|
|||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
master_shard_uds_path: String,
|
||||
backend: impl Backend + Send + Sync + 'static,
|
||||
model_info: HubModelInfo,
|
||||
compat_return_full_text: bool,
|
||||
max_concurrent_requests: usize,
|
||||
|
@ -1407,11 +1403,6 @@ pub async fn run(
|
|||
max_top_n_tokens: u32,
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
config: Option<Config>,
|
||||
validation_workers: usize,
|
||||
|
@ -1495,11 +1486,6 @@ pub async fn run(
|
|||
struct ApiDoc;
|
||||
|
||||
// Create state
|
||||
|
||||
// Open connection, get model info and warmup
|
||||
let (backend, backend_info) = connect_backend(max_input_tokens, max_total_tokens, master_shard_uds_path, waiting_served_ratio, max_batch_prefill_tokens, max_batch_total_tokens, max_waiting_tokens, max_batch_size).await?;
|
||||
// tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
|
||||
let validation = Validation::new(
|
||||
validation_workers,
|
||||
tokenizer,
|
||||
|
@ -1827,8 +1813,6 @@ impl From<InferError> for Event {
|
|||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WebServerError {
|
||||
#[error("Backend error: {0}")]
|
||||
Backend(#[from] V3Error),
|
||||
#[error("Axum error: {0}")]
|
||||
Axum(#[from] axum::BoxError),
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ use jsonschema::{Draft, JSONSchema};
|
|||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
use text_generation_client::{Chunk, Image, InputChunk};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::sync::mpsc;
|
||||
|
@ -86,7 +85,7 @@ impl Validation {
|
|||
&self,
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
|
||||
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some(sender) = &self.sender {
|
||||
// Create response channel
|
||||
|
@ -112,7 +111,7 @@ impl Validation {
|
|||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
|
||||
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||
// Create response channel
|
||||
|
@ -514,7 +513,7 @@ fn prepare_input(
|
|||
_truncate: Option<usize>,
|
||||
tokenizer: &Tokenizer,
|
||||
config: &Option<Config>,
|
||||
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
|
||||
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
Some(Config::LlavaNext(config)) => {
|
||||
|
@ -626,18 +625,51 @@ fn prepare_input(
|
|||
|
||||
type TokenizerRequest = (
|
||||
(String, Option<usize>),
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
||||
Span,
|
||||
);
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct Image {
|
||||
pub data: Vec<u8>,
|
||||
pub mimetype: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum Chunk {
|
||||
Text(String),
|
||||
Image(Image)
|
||||
}
|
||||
|
||||
/// Convert input chunks to a stringly-typed input for backwards
|
||||
/// compat for backends that haven't implemented chunked inputs.
|
||||
pub trait ChunksToString {
|
||||
/// Convert chunks to string.
|
||||
fn chunks_to_string(&self) -> String;
|
||||
}
|
||||
|
||||
impl ChunksToString for Vec<Chunk> {
|
||||
fn chunks_to_string(&self) -> String {
|
||||
let mut output = String::new();
|
||||
self.iter().for_each(|c| match &c {
|
||||
Chunk::Text(text) => output.push_str(text),
|
||||
Chunk::Image(Image { data, mimetype }) => {
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
||||
}
|
||||
});
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum ValidGrammar {
|
||||
pub enum ValidGrammar {
|
||||
Json(String),
|
||||
Regex(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidParameters {
|
||||
pub struct ValidParameters {
|
||||
/// / exponential scaling output probability distribution
|
||||
pub temperature: f32,
|
||||
/// / restricting to the k highest probability elements
|
||||
|
@ -661,7 +693,7 @@ pub(crate) struct ValidParameters {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidStoppingParameters {
|
||||
pub struct ValidStoppingParameters {
|
||||
/// / Maximum number of generated tokens
|
||||
pub max_new_tokens: u32,
|
||||
/// / Optional stopping sequences
|
||||
|
@ -672,8 +704,8 @@ pub(crate) struct ValidStoppingParameters {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidGenerateRequest {
|
||||
pub inputs: Vec<InputChunk>,
|
||||
pub struct ValidGenerateRequest {
|
||||
pub inputs: Vec<Chunk>,
|
||||
pub input_length: u32,
|
||||
pub truncate: u32,
|
||||
pub decoder_input_details: bool,
|
||||
|
|
Loading…
Reference in New Issue