feat: Docker image

This commit is contained in:
Olivier Dehaene 2022-10-14 15:56:21 +02:00
parent 39df4d9975
commit bf99afe916
13 changed files with 265 additions and 180 deletions

1
.dockerignore Normal file
View File

@ -0,0 +1 @@
router/target

59
Dockerfile Normal file
View File

@ -0,0 +1,59 @@
FROM rust:1.64 as builder
WORKDIR /usr/src
COPY proto proto
COPY router router
WORKDIR /usr/src/router
RUN cargo install --path .
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \
MODEL_BASE_PATH=/var/azureml-model \
MODEL_NAME=bigscience/bloom \
NUM_GPUS=8 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
CONDA_DEFAULT_ENV=text-generation \
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y unzip wget libssl-dev && rm -rf /var/lib/apt/lists/*
RUN cd ~ && \
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
chmod +x Miniconda3-latest-Linux-x86_64.sh && \
bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \
conda create -n text-generation python=3.9 -y
# Install specific version of torch
RUN /opt/miniconda/envs/text-generation/bin/pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
# Install specific version of transformers
RUN wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
cd transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 && \
/opt/miniconda/envs/text-generation/bin/python setup.py install
WORKDIR /usr/src
# Install server
COPY server server
RUN cd server && \
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
# Install router
COPY --from=builder /usr/local/cargo/bin/bloom-inference /usr/local/bin/bloom-inference
COPY run.sh .
RUN chmod +x run.sh
CMD ["./run.sh"]

View File

@ -43,8 +43,6 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
## TODO:
- [ ] Improve model download
- Store "shardable" layers separately and layer by layer
- [ ] Add batching args to router CLI
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
- [ ] Add tests

View File

@ -0,0 +1,3 @@
[toolchain]
channel = "1.64.0"
components = ["rustfmt", "clippy"]

View File

@ -83,7 +83,12 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
cached_batch = match batch_size {
size if size > 16 => {
wrap_future(client.generate_until_finished_with_cache(batches), request_ids, &db).await
wrap_future(
client.generate_until_finished_with_cache(batches),
request_ids,
&db,
)
.await
}
_ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
};

View File

@ -1,5 +1,5 @@
use std::net::SocketAddr;
use bloom_inference_client::ShardedClient;
use std::net::SocketAddr;
use std::time::Duration;
use tokenizers::Tokenizer;
@ -37,7 +37,7 @@ fn main() -> Result<(), std::io::Error> {
.expect("Unable to clear cache");
tracing::info!("Connected");
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
server::run(sharded_client, tokenizer, addr).await;
Ok(())

View File

@ -1,10 +1,10 @@
use std::net::SocketAddr;
use axum::{Router, Json};
use axum::http::StatusCode;
use axum::extract::Extension;
use axum::routing::post;
use crate::{Batcher, ShardedClient, Validation};
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::routing::post;
use axum::{Json, Router};
use serde::Deserialize;
use std::net::SocketAddr;
use tokenizers::Tokenizer;
use tokio::time::Instant;
use tracing::instrument;
@ -60,6 +60,31 @@ pub(crate) struct GenerateRequest {
pub parameters: GenerateParameters,
}
#[instrument(skip(state), fields(time, time_per_token))]
async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
let output = state
.infer
.infer(
1,
GenerateRequest {
inputs: "liveness".to_string(),
parameters: GenerateParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
do_sample: false,
max_new_tokens: 1,
},
},
)
.await;
match output {
Ok(_) => Ok(()),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
#[instrument(skip(state), fields(time, time_per_token))]
async fn generate(
state: Extension<ServerState>,
@ -67,14 +92,16 @@ async fn generate(
) -> Result<Json<serde_json::Value>, StatusCode> {
let start = Instant::now();
let (input_length, validated_request) = match state.validation
let (input_length, validated_request) = match state
.validation
.validate(GenerateRequest {
inputs: req.inputs.clone(),
parameters: req.parameters.clone(),
})
.await {
.await
{
Ok(result) => result,
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR)
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
};
let output = state.infer.infer(input_length, validated_request).await;
@ -102,11 +129,7 @@ struct ServerState {
infer: Batcher,
}
pub async fn run(
client: ShardedClient,
tokenizer: Tokenizer,
addr: SocketAddr,
) {
pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
client.clear_cache().await.expect("Unable to clear cache");
tracing::info!("Connected");
@ -114,13 +137,16 @@ pub async fn run(
let validation = Validation::new(tokenizer);
let shared_state = ServerState {
validation,
infer,
};
let shared_state = ServerState { validation, infer };
let app = Router::new().route("/generate", post(generate)).layer(Extension(shared_state));
let app = Router::new()
.route("/generate", post(generate))
.layer(Extension(shared_state.clone()))
.route("/health", post(liveness))
.layer(Extension(shared_state.clone()));
axum::Server::bind(&addr)
.serve(app.into_make_service()).await.unwrap();
.serve(app.into_make_service())
.await
.unwrap();
}

21
run.sh Executable file
View File

@ -0,0 +1,21 @@
#!/usr/bin/env bash
server_cmd="python server/bloom_inference/main.py $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
$server_cmd &
FILE=/tmp/bloom-inference-0
while :
do
if test -S "$FILE"; then
echo "Text Generation Python gRPC server started"
break
else
echo "Waiting for Text Generation Python gRPC server to start"
sleep 5
fi
done
sleep 1
exec "bloom-inference"

View File

@ -220,12 +220,14 @@ class BLOOM:
def __init__(self, model_name: str):
if torch.cuda.is_available():
self.device = torch.device("cuda")
dtype = torch.bfloat16
else:
self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.model = (
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device)
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device).to(dtype)
)
self.num_heads = self.model.base_model.num_heads
@ -427,7 +429,8 @@ class BLOOMSharded(BLOOM):
if do_transpose:
state_dict[key] = state_dict[key].transpose(1, 0).contiguous()
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, strict=False)
model.tie_weights()
self.model = model.to(self.device).eval()
self.num_heads = config.n_head // self.process_group.size()
torch.distributed.barrier(group=self.process_group)

View File

@ -1,18 +1,62 @@
import torch
import os
import tempfile
import json
from typing import BinaryIO
from joblib import Parallel, delayed
from functools import partial
from pathlib import Path
from tqdm import tqdm
MODEL_NAME = "bigscience/bloom"
from huggingface_hub import hf_hub_url
from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
def match_suffix(text, suffix):
return text[-len(suffix) :] == suffix
return text[-len(suffix):] == suffix
def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
def http_get(
url: str,
temp_file: BinaryIO,
*,
timeout=10.0,
max_retries=0,
):
"""
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
"""
r = _request_wrapper(
method="GET",
url=url,
stream=True,
timeout=timeout,
max_retries=max_retries,
)
hf_raise_for_status(r)
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
temp_file.write(chunk)
def cache_download_url(url: str, root_dir: Path):
filename = root_dir / url.split("/")[-1]
if not filename.exists():
temp_file_manager = partial(
tempfile.NamedTemporaryFile, mode="wb", dir=root_dir, delete=False
)
with temp_file_manager() as temp_file:
http_get(url, temp_file)
os.replace(temp_file.name, filename)
return filename
def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world_size: int):
save_paths = [
save_path / f"{MODEL_NAME}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
for tp_rank in range(tp_world_size)
]
@ -20,45 +64,67 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
print("Weights are already prepared")
return
cache_path.mkdir(parents=True, exist_ok=True)
if model_name == "bigscience/bloom-560m":
url = hf_hub_url(model_name, filename="pytorch_model.bin")
cache_download_url(url, cache_path)
elif model_name == "bigscience/bloom":
url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json")
index_path = cache_download_url(url, cache_path)
with index_path.open("r") as f:
index = json.load(f)
# Get unique file names
weight_files = list(set([filename for filename in index["weight_map"].values()]))
urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files]
Parallel(n_jobs=5)(delayed(cache_download_url)(url, cache_path) for url in tqdm(urls))
else:
raise ValueError(f"Unknown model name: {model_name}")
shards_state_dicts = [{} for _ in range(tp_world_size)]
for weight_path in tqdm(hub_path.glob("*.bin")):
for weight_path in tqdm(Path(cache_path).glob("*.bin")):
state_dict = torch.load(weight_path, map_location="cpu")
keys = list(state_dict.keys())
for state_name in keys:
state = state_dict[state_name]
if any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.query_key_value.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias",
"word_embeddings.weight",
"lm_head.weight",
]
match_suffix(state_name, candidate)
for candidate in [
"self_attention.query_key_value.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias",
"word_embeddings.weight",
]
):
output_size = state.shape[0]
assert output_size % tp_world_size == 0
block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[0] == block_size
if match_suffix(state_name, "lm_head.weight"):
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
else:
shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone()
elif match_suffix(state_name, "lm_head.weight"):
output_size = state.shape[0]
assert output_size % tp_world_size == 0
block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
"lm_head.weight",
]
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
]
):
input_size = state.shape[1]
assert input_size % tp_world_size == 0
@ -66,40 +132,31 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
sharded_weights = torch.split(state, block_size, dim=1)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[1] == block_size
if match_suffix(state_name, "lm_head.weight"):
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
else:
shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
]
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
]
):
shards_state_dicts[0][
"transformer." + state_name
] = state.detach().clone()
shards_state_dicts[0]["transformer." + state_name] = state.detach().clone()
for tp_rank in range(1, tp_world_size):
shards_state_dicts[tp_rank][
"transformer." + state_name
] = torch.zeros_like(state)
shards_state_dicts[tp_rank]["transformer." + state_name] = torch.zeros_like(state)
else:
# We duplicate parameters across tp ranks
for tp_rank in range(tp_world_size):
shards_state_dicts[tp_rank][
"transformer." + state_name
] = state.detach().clone()
shards_state_dicts[tp_rank]["transformer." + state_name] = state.detach().clone()
del state_dict[state_name] # delete key from state_dict
del state # delete tensor
del state_dict
# we save state_dict
for tp_rank, (save_path, shard_state_dict) in enumerate(
zip(save_paths, shards_state_dicts)
zip(save_paths, shards_state_dicts)
):
save_paths.append(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
@ -116,9 +173,10 @@ if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hub-path", required=True, type=str)
parser.add_argument("--model-name", required=True, type=str)
parser.add_argument("--cache-path", required=True, type=str)
parser.add_argument("--save-path", required=True, type=str)
parser.add_argument("--world-size", required=True, type=int)
args = parser.parse_args()
prepare_weights(Path(args.hub_path), Path(args.save_path), args.world_size)
prepare_weights(args.model_name, Path(args.cache_path), Path(args.save_path), args.world_size)

View File

@ -1,102 +0,0 @@
from pathlib import Path
import torch
from torch import nn
from transformers import AutoModelForCausalLM
def match_suffix(text, suffix):
return text[-len(suffix) :] == suffix
def shard_model(model_name: str, path: Path, tp_world_size: int, dtype: torch.dtype):
"""BLOOM specific sharding mechanism"""
save_paths = [
path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
for tp_rank in range(tp_world_size)
]
if all(save_path.exists() for save_path in save_paths):
print("Loading already cached values")
return save_paths
model: nn.Module = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, local_files_only=True
)
shards_state_dicts = [{} for _ in range(tp_world_size)]
state_dict = model.state_dict()
keys = list(state_dict.keys())
for state_name in keys:
print(state_name)
state = state_dict[state_name]
if any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.query_key_value.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias",
"transformer.word_embeddings.weight",
"lm_head.weight",
]
):
output_size = state.shape[0]
assert output_size % tp_world_size == 0
block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[0] == block_size
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
"lm_head.weight",
]
):
input_size = state.shape[1]
assert input_size % tp_world_size == 0
block_size = input_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=1)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[1] == block_size
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
]
):
shards_state_dicts[0][state_name] = state.detach().clone()
for tp_rank in range(1, tp_world_size):
shards_state_dicts[tp_rank][state_name] = torch.zeros_like(state)
else:
# We duplicate parameters across tp ranks
for tp_rank in range(tp_world_size):
shards_state_dicts[tp_rank][state_name] = state.detach().clone()
del state_dict[state_name] # delete key from state_dict
del state # delete tensor
# we save state_dict
for tp_rank, (save_path, shard_state_dict) in enumerate(
zip(save_paths, shards_state_dicts)
):
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(shard_state_dict, save_path)
save_paths.append(save_path)
return save_paths
if __name__ == "__main__":
model_name = "bigscience/bloom"
save_path = Path("/data/shards")
tp_world_size = 8
dtype = torch.bfloat16
shard_model(model_name, save_path, tp_world_size=tp_world_size, dtype=dtype)

14
server/poetry.lock generated
View File

@ -80,6 +80,14 @@ grpcio = ">=1.49.1"
protobuf = ">=4.21.3,<5.0dev"
setuptools = "*"
[[package]]
name = "joblib"
version = "1.2.0"
description = "Lightweight pipelining with Python functions"
category = "main"
optional = false
python-versions = ">=3.7"
[[package]]
name = "numpy"
version = "1.23.3"
@ -197,7 +205,7 @@ python-versions = ">=3.7"
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "cedd0aebeb3731e2bbddf017a2ee6074c285866354272f8dfe930e9606437a25"
content-hash = "f3dc5b2420183f2e7e9257e372489409d7bd26d1dcc535fc2558ebca50c988c2"
[metadata.files]
accelerate = [
@ -310,6 +318,10 @@ grpcio-tools = [
{file = "grpcio_tools-1.49.1-cp39-cp39-win32.whl", hash = "sha256:704d21509ec06efc9d034dbe70e7152715aac004941f4f0f553cf3a0aff15bd5"},
{file = "grpcio_tools-1.49.1-cp39-cp39-win_amd64.whl", hash = "sha256:1efa0c221c719433f441ac0e026fc3c4dbc9a1a08a552ecdc707775e2f2fbbae"},
]
joblib = [
{file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"},
{file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"},
]
numpy = [
{file = "numpy-1.23.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c9f707b5bb73bf277d812ded9896f9512a43edff72712f31667d0a8c2f8e71ee"},
{file = "numpy-1.23.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ffcf105ecdd9396e05a8e58e81faaaf34d3f9875f137c7372450baa5d77c9a54"},

View File

@ -12,6 +12,7 @@ torch = "^1.12.1"
typer = "^0.6.1"
grpcio-reflection = "^1.49.1"
accelerate = "^0.12.0"
joblib = "^1.2.0"
[tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.49.1"