feat(docker): add benchmarking tool to docker image (#298)
This commit is contained in:
parent
926fd9a010
commit
e250282213
|
@ -134,6 +134,17 @@ version = "1.1.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "average"
|
||||||
|
version = "0.13.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "843ec791d3f24503bbf72bbd5e49a3ab4dbb4bcd0a8ef6b0c908efa73caa27b1"
|
||||||
|
dependencies = [
|
||||||
|
"easy-cast",
|
||||||
|
"float-ord",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum"
|
name = "axum"
|
||||||
version = "0.6.15"
|
version = "0.6.15"
|
||||||
|
@ -293,6 +304,12 @@ dependencies = [
|
||||||
"zip",
|
"zip",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cassowary"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.0.79"
|
version = "1.0.79"
|
||||||
|
@ -461,6 +478,31 @@ dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossterm"
|
||||||
|
version = "0.26.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a84cda67535339806297f1b331d6dd6320470d2a0fe65381e79ee9e156dd3d13"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"crossterm_winapi",
|
||||||
|
"libc",
|
||||||
|
"mio",
|
||||||
|
"parking_lot",
|
||||||
|
"signal-hook",
|
||||||
|
"signal-hook-mio",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossterm_winapi"
|
||||||
|
version = "0.9.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2ae1b35a484aa10e07fe0638d02301c5ad24de82d310ccbd2f3693da5f09bf1c"
|
||||||
|
dependencies = [
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crypto-common"
|
name = "crypto-common"
|
||||||
version = "0.1.6"
|
version = "0.1.6"
|
||||||
|
@ -591,6 +633,15 @@ dependencies = [
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "easy-cast"
|
||||||
|
version = "0.4.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4bd102ee8c418348759919b83b81cdbdc933ffe29740b903df448b4bafaa348e"
|
||||||
|
dependencies = [
|
||||||
|
"libm",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "either"
|
name = "either"
|
||||||
version = "1.8.1"
|
version = "1.8.1"
|
||||||
|
@ -679,6 +730,12 @@ dependencies = [
|
||||||
"miniz_oxide",
|
"miniz_oxide",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "float-ord"
|
||||||
|
version = "0.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "float_eq"
|
name = "float_eq"
|
||||||
version = "1.0.1"
|
version = "1.0.1"
|
||||||
|
@ -1165,6 +1222,12 @@ version = "0.2.141"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5"
|
checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libm"
|
||||||
|
version = "0.2.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "linux-raw-sys"
|
name = "linux-raw-sys"
|
||||||
version = "0.3.1"
|
version = "0.3.1"
|
||||||
|
@ -1447,6 +1510,16 @@ dependencies = [
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-traits"
|
||||||
|
version = "0.2.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"libm",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num_cpus"
|
name = "num_cpus"
|
||||||
version = "1.15.0"
|
version = "1.15.0"
|
||||||
|
@ -1903,6 +1976,19 @@ dependencies = [
|
||||||
"getrandom",
|
"getrandom",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ratatui"
|
||||||
|
version = "0.20.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dcc0d032bccba900ee32151ec0265667535c230169f5a011154cdcd984e16829"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"cassowary",
|
||||||
|
"crossterm",
|
||||||
|
"unicode-segmentation",
|
||||||
|
"unicode-width",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "raw-cpuid"
|
name = "raw-cpuid"
|
||||||
version = "10.7.0"
|
version = "10.7.0"
|
||||||
|
@ -2252,6 +2338,27 @@ dependencies = [
|
||||||
"dirs",
|
"dirs",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "signal-hook"
|
||||||
|
version = "0.3.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "732768f1176d21d09e076c23a93123d40bba92d50c4058da34d45c8de8e682b9"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"signal-hook-registry",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "signal-hook-mio"
|
||||||
|
version = "0.2.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"mio",
|
||||||
|
"signal-hook",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "signal-hook-registry"
|
name = "signal-hook-registry"
|
||||||
version = "1.4.1"
|
version = "1.4.1"
|
||||||
|
@ -2407,9 +2514,28 @@ dependencies = [
|
||||||
"windows-sys 0.45.0",
|
"windows-sys 0.45.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "text-generation-benchmark"
|
||||||
|
version = "0.7.0-dev"
|
||||||
|
dependencies = [
|
||||||
|
"average",
|
||||||
|
"clap",
|
||||||
|
"crossterm",
|
||||||
|
"float-ord",
|
||||||
|
"ratatui",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"text-generation-client",
|
||||||
|
"thiserror",
|
||||||
|
"tokenizers",
|
||||||
|
"tokio",
|
||||||
|
"tracing",
|
||||||
|
"tracing-subscriber",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "0.6.0"
|
version = "0.7.0-dev"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"grpc-metadata",
|
"grpc-metadata",
|
||||||
|
@ -2421,12 +2547,11 @@ dependencies = [
|
||||||
"tonic-build",
|
"tonic-build",
|
||||||
"tower",
|
"tower",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-error",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "0.6.0"
|
version = "0.7.0-dev"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
|
@ -2442,7 +2567,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "0.6.0"
|
version = "0.7.0-dev"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
|
@ -2803,16 +2928,6 @@ dependencies = [
|
||||||
"valuable",
|
"valuable",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tracing-error"
|
|
||||||
version = "0.2.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "d686ec1c0f384b1277f097b2f279a2ecc11afe8c133c1aabf036a27cb4cd206e"
|
|
||||||
dependencies = [
|
|
||||||
"tracing",
|
|
||||||
"tracing-subscriber",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing-futures"
|
name = "tracing-futures"
|
||||||
version = "0.2.5"
|
version = "0.2.5"
|
||||||
|
|
10
Cargo.toml
10
Cargo.toml
|
@ -1,13 +1,17 @@
|
||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
|
"benchmark",
|
||||||
"router",
|
"router",
|
||||||
"router/client",
|
"router/client",
|
||||||
"router/grpc-metadata",
|
"router/grpc-metadata",
|
||||||
"launcher"
|
"launcher"
|
||||||
]
|
]
|
||||||
exclude = [
|
|
||||||
"benchmark"
|
[workspace.package]
|
||||||
]
|
version = "0.7.0-dev"
|
||||||
|
edition = "2021"
|
||||||
|
authors = ["Olivier Dehaene"]
|
||||||
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
debug = 1
|
debug = 1
|
||||||
|
|
|
@ -164,6 +164,8 @@ RUN cd server && \
|
||||||
pip install -r requirements.txt && \
|
pip install -r requirements.txt && \
|
||||||
pip install ".[bnb, accelerate]" --no-cache-dir
|
pip install ".[bnb, accelerate]" --no-cache-dir
|
||||||
|
|
||||||
|
# Install benchmarker
|
||||||
|
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
# Install router
|
# Install router
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
||||||
# Install launcher
|
# Install launcher
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
target
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,15 +1,10 @@
|
||||||
[package]
|
[package]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
authors = ["Olivier Dehaene"]
|
|
||||||
description = "Text Generation Benchmarking tool"
|
description = "Text Generation Benchmarking tool"
|
||||||
|
version.workspace = true
|
||||||
[profile.release]
|
edition.workspace = true
|
||||||
debug = 1
|
authors.workspace = true
|
||||||
incremental = true
|
homepage.workspace = true
|
||||||
lto = "off"
|
|
||||||
panic = "abort"
|
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
path = "src/lib.rs"
|
path = "src/lib.rs"
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
[toolchain]
|
|
||||||
channel = "1.67.0"
|
|
||||||
components = ["rustfmt", "clippy"]
|
|
|
@ -10,7 +10,7 @@
|
||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "0.6.0"
|
"version": "0.7.0-dev"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
|
@ -33,7 +33,19 @@
|
||||||
},
|
},
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
"description": "See /generate or /generate_stream"
|
"description": "Generated Text",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/GenerateResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text/event-stream": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/StreamResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"422": {
|
"422": {
|
||||||
"description": "Input validation error",
|
"description": "Input validation error",
|
||||||
|
@ -584,11 +596,73 @@
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
"model_id",
|
"model_id",
|
||||||
|
"model_dtype",
|
||||||
|
"model_device_type",
|
||||||
|
"max_concurrent_requests",
|
||||||
|
"max_best_of",
|
||||||
|
"max_stop_sequences",
|
||||||
|
"max_input_length",
|
||||||
|
"max_total_tokens",
|
||||||
|
"waiting_served_ratio",
|
||||||
|
"max_batch_total_tokens",
|
||||||
|
"max_waiting_tokens",
|
||||||
|
"validation_workers",
|
||||||
"version"
|
"version"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"docker_label": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "null",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"max_batch_total_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"example": "32000",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"max_best_of": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "2",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"max_concurrent_requests": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Router Parameters",
|
||||||
|
"example": "128",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"max_input_length": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "1024",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"max_stop_sequences": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "4",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"max_total_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "2048",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"max_waiting_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "20",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
|
"model_device_type": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "cuda"
|
||||||
|
},
|
||||||
|
"model_dtype": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "torch.float16"
|
||||||
|
},
|
||||||
"model_id": {
|
"model_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"description": "Model info",
|
||||||
"example": "bigscience/blomm-560m"
|
"example": "bigscience/blomm-560m"
|
||||||
},
|
},
|
||||||
"model_pipeline_tag": {
|
"model_pipeline_tag": {
|
||||||
|
@ -606,9 +680,20 @@
|
||||||
"example": "null",
|
"example": "null",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
|
"validation_workers": {
|
||||||
|
"type": "integer",
|
||||||
|
"example": "2",
|
||||||
|
"minimum": 0.0
|
||||||
|
},
|
||||||
"version": {
|
"version": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"description": "Router Info",
|
||||||
"example": "0.5.0"
|
"example": "0.5.0"
|
||||||
|
},
|
||||||
|
"waiting_served_ratio": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"example": "1.2"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
[package]
|
[package]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "0.6.0"
|
|
||||||
edition = "2021"
|
|
||||||
authors = ["Olivier Dehaene"]
|
|
||||||
description = "Text Generation Launcher"
|
description = "Text Generation Launcher"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
authors.workspace = true
|
||||||
|
homepage.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
clap = { version = "4.1.4", features = ["derive", "env"] }
|
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
[package]
|
[package]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "0.6.0"
|
|
||||||
edition = "2021"
|
|
||||||
authors = ["Olivier Dehaene"]
|
|
||||||
description = "Text Generation Webserver"
|
description = "Text Generation Webserver"
|
||||||
build = "build.rs"
|
build = "build.rs"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
authors.workspace = true
|
||||||
|
homepage.workspace = true
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
path = "src/lib.rs"
|
path = "src/lib.rs"
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
[package]
|
[package]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "0.6.0"
|
version.workspace = true
|
||||||
edition = "2021"
|
edition.workspace = true
|
||||||
|
authors.workspace = true
|
||||||
|
homepage.workspace = true
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
futures = "^0.3"
|
futures = "^0.3"
|
||||||
|
@ -12,7 +14,6 @@ tokio = { version = "^1.25", features = ["sync"] }
|
||||||
tonic = "^0.8"
|
tonic = "^0.8"
|
||||||
tower = "^0.4"
|
tower = "^0.4"
|
||||||
tracing = "^0.1"
|
tracing = "^0.1"
|
||||||
tracing-error = "^0.2"
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
tonic-build = "0.8.4"
|
tonic-build = "0.8.4"
|
||||||
|
|
|
@ -9,6 +9,7 @@ use opentelemetry::{global, KeyValue};
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::time::Duration;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use text_generation_router::{server, HubModelInfo};
|
use text_generation_router::{server, HubModelInfo};
|
||||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
|
@ -125,7 +126,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
.block_on(async {
|
.block_on(async {
|
||||||
init_logging(otlp_endpoint, json_output);
|
init_logging(otlp_endpoint, json_output);
|
||||||
|
|
||||||
if let Some(max_batch_size) = max_batch_size{
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
|
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
|
||||||
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
|
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
|
||||||
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
|
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
|
||||||
|
@ -145,7 +146,10 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
sha: None,
|
sha: None,
|
||||||
pipeline_tag: None,
|
pipeline_tag: None,
|
||||||
},
|
},
|
||||||
false => get_model_info(&tokenizer_name, &revision, authorization_token).await,
|
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({
|
||||||
|
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||||
|
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
|
||||||
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||||
|
@ -256,22 +260,24 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// get model info from the Huggingface Hub
|
/// get model info from the Huggingface Hub
|
||||||
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> HubModelInfo {
|
pub async fn get_model_info(
|
||||||
|
model_id: &str,
|
||||||
|
revision: &str,
|
||||||
|
token: Option<String>,
|
||||||
|
) -> Option<HubModelInfo> {
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
// Poor man's urlencode
|
// Poor man's urlencode
|
||||||
let revision = revision.replace("/", "%2F");
|
let revision = revision.replace('/', "%2F");
|
||||||
let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
|
let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
|
||||||
let mut builder = client.get(url);
|
let mut builder = client.get(url).timeout(Duration::from_secs(5));
|
||||||
if let Some(token) = token {
|
if let Some(token) = token {
|
||||||
builder = builder.bearer_auth(token);
|
builder = builder.bearer_auth(token);
|
||||||
}
|
}
|
||||||
|
|
||||||
let model_info = builder
|
let response = builder.send().await.ok()?;
|
||||||
.send()
|
|
||||||
.await
|
if response.status().is_success() {
|
||||||
.expect("Could not connect to hf.co")
|
return serde_json::from_str(&response.text().await.ok()?).ok();
|
||||||
.text()
|
}
|
||||||
.await
|
None
|
||||||
.expect("error when retrieving model info from hf.co");
|
|
||||||
serde_json::from_str(&model_info).expect("unable to parse model info")
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ use utoipa_swagger_ui::SwaggerUi;
|
||||||
path = "/",
|
path = "/",
|
||||||
request_body = CompatGenerateRequest,
|
request_body = CompatGenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "See /generate or /generate_stream",
|
(status = 200, description = "Generated Text",
|
||||||
content(
|
content(
|
||||||
("application/json" = GenerateResponse),
|
("application/json" = GenerateResponse),
|
||||||
("text/event-stream" = StreamResponse),
|
("text/event-stream" = StreamResponse),
|
||||||
|
|
Loading…
Reference in New Issue