feat: Add AML deployment
This commit is contained in:
parent
bf99afe916
commit
bcb53903b8
|
@ -1 +1,2 @@
|
|||
aml
|
||||
router/target
|
|
@ -0,0 +1,7 @@
|
|||
```shell
|
||||
docker build . -t db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
|
||||
docker push db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
|
||||
|
||||
az ml online-endpoint create -f endpoint.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
|
||||
az ml online-deployment create -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
|
||||
```
|
|
@ -0,0 +1,39 @@
|
|||
$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json
|
||||
name: bloom-deployment
|
||||
endpoint_name: bloom-inference
|
||||
model:
|
||||
name: bloom
|
||||
path: ./bloom
|
||||
model_mount_path: /var/azureml-model
|
||||
environment_variables:
|
||||
MODEL_BASE_PATH: /var/azureml-model/bloom
|
||||
MODEL_NAME: bigscience/bloom
|
||||
NUM_GPUS: 8
|
||||
environment:
|
||||
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
|
||||
inference_config:
|
||||
liveness_route:
|
||||
port: 3000
|
||||
path: /health
|
||||
readiness_route:
|
||||
port: 3000
|
||||
path: /health
|
||||
scoring_route:
|
||||
port: 3000
|
||||
path: /generate
|
||||
instance_type: Standard_ND96amsr_A100_v4
|
||||
request_settings:
|
||||
request_timeout_ms: 90000
|
||||
liveness_probe:
|
||||
initial_delay: 300
|
||||
timeout: 20
|
||||
period: 60
|
||||
success_threshold: 1
|
||||
failure_threshold: 60
|
||||
readiness_probe:
|
||||
initial_delay: 300
|
||||
timeout: 20
|
||||
period: 60
|
||||
success_threshold: 1
|
||||
failure_threshold: 60
|
||||
instance_count: 1
|
|
@ -0,0 +1,3 @@
|
|||
$schema: https://azuremlsdk2.blob.core.windows.net/latest/managedOnlineEndpoint.schema.json
|
||||
name: bloom-inference
|
||||
auth_mode: aml_token
|
|
@ -1,7 +1,8 @@
|
|||
use crate::{Batcher, ShardedClient, Validation};
|
||||
use bloom_inference_client::ShardedClient;
|
||||
use crate::{Batcher, Validation};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::StatusCode;
|
||||
use axum::routing::post;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use serde::Deserialize;
|
||||
use std::net::SocketAddr;
|
||||
|
@ -142,7 +143,7 @@ pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr)
|
|||
let app = Router::new()
|
||||
.route("/generate", post(generate))
|
||||
.layer(Extension(shared_state.clone()))
|
||||
.route("/health", post(liveness))
|
||||
.route("/health", get(liveness))
|
||||
.layer(Extension(shared_state.clone()));
|
||||
|
||||
axum::Server::bind(&addr)
|
||||
|
|
|
@ -9,7 +9,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
|||
from transformers.modeling_utils import no_init_weights
|
||||
|
||||
from bloom_inference.pb import generate_pb2
|
||||
from bloom_inference.shard_model import shard_model, match_suffix
|
||||
from bloom_inference.prepare_weights import prepare_weights, match_suffix
|
||||
from bloom_inference.utils import (
|
||||
StoppingCriteria,
|
||||
NextTokenChooser,
|
||||
|
@ -377,8 +377,8 @@ class BLOOMSharded(BLOOM):
|
|||
# shard state_dict
|
||||
if self.master:
|
||||
# TODO @thomasw21 do some caching
|
||||
shard_state_dict_paths = shard_model(
|
||||
model_name, shard_directory, tp_world_size=self.world_size, dtype=dtype
|
||||
shard_state_dict_paths = prepare_weights(
|
||||
model_name, shard_directory / "cache", shard_directory, tp_world_size=self.world_size
|
||||
)
|
||||
shard_state_dict_paths = [
|
||||
str(path.absolute()) for path in shard_state_dict_paths
|
||||
|
|
|
@ -62,7 +62,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
|||
|
||||
if all(save_path.exists() for save_path in save_paths):
|
||||
print("Weights are already prepared")
|
||||
return
|
||||
return save_paths
|
||||
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
if model_name == "bigscience/bloom-560m":
|
||||
|
|
Loading…
Reference in New Issue