diff --git a/.dockerignore b/.dockerignore index d2704bf1..38e8f824 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,2 @@ +aml router/target \ No newline at end of file diff --git a/aml/README.md b/aml/README.md new file mode 100644 index 00000000..c38f9fef --- /dev/null +++ b/aml/README.md @@ -0,0 +1,7 @@ +```shell +docker build . -t db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1 +docker push db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1 + +az ml online-endpoint create -f endpoint.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace +az ml online-deployment create -f deployment.yaml -g HuggingFace-BLOOM-ModelPage -w HuggingFace +``` \ No newline at end of file diff --git a/aml/deployment.yaml b/aml/deployment.yaml new file mode 100644 index 00000000..f6b55faa --- /dev/null +++ b/aml/deployment.yaml @@ -0,0 +1,39 @@ +$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json +name: bloom-deployment +endpoint_name: bloom-inference +model: + name: bloom + path: ./bloom +model_mount_path: /var/azureml-model +environment_variables: + MODEL_BASE_PATH: /var/azureml-model/bloom + MODEL_NAME: bigscience/bloom + NUM_GPUS: 8 +environment: + image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1 + inference_config: + liveness_route: + port: 3000 + path: /health + readiness_route: + port: 3000 + path: /health + scoring_route: + port: 3000 + path: /generate +instance_type: Standard_ND96amsr_A100_v4 +request_settings: + request_timeout_ms: 90000 +liveness_probe: + initial_delay: 300 + timeout: 20 + period: 60 + success_threshold: 1 + failure_threshold: 60 +readiness_probe: + initial_delay: 300 + timeout: 20 + period: 60 + success_threshold: 1 + failure_threshold: 60 +instance_count: 1 diff --git a/aml/endpoint.yaml b/aml/endpoint.yaml new file mode 100644 index 00000000..934b31ad --- /dev/null +++ b/aml/endpoint.yaml @@ -0,0 +1,3 @@ +$schema: https://azuremlsdk2.blob.core.windows.net/latest/managedOnlineEndpoint.schema.json +name: bloom-inference +auth_mode: aml_token diff --git a/router/src/server.rs b/router/src/server.rs index 1cbec333..f113ce44 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,7 +1,8 @@ -use crate::{Batcher, ShardedClient, Validation}; +use bloom_inference_client::ShardedClient; +use crate::{Batcher, Validation}; use axum::extract::Extension; use axum::http::StatusCode; -use axum::routing::post; +use axum::routing::{get, post}; use axum::{Json, Router}; use serde::Deserialize; use std::net::SocketAddr; @@ -142,7 +143,7 @@ pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) let app = Router::new() .route("/generate", post(generate)) .layer(Extension(shared_state.clone())) - .route("/health", post(liveness)) + .route("/health", get(liveness)) .layer(Extension(shared_state.clone())); axum::Server::bind(&addr) diff --git a/server/bloom_inference/model.py b/server/bloom_inference/model.py index 40d69e8b..8b0e7ab0 100644 --- a/server/bloom_inference/model.py +++ b/server/bloom_inference/model.py @@ -9,7 +9,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers.modeling_utils import no_init_weights from bloom_inference.pb import generate_pb2 -from bloom_inference.shard_model import shard_model, match_suffix +from bloom_inference.prepare_weights import prepare_weights, match_suffix from bloom_inference.utils import ( StoppingCriteria, NextTokenChooser, @@ -377,8 +377,8 @@ class BLOOMSharded(BLOOM): # shard state_dict if self.master: # TODO @thomasw21 do some caching - shard_state_dict_paths = shard_model( - model_name, shard_directory, tp_world_size=self.world_size, dtype=dtype + shard_state_dict_paths = prepare_weights( + model_name, shard_directory / "cache", shard_directory, tp_world_size=self.world_size ) shard_state_dict_paths = [ str(path.absolute()) for path in shard_state_dict_paths diff --git a/server/bloom_inference/prepare_weights.py b/server/bloom_inference/prepare_weights.py index 5fa2be51..7cf3dbb5 100644 --- a/server/bloom_inference/prepare_weights.py +++ b/server/bloom_inference/prepare_weights.py @@ -62,7 +62,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world if all(save_path.exists() for save_path in save_paths): print("Weights are already prepared") - return + return save_paths cache_path.mkdir(parents=True, exist_ok=True) if model_name == "bigscience/bloom-560m":