feat(docker): add benchmarking tool to docker image (#298)
This commit is contained in:
parent
926fd9a010
commit
e250282213
|
@ -134,6 +134,17 @@ version = "1.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||
|
||||
[[package]]
|
||||
name = "average"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "843ec791d3f24503bbf72bbd5e49a3ab4dbb4bcd0a8ef6b0c908efa73caa27b1"
|
||||
dependencies = [
|
||||
"easy-cast",
|
||||
"float-ord",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.6.15"
|
||||
|
@ -293,6 +304,12 @@ dependencies = [
|
|||
"zip",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cassowary"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.79"
|
||||
|
@ -461,6 +478,31 @@ dependencies = [
|
|||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossterm"
|
||||
version = "0.26.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a84cda67535339806297f1b331d6dd6320470d2a0fe65381e79ee9e156dd3d13"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"crossterm_winapi",
|
||||
"libc",
|
||||
"mio",
|
||||
"parking_lot",
|
||||
"signal-hook",
|
||||
"signal-hook-mio",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossterm_winapi"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2ae1b35a484aa10e07fe0638d02301c5ad24de82d310ccbd2f3693da5f09bf1c"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.6"
|
||||
|
@ -591,6 +633,15 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "easy-cast"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4bd102ee8c418348759919b83b81cdbdc933ffe29740b903df448b4bafaa348e"
|
||||
dependencies = [
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.8.1"
|
||||
|
@ -679,6 +730,12 @@ dependencies = [
|
|||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "float-ord"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d"
|
||||
|
||||
[[package]]
|
||||
name = "float_eq"
|
||||
version = "1.0.1"
|
||||
|
@ -1165,6 +1222,12 @@ version = "0.2.141"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.3.1"
|
||||
|
@ -1447,6 +1510,16 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.15.0"
|
||||
|
@ -1903,6 +1976,19 @@ dependencies = [
|
|||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ratatui"
|
||||
version = "0.20.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcc0d032bccba900ee32151ec0265667535c230169f5a011154cdcd984e16829"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"cassowary",
|
||||
"crossterm",
|
||||
"unicode-segmentation",
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "10.7.0"
|
||||
|
@ -2252,6 +2338,27 @@ dependencies = [
|
|||
"dirs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook"
|
||||
version = "0.3.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "732768f1176d21d09e076c23a93123d40bba92d50c4058da34d45c8de8e682b9"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"signal-hook-registry",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-mio"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"mio",
|
||||
"signal-hook",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.1"
|
||||
|
@ -2407,9 +2514,28 @@ dependencies = [
|
|||
"windows-sys 0.45.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "0.7.0-dev"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap",
|
||||
"crossterm",
|
||||
"float-ord",
|
||||
"ratatui",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"text-generation-client",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0-dev"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"grpc-metadata",
|
||||
|
@ -2421,12 +2547,11 @@ dependencies = [
|
|||
"tonic-build",
|
||||
"tower",
|
||||
"tracing",
|
||||
"tracing-error",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0-dev"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"ctrlc",
|
||||
|
@ -2442,7 +2567,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0-dev"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"axum",
|
||||
|
@ -2803,16 +2928,6 @@ dependencies = [
|
|||
"valuable",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-error"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d686ec1c0f384b1277f097b2f279a2ecc11afe8c133c1aabf036a27cb4cd206e"
|
||||
dependencies = [
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-futures"
|
||||
version = "0.2.5"
|
||||
|
|
10
Cargo.toml
10
Cargo.toml
|
@ -1,13 +1,17 @@
|
|||
[workspace]
|
||||
members = [
|
||||
"benchmark",
|
||||
"router",
|
||||
"router/client",
|
||||
"router/grpc-metadata",
|
||||
"launcher"
|
||||
]
|
||||
exclude = [
|
||||
"benchmark"
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.7.0-dev"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
||||
[profile.release]
|
||||
debug = 1
|
||||
|
|
|
@ -164,6 +164,8 @@ RUN cd server && \
|
|||
pip install -r requirements.txt && \
|
||||
pip install ".[bnb, accelerate]" --no-cache-dir
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
target
|
File diff suppressed because it is too large
Load Diff
|
@ -1,15 +1,10 @@
|
|||
[package]
|
||||
name = "text-generation-benchmark"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
description = "Text Generation Benchmarking tool"
|
||||
|
||||
[profile.release]
|
||||
debug = 1
|
||||
incremental = true
|
||||
lto = "off"
|
||||
panic = "abort"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
[toolchain]
|
||||
channel = "1.67.0"
|
||||
components = ["rustfmt", "clippy"]
|
|
@ -10,7 +10,7 @@
|
|||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "0.6.0"
|
||||
"version": "0.7.0-dev"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
|
@ -33,7 +33,19 @@
|
|||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "See /generate or /generate_stream"
|
||||
"description": "Generated Text",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/GenerateResponse"
|
||||
}
|
||||
},
|
||||
"text/event-stream": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/StreamResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Input validation error",
|
||||
|
@ -584,11 +596,73 @@
|
|||
"type": "object",
|
||||
"required": [
|
||||
"model_id",
|
||||
"model_dtype",
|
||||
"model_device_type",
|
||||
"max_concurrent_requests",
|
||||
"max_best_of",
|
||||
"max_stop_sequences",
|
||||
"max_input_length",
|
||||
"max_total_tokens",
|
||||
"waiting_served_ratio",
|
||||
"max_batch_total_tokens",
|
||||
"max_waiting_tokens",
|
||||
"validation_workers",
|
||||
"version"
|
||||
],
|
||||
"properties": {
|
||||
"docker_label": {
|
||||
"type": "string",
|
||||
"example": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"max_batch_total_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"example": "32000",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"max_best_of": {
|
||||
"type": "integer",
|
||||
"example": "2",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"max_concurrent_requests": {
|
||||
"type": "integer",
|
||||
"description": "Router Parameters",
|
||||
"example": "128",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"max_input_length": {
|
||||
"type": "integer",
|
||||
"example": "1024",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"max_stop_sequences": {
|
||||
"type": "integer",
|
||||
"example": "4",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"max_total_tokens": {
|
||||
"type": "integer",
|
||||
"example": "2048",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"max_waiting_tokens": {
|
||||
"type": "integer",
|
||||
"example": "20",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"model_device_type": {
|
||||
"type": "string",
|
||||
"example": "cuda"
|
||||
},
|
||||
"model_dtype": {
|
||||
"type": "string",
|
||||
"example": "torch.float16"
|
||||
},
|
||||
"model_id": {
|
||||
"type": "string",
|
||||
"description": "Model info",
|
||||
"example": "bigscience/blomm-560m"
|
||||
},
|
||||
"model_pipeline_tag": {
|
||||
|
@ -606,9 +680,20 @@
|
|||
"example": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"validation_workers": {
|
||||
"type": "integer",
|
||||
"example": "2",
|
||||
"minimum": 0.0
|
||||
},
|
||||
"version": {
|
||||
"type": "string",
|
||||
"description": "Router Info",
|
||||
"example": "0.5.0"
|
||||
},
|
||||
"waiting_served_ratio": {
|
||||
"type": "number",
|
||||
"format": "float",
|
||||
"example": "1.2"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
[package]
|
||||
name = "text-generation-launcher"
|
||||
version = "0.6.0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
description = "Text Generation Launcher"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
[package]
|
||||
name = "text-generation-router"
|
||||
version = "0.6.0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
description = "Text Generation Webserver"
|
||||
build = "build.rs"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
[package]
|
||||
name = "text-generation-client"
|
||||
version = "0.6.0"
|
||||
edition = "2021"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
futures = "^0.3"
|
||||
|
@ -12,7 +14,6 @@ tokio = { version = "^1.25", features = ["sync"] }
|
|||
tonic = "^0.8"
|
||||
tower = "^0.4"
|
||||
tracing = "^0.1"
|
||||
tracing-error = "^0.2"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.8.4"
|
||||
|
|
|
@ -9,6 +9,7 @@ use opentelemetry::{global, KeyValue};
|
|||
use opentelemetry_otlp::WithExportConfig;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use text_generation_client::ShardedClient;
|
||||
use text_generation_router::{server, HubModelInfo};
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
|
@ -145,7 +146,10 @@ fn main() -> Result<(), std::io::Error> {
|
|||
sha: None,
|
||||
pipeline_tag: None,
|
||||
},
|
||||
false => get_model_info(&tokenizer_name, &revision, authorization_token).await,
|
||||
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
|
||||
}),
|
||||
};
|
||||
|
||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||
|
@ -256,22 +260,24 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
|||
}
|
||||
|
||||
/// get model info from the Huggingface Hub
|
||||
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> HubModelInfo {
|
||||
pub async fn get_model_info(
|
||||
model_id: &str,
|
||||
revision: &str,
|
||||
token: Option<String>,
|
||||
) -> Option<HubModelInfo> {
|
||||
let client = reqwest::Client::new();
|
||||
// Poor man's urlencode
|
||||
let revision = revision.replace("/", "%2F");
|
||||
let revision = revision.replace('/', "%2F");
|
||||
let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
|
||||
let mut builder = client.get(url);
|
||||
let mut builder = client.get(url).timeout(Duration::from_secs(5));
|
||||
if let Some(token) = token {
|
||||
builder = builder.bearer_auth(token);
|
||||
}
|
||||
|
||||
let model_info = builder
|
||||
.send()
|
||||
.await
|
||||
.expect("Could not connect to hf.co")
|
||||
.text()
|
||||
.await
|
||||
.expect("error when retrieving model info from hf.co");
|
||||
serde_json::from_str(&model_info).expect("unable to parse model info")
|
||||
let response = builder.send().await.ok()?;
|
||||
|
||||
if response.status().is_success() {
|
||||
return serde_json::from_str(&response.text().await.ok()?).ok();
|
||||
}
|
||||
None
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ use utoipa_swagger_ui::SwaggerUi;
|
|||
path = "/",
|
||||
request_body = CompatGenerateRequest,
|
||||
responses(
|
||||
(status = 200, description = "See /generate or /generate_stream",
|
||||
(status = 200, description = "Generated Text",
|
||||
content(
|
||||
("application/json" = GenerateResponse),
|
||||
("text/event-stream" = StreamResponse),
|
||||
|
|
Loading…
Reference in New Issue