feat: Add AML deployment
This commit is contained in:
parent
bf99afe916
commit
bcb53903b8
|
@ -1 +1,2 @@
|
||||||
|
aml
|
||||||
router/target
|
router/target
|
|
@ -0,0 +1,7 @@
|
||||||
|
```shell
|
||||||
|
docker build . -t db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
|
||||||
|
docker push db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
|
||||||
|
|
||||||
|
az ml online-endpoint create -f endpoint.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
|
||||||
|
az ml online-deployment create -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace
|
||||||
|
```
|
|
@ -0,0 +1,39 @@
|
||||||
|
$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json
|
||||||
|
name: bloom-deployment
|
||||||
|
endpoint_name: bloom-inference
|
||||||
|
model:
|
||||||
|
name: bloom
|
||||||
|
path: ./bloom
|
||||||
|
model_mount_path: /var/azureml-model
|
||||||
|
environment_variables:
|
||||||
|
MODEL_BASE_PATH: /var/azureml-model/bloom
|
||||||
|
MODEL_NAME: bigscience/bloom
|
||||||
|
NUM_GPUS: 8
|
||||||
|
environment:
|
||||||
|
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
|
||||||
|
inference_config:
|
||||||
|
liveness_route:
|
||||||
|
port: 3000
|
||||||
|
path: /health
|
||||||
|
readiness_route:
|
||||||
|
port: 3000
|
||||||
|
path: /health
|
||||||
|
scoring_route:
|
||||||
|
port: 3000
|
||||||
|
path: /generate
|
||||||
|
instance_type: Standard_ND96amsr_A100_v4
|
||||||
|
request_settings:
|
||||||
|
request_timeout_ms: 90000
|
||||||
|
liveness_probe:
|
||||||
|
initial_delay: 300
|
||||||
|
timeout: 20
|
||||||
|
period: 60
|
||||||
|
success_threshold: 1
|
||||||
|
failure_threshold: 60
|
||||||
|
readiness_probe:
|
||||||
|
initial_delay: 300
|
||||||
|
timeout: 20
|
||||||
|
period: 60
|
||||||
|
success_threshold: 1
|
||||||
|
failure_threshold: 60
|
||||||
|
instance_count: 1
|
|
@ -0,0 +1,3 @@
|
||||||
|
$schema: https://azuremlsdk2.blob.core.windows.net/latest/managedOnlineEndpoint.schema.json
|
||||||
|
name: bloom-inference
|
||||||
|
auth_mode: aml_token
|
|
@ -1,7 +1,8 @@
|
||||||
use crate::{Batcher, ShardedClient, Validation};
|
use bloom_inference_client::ShardedClient;
|
||||||
|
use crate::{Batcher, Validation};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use axum::routing::post;
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
@ -142,7 +143,7 @@ pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr)
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.layer(Extension(shared_state.clone()))
|
.layer(Extension(shared_state.clone()))
|
||||||
.route("/health", post(liveness))
|
.route("/health", get(liveness))
|
||||||
.layer(Extension(shared_state.clone()));
|
.layer(Extension(shared_state.clone()));
|
||||||
|
|
||||||
axum::Server::bind(&addr)
|
axum::Server::bind(&addr)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||||
from transformers.modeling_utils import no_init_weights
|
from transformers.modeling_utils import no_init_weights
|
||||||
|
|
||||||
from bloom_inference.pb import generate_pb2
|
from bloom_inference.pb import generate_pb2
|
||||||
from bloom_inference.shard_model import shard_model, match_suffix
|
from bloom_inference.prepare_weights import prepare_weights, match_suffix
|
||||||
from bloom_inference.utils import (
|
from bloom_inference.utils import (
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
NextTokenChooser,
|
NextTokenChooser,
|
||||||
|
@ -377,8 +377,8 @@ class BLOOMSharded(BLOOM):
|
||||||
# shard state_dict
|
# shard state_dict
|
||||||
if self.master:
|
if self.master:
|
||||||
# TODO @thomasw21 do some caching
|
# TODO @thomasw21 do some caching
|
||||||
shard_state_dict_paths = shard_model(
|
shard_state_dict_paths = prepare_weights(
|
||||||
model_name, shard_directory, tp_world_size=self.world_size, dtype=dtype
|
model_name, shard_directory / "cache", shard_directory, tp_world_size=self.world_size
|
||||||
)
|
)
|
||||||
shard_state_dict_paths = [
|
shard_state_dict_paths = [
|
||||||
str(path.absolute()) for path in shard_state_dict_paths
|
str(path.absolute()) for path in shard_state_dict_paths
|
||||||
|
|
|
@ -62,7 +62,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
||||||
|
|
||||||
if all(save_path.exists() for save_path in save_paths):
|
if all(save_path.exists() for save_path in save_paths):
|
||||||
print("Weights are already prepared")
|
print("Weights are already prepared")
|
||||||
return
|
return save_paths
|
||||||
|
|
||||||
cache_path.mkdir(parents=True, exist_ok=True)
|
cache_path.mkdir(parents=True, exist_ok=True)
|
||||||
if model_name == "bigscience/bloom-560m":
|
if model_name == "bigscience/bloom-560m":
|
||||||
|
|
Loading…
Reference in New Issue