This commit is contained in:
OlivierDehaene 2023-07-01 19:25:41 +02:00 committed by GitHub
parent 2b53d71991
commit e28a809004
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 376 additions and 258 deletions

505
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,7 @@ members = [
]
[workspace.package]
version = "0.8.2"
version = "0.9.0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -1,5 +1,5 @@
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.70 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

View File

@ -84,7 +84,7 @@ model=bigscience/bloom-560m
num_shard=2
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.8 --model-id $model --num-shard $num_shard
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:0.9 --model-id $model --num-shard $num_shard
```
**Note:** To use GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 11.8 or higher.

View File

@ -1,15 +0,0 @@
# Azure ML endpoint
## Create all resources
```shell
az ml model create -f model.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
az ml online-endpoint create -f endpoint.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
az ml online-deployment create -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
```
## Update deployment
```shell
az ml online-deployment update -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
```

View File

@ -1,38 +0,0 @@
$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json
name: bloom-deployment
endpoint_name: bloom-inference
model: azureml:bloom-safetensors:1
model_mount_path: /var/azureml-model
environment_variables:
WEIGHTS_CACHE_OVERRIDE: /var/azureml-model/bloom-safetensors
MODEL_ID: bigscience/bloom
NUM_SHARD: 8
environment:
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.2.0
inference_config:
liveness_route:
port: 80
path: /health
readiness_route:
port: 80
path: /health
scoring_route:
port: 80
path: /generate
instance_type: Standard_ND96amsr_A100_v4
request_settings:
request_timeout_ms: 90000
max_concurrent_requests_per_instance: 256
liveness_probe:
initial_delay: 600
timeout: 90
period: 120
success_threshold: 1
failure_threshold: 5
readiness_probe:
initial_delay: 600
timeout: 90
period: 120
success_threshold: 1
failure_threshold: 5
instance_count: 1

View File

@ -1,3 +0,0 @@
$schema: https://azuremlsdk2.blob.core.windows.net/latest/managedOnlineEndpoint.schema.json
name: bloom-inference
auth_mode: key

View File

@ -1,3 +0,0 @@
$schema: https://azuremlschemas.azureedge.net/latest/model.schema.json
name: bloom-safetensors
path: /data/bloom-safetensors

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "0.8.2"
"version": "0.9.0"
},
"paths": {
"/": {
@ -270,6 +270,35 @@
}
}
},
"/health": {
"get": {
"tags": [
"Text Generation Inference"
],
"summary": "Health check method",
"description": "Health check method",
"operationId": "health",
"responses": {
"200": {
"description": "Everything is working fine"
},
"503": {
"description": "Text generation inference is down",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "unhealthy",
"error_type": "healthcheck"
}
}
}
}
}
}
},
"/info": {
"get": {
"tags": [

View File

@ -1040,14 +1040,18 @@ fn main() -> Result<(), LauncherError> {
return Ok(());
}
let mut webserver = spawn_webserver(args, shutdown.clone(), &shutdown_receiver)?;
let mut webserver =
spawn_webserver(args, shutdown.clone(), &shutdown_receiver).map_err(|err| {
shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?;
// Default exit code
let mut exit_code = Ok(());
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {rank} failed to start");
tracing::error!("Shard {rank} crashed");
if let Some(err) = err {
tracing::error!("{err}");
}

View File

@ -22,11 +22,11 @@ text-generation-client = { path = "client" }
clap = { version = "4.1.4", features = ["derive", "env"] }
flume = "0.10.14"
futures = "0.3.26"
metrics = "0.20.1"
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
metrics = "0.21.0"
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0"
opentelemetry = { version = "0.19.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.12.0"
rand = "0.8.5"
reqwest = { version = "0.11.14", features = [] }
serde = "1.0.152"
@ -36,7 +36,7 @@ tokenizers = "0.13.3"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tower-http = { version = "0.4.0", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.18.0"
tracing-opentelemetry = "0.19.0"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }

View File

@ -11,10 +11,10 @@ grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.11"
thiserror = "^1.0"
tokio = { version = "^1.25", features = ["sync"] }
tonic = "^0.8"
tonic = "^0.9"
tower = "^0.4"
tracing = "^0.1"
[build-dependencies]
tonic-build = "0.8.4"
tonic-build = "0.9.2"
prost-build = "0.11.6"

View File

@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
opentelemetry = "0.18.0"
tonic = "^0.8"
opentelemetry = "^0.19"
tonic = "^0.9"
tracing = "^0.1"
tracing-opentelemetry = "0.18.0"
tracing-opentelemetry = "^0.19"

View File

@ -532,6 +532,7 @@ pub async fn run(
#[derive(OpenApi)]
#[openapi(
paths(
health,
get_model_info,
compat_generate,
generate,

View File

@ -1,3 +1,3 @@
[toolchain]
channel = "1.69.0"
channel = "1.70.0"
components = ["rustfmt", "clippy"]

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-server"
version = "0.8.2"
version = "0.9.0"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]