feat: Add AML deployment

This commit is contained in:
Olivier Dehaene 2022-10-15 20:21:50 +02:00
parent bf99afe916
commit bcb53903b8
7 changed files with 58 additions and 7 deletions

View File

@ -1 +1,2 @@
aml
router/target router/target

7
aml/README.md Normal file
View File

@ -0,0 +1,7 @@
```shell
docker build . -t db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
docker push db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
az ml online-endpoint create -f endpoint.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
az ml online-deployment create -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
```

39
aml/deployment.yaml Normal file
View File

@ -0,0 +1,39 @@
$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json
name: bloom-deployment
endpoint_name: bloom-inference
model:
name: bloom
path: ./bloom
model_mount_path: /var/azureml-model
environment_variables:
MODEL_BASE_PATH: /var/azureml-model/bloom
MODEL_NAME: bigscience/bloom
NUM_GPUS: 8
environment:
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
inference_config:
liveness_route:
port: 3000
path: /health
readiness_route:
port: 3000
path: /health
scoring_route:
port: 3000
path: /generate
instance_type: Standard_ND96amsr_A100_v4
request_settings:
request_timeout_ms: 90000
liveness_probe:
initial_delay: 300
timeout: 20
period: 60
success_threshold: 1
failure_threshold: 60
readiness_probe:
initial_delay: 300
timeout: 20
period: 60
success_threshold: 1
failure_threshold: 60
instance_count: 1

3
aml/endpoint.yaml Normal file
View File

@ -0,0 +1,3 @@
$schema: https://azuremlsdk2.blob.core.windows.net/latest/managedOnlineEndpoint.schema.json
name: bloom-inference
auth_mode: aml_token

View File

@ -1,7 +1,8 @@
use crate::{Batcher, ShardedClient, Validation}; use bloom_inference_client::ShardedClient;
use crate::{Batcher, Validation};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::routing::post; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use serde::Deserialize; use serde::Deserialize;
use std::net::SocketAddr; use std::net::SocketAddr;
@ -142,7 +143,7 @@ pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr)
let app = Router::new() let app = Router::new()
.route("/generate", post(generate)) .route("/generate", post(generate))
.layer(Extension(shared_state.clone())) .layer(Extension(shared_state.clone()))
.route("/health", post(liveness)) .route("/health", get(liveness))
.layer(Extension(shared_state.clone())); .layer(Extension(shared_state.clone()));
axum::Server::bind(&addr) axum::Server::bind(&addr)

View File

@ -9,7 +9,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_utils import no_init_weights from transformers.modeling_utils import no_init_weights
from bloom_inference.pb import generate_pb2 from bloom_inference.pb import generate_pb2
from bloom_inference.shard_model import shard_model, match_suffix from bloom_inference.prepare_weights import prepare_weights, match_suffix
from bloom_inference.utils import ( from bloom_inference.utils import (
StoppingCriteria, StoppingCriteria,
NextTokenChooser, NextTokenChooser,
@ -377,8 +377,8 @@ class BLOOMSharded(BLOOM):
# shard state_dict # shard state_dict
if self.master: if self.master:
# TODO @thomasw21 do some caching # TODO @thomasw21 do some caching
shard_state_dict_paths = shard_model( shard_state_dict_paths = prepare_weights(
model_name, shard_directory, tp_world_size=self.world_size, dtype=dtype model_name, shard_directory / "cache", shard_directory, tp_world_size=self.world_size
) )
shard_state_dict_paths = [ shard_state_dict_paths = [
str(path.absolute()) for path in shard_state_dict_paths str(path.absolute()) for path in shard_state_dict_paths

View File

@ -62,7 +62,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
if all(save_path.exists() for save_path in save_paths): if all(save_path.exists() for save_path in save_paths):
print("Weights are already prepared") print("Weights are already prepared")
return return save_paths
cache_path.mkdir(parents=True, exist_ok=True) cache_path.mkdir(parents=True, exist_ok=True)
if model_name == "bigscience/bloom-560m": if model_name == "bigscience/bloom-560m":