feat(docker): add benchmarking tool to docker image (#298)

This commit is contained in:
OlivierDehaene 2023-05-09 13:19:31 +02:00 committed by GitHub
parent 926fd9a010
commit e250282213
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 262 additions and 2940 deletions

143
Cargo.lock generated
View File

@ -134,6 +134,17 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "average"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "843ec791d3f24503bbf72bbd5e49a3ab4dbb4bcd0a8ef6b0c908efa73caa27b1"
dependencies = [
"easy-cast",
"float-ord",
"num-traits",
]
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.6.15" version = "0.6.15"
@ -293,6 +304,12 @@ dependencies = [
"zip", "zip",
] ]
[[package]]
name = "cassowary"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.79" version = "1.0.79"
@ -461,6 +478,31 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "crossterm"
version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a84cda67535339806297f1b331d6dd6320470d2a0fe65381e79ee9e156dd3d13"
dependencies = [
"bitflags",
"crossterm_winapi",
"libc",
"mio",
"parking_lot",
"signal-hook",
"signal-hook-mio",
"winapi",
]
[[package]]
name = "crossterm_winapi"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ae1b35a484aa10e07fe0638d02301c5ad24de82d310ccbd2f3693da5f09bf1c"
dependencies = [
"winapi",
]
[[package]] [[package]]
name = "crypto-common" name = "crypto-common"
version = "0.1.6" version = "0.1.6"
@ -591,6 +633,15 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "easy-cast"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4bd102ee8c418348759919b83b81cdbdc933ffe29740b903df448b4bafaa348e"
dependencies = [
"libm",
]
[[package]] [[package]]
name = "either" name = "either"
version = "1.8.1" version = "1.8.1"
@ -679,6 +730,12 @@ dependencies = [
"miniz_oxide", "miniz_oxide",
] ]
[[package]]
name = "float-ord"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d"
[[package]] [[package]]
name = "float_eq" name = "float_eq"
version = "1.0.1" version = "1.0.1"
@ -1165,6 +1222,12 @@ version = "0.2.141"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5"
[[package]]
name = "libm"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb"
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.3.1" version = "0.3.1"
@ -1447,6 +1510,16 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "num-traits"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
dependencies = [
"autocfg",
"libm",
]
[[package]] [[package]]
name = "num_cpus" name = "num_cpus"
version = "1.15.0" version = "1.15.0"
@ -1903,6 +1976,19 @@ dependencies = [
"getrandom", "getrandom",
] ]
[[package]]
name = "ratatui"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcc0d032bccba900ee32151ec0265667535c230169f5a011154cdcd984e16829"
dependencies = [
"bitflags",
"cassowary",
"crossterm",
"unicode-segmentation",
"unicode-width",
]
[[package]] [[package]]
name = "raw-cpuid" name = "raw-cpuid"
version = "10.7.0" version = "10.7.0"
@ -2252,6 +2338,27 @@ dependencies = [
"dirs", "dirs",
] ]
[[package]]
name = "signal-hook"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "732768f1176d21d09e076c23a93123d40bba92d50c4058da34d45c8de8e682b9"
dependencies = [
"libc",
"signal-hook-registry",
]
[[package]]
name = "signal-hook-mio"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af"
dependencies = [
"libc",
"mio",
"signal-hook",
]
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.1" version = "1.4.1"
@ -2407,9 +2514,28 @@ dependencies = [
"windows-sys 0.45.0", "windows-sys 0.45.0",
] ]
[[package]]
name = "text-generation-benchmark"
version = "0.7.0-dev"
dependencies = [
"average",
"clap",
"crossterm",
"float-ord",
"ratatui",
"serde",
"serde_json",
"text-generation-client",
"thiserror",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "0.6.0" version = "0.7.0-dev"
dependencies = [ dependencies = [
"futures", "futures",
"grpc-metadata", "grpc-metadata",
@ -2421,12 +2547,11 @@ dependencies = [
"tonic-build", "tonic-build",
"tower", "tower",
"tracing", "tracing",
"tracing-error",
] ]
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "0.6.0" version = "0.7.0-dev"
dependencies = [ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
@ -2442,7 +2567,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "0.6.0" version = "0.7.0-dev"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",
@ -2803,16 +2928,6 @@ dependencies = [
"valuable", "valuable",
] ]
[[package]]
name = "tracing-error"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d686ec1c0f384b1277f097b2f279a2ecc11afe8c133c1aabf036a27cb4cd206e"
dependencies = [
"tracing",
"tracing-subscriber",
]
[[package]] [[package]]
name = "tracing-futures" name = "tracing-futures"
version = "0.2.5" version = "0.2.5"

View File

@ -1,13 +1,17 @@
[workspace] [workspace]
members = [ members = [
"benchmark",
"router", "router",
"router/client", "router/client",
"router/grpc-metadata", "router/grpc-metadata",
"launcher" "launcher"
] ]
exclude = [
"benchmark" [workspace.package]
] version = "0.7.0-dev"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
[profile.release] [profile.release]
debug = 1 debug = 1

View File

@ -164,6 +164,8 @@ RUN cd server && \
pip install -r requirements.txt && \ pip install -r requirements.txt && \
pip install ".[bnb, accelerate]" --no-cache-dir pip install ".[bnb, accelerate]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
# Install launcher # Install launcher

View File

@ -1 +0,0 @@
target

2884
benchmark/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,10 @@
[package] [package]
name = "text-generation-benchmark" name = "text-generation-benchmark"
version = "0.1.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Benchmarking tool" description = "Text Generation Benchmarking tool"
version.workspace = true
[profile.release] edition.workspace = true
debug = 1 authors.workspace = true
incremental = true homepage.workspace = true
lto = "off"
panic = "abort"
[lib] [lib]
path = "src/lib.rs" path = "src/lib.rs"

View File

@ -1,3 +0,0 @@
[toolchain]
channel = "1.67.0"
components = ["rustfmt", "clippy"]

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "0.6.0" "version": "0.7.0-dev"
}, },
"paths": { "paths": {
"/": { "/": {
@ -33,7 +33,19 @@
}, },
"responses": { "responses": {
"200": { "200": {
"description": "See /generate or /generate_stream" "description": "Generated Text",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/GenerateResponse"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/StreamResponse"
}
}
}
}, },
"422": { "422": {
"description": "Input validation error", "description": "Input validation error",
@ -584,11 +596,73 @@
"type": "object", "type": "object",
"required": [ "required": [
"model_id", "model_id",
"model_dtype",
"model_device_type",
"max_concurrent_requests",
"max_best_of",
"max_stop_sequences",
"max_input_length",
"max_total_tokens",
"waiting_served_ratio",
"max_batch_total_tokens",
"max_waiting_tokens",
"validation_workers",
"version" "version"
], ],
"properties": { "properties": {
"docker_label": {
"type": "string",
"example": "null",
"nullable": true
},
"max_batch_total_tokens": {
"type": "integer",
"format": "int32",
"example": "32000",
"minimum": 0.0
},
"max_best_of": {
"type": "integer",
"example": "2",
"minimum": 0.0
},
"max_concurrent_requests": {
"type": "integer",
"description": "Router Parameters",
"example": "128",
"minimum": 0.0
},
"max_input_length": {
"type": "integer",
"example": "1024",
"minimum": 0.0
},
"max_stop_sequences": {
"type": "integer",
"example": "4",
"minimum": 0.0
},
"max_total_tokens": {
"type": "integer",
"example": "2048",
"minimum": 0.0
},
"max_waiting_tokens": {
"type": "integer",
"example": "20",
"minimum": 0.0
},
"model_device_type": {
"type": "string",
"example": "cuda"
},
"model_dtype": {
"type": "string",
"example": "torch.float16"
},
"model_id": { "model_id": {
"type": "string", "type": "string",
"description": "Model info",
"example": "bigscience/blomm-560m" "example": "bigscience/blomm-560m"
}, },
"model_pipeline_tag": { "model_pipeline_tag": {
@ -606,9 +680,20 @@
"example": "null", "example": "null",
"nullable": true "nullable": true
}, },
"validation_workers": {
"type": "integer",
"example": "2",
"minimum": 0.0
},
"version": { "version": {
"type": "string", "type": "string",
"description": "Router Info",
"example": "0.5.0" "example": "0.5.0"
},
"waiting_served_ratio": {
"type": "number",
"format": "float",
"example": "1.2"
} }
} }
}, },

View File

@ -1,9 +1,10 @@
[package] [package]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "0.6.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Launcher" description = "Text Generation Launcher"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[dependencies] [dependencies]
clap = { version = "4.1.4", features = ["derive", "env"] } clap = { version = "4.1.4", features = ["derive", "env"] }

View File

@ -1,10 +1,11 @@
[package] [package]
name = "text-generation-router" name = "text-generation-router"
version = "0.6.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Webserver" description = "Text Generation Webserver"
build = "build.rs" build = "build.rs"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[lib] [lib]
path = "src/lib.rs" path = "src/lib.rs"

View File

@ -1,7 +1,9 @@
[package] [package]
name = "text-generation-client" name = "text-generation-client"
version = "0.6.0" version.workspace = true
edition = "2021" edition.workspace = true
authors.workspace = true
homepage.workspace = true
[dependencies] [dependencies]
futures = "^0.3" futures = "^0.3"
@ -12,7 +14,6 @@ tokio = { version = "^1.25", features = ["sync"] }
tonic = "^0.8" tonic = "^0.8"
tower = "^0.4" tower = "^0.4"
tracing = "^0.1" tracing = "^0.1"
tracing-error = "^0.2"
[build-dependencies] [build-dependencies]
tonic-build = "0.8.4" tonic-build = "0.8.4"

View File

@ -9,6 +9,7 @@ use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::Path;
use std::time::Duration;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use text_generation_router::{server, HubModelInfo}; use text_generation_router::{server, HubModelInfo};
use tokenizers::{FromPretrainedParameters, Tokenizer}; use tokenizers::{FromPretrainedParameters, Tokenizer};
@ -125,7 +126,7 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async { .block_on(async {
init_logging(otlp_endpoint, json_output); init_logging(otlp_endpoint, json_output);
if let Some(max_batch_size) = max_batch_size{ if let Some(max_batch_size) = max_batch_size {
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead"); tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32; max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}"); tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
@ -145,7 +146,10 @@ fn main() -> Result<(), std::io::Error> {
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
}, },
false => get_model_info(&tokenizer_name, &revision, authorization_token).await, false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
}),
}; };
// if pipeline-tag == text-generation we default to return_full_text = true // if pipeline-tag == text-generation we default to return_full_text = true
@ -195,7 +199,7 @@ fn main() -> Result<(), std::io::Error> {
addr, addr,
cors_allow_origin, cors_allow_origin,
) )
.await; .await;
Ok(()) Ok(())
}) })
} }
@ -256,22 +260,24 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
} }
/// get model info from the Huggingface Hub /// get model info from the Huggingface Hub
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> HubModelInfo { pub async fn get_model_info(
model_id: &str,
revision: &str,
token: Option<String>,
) -> Option<HubModelInfo> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
// Poor man's urlencode // Poor man's urlencode
let revision = revision.replace("/", "%2F"); let revision = revision.replace('/', "%2F");
let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}"); let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
let mut builder = client.get(url); let mut builder = client.get(url).timeout(Duration::from_secs(5));
if let Some(token) = token { if let Some(token) = token {
builder = builder.bearer_auth(token); builder = builder.bearer_auth(token);
} }
let model_info = builder let response = builder.send().await.ok()?;
.send()
.await if response.status().is_success() {
.expect("Could not connect to hf.co") return serde_json::from_str(&response.text().await.ok()?).ok();
.text() }
.await None
.expect("error when retrieving model info from hf.co");
serde_json::from_str(&model_info).expect("unable to parse model info")
} }

View File

@ -37,7 +37,7 @@ use utoipa_swagger_ui::SwaggerUi;
path = "/", path = "/",
request_body = CompatGenerateRequest, request_body = CompatGenerateRequest,
responses( responses(
(status = 200, description = "See /generate or /generate_stream", (status = 200, description = "Generated Text",
content( content(
("application/json" = GenerateResponse), ("application/json" = GenerateResponse),
("text/event-stream" = StreamResponse), ("text/event-stream" = StreamResponse),