feat: Docker image
This commit is contained in:
parent
39df4d9975
commit
bf99afe916
|
@ -0,0 +1 @@
|
||||||
|
router/target
|
|
@ -0,0 +1,59 @@
|
||||||
|
FROM rust:1.64 as builder
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY proto proto
|
||||||
|
COPY router router
|
||||||
|
|
||||||
|
WORKDIR /usr/src/router
|
||||||
|
|
||||||
|
RUN cargo install --path .
|
||||||
|
|
||||||
|
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||||
|
|
||||||
|
ENV LANG=C.UTF-8 \
|
||||||
|
LC_ALL=C.UTF-8 \
|
||||||
|
DEBIAN_FRONTEND=noninteractive \
|
||||||
|
MODEL_BASE_PATH=/var/azureml-model \
|
||||||
|
MODEL_NAME=bigscience/bloom \
|
||||||
|
NUM_GPUS=8 \
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
CUDA_HOME=/usr/local/cuda \
|
||||||
|
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
|
||||||
|
CONDA_DEFAULT_ENV=text-generation \
|
||||||
|
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
|
||||||
|
|
||||||
|
SHELL ["/bin/bash", "-c"]
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y unzip wget libssl-dev && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN cd ~ && \
|
||||||
|
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||||
|
chmod +x Miniconda3-latest-Linux-x86_64.sh && \
|
||||||
|
bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \
|
||||||
|
conda create -n text-generation python=3.9 -y
|
||||||
|
|
||||||
|
# Install specific version of torch
|
||||||
|
RUN /opt/miniconda/envs/text-generation/bin/pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
|
||||||
|
|
||||||
|
# Install specific version of transformers
|
||||||
|
RUN wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
|
||||||
|
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
|
||||||
|
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
|
||||||
|
cd transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 && \
|
||||||
|
/opt/miniconda/envs/text-generation/bin/python setup.py install
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
# Install server
|
||||||
|
COPY server server
|
||||||
|
RUN cd server && \
|
||||||
|
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
|
||||||
|
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/local/cargo/bin/bloom-inference /usr/local/bin/bloom-inference
|
||||||
|
|
||||||
|
COPY run.sh .
|
||||||
|
RUN chmod +x run.sh
|
||||||
|
|
||||||
|
CMD ["./run.sh"]
|
|
@ -43,8 +43,6 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
|
||||||
|
|
||||||
## TODO:
|
## TODO:
|
||||||
|
|
||||||
- [ ] Improve model download
|
|
||||||
- Store "shardable" layers separately and layer by layer
|
|
||||||
- [ ] Add batching args to router CLI
|
- [ ] Add batching args to router CLI
|
||||||
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
|
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
|
||||||
- [ ] Add tests
|
- [ ] Add tests
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
[toolchain]
|
||||||
|
channel = "1.64.0"
|
||||||
|
components = ["rustfmt", "clippy"]
|
|
@ -83,7 +83,12 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
|
||||||
|
|
||||||
cached_batch = match batch_size {
|
cached_batch = match batch_size {
|
||||||
size if size > 16 => {
|
size if size > 16 => {
|
||||||
wrap_future(client.generate_until_finished_with_cache(batches), request_ids, &db).await
|
wrap_future(
|
||||||
|
client.generate_until_finished_with_cache(batches),
|
||||||
|
request_ids,
|
||||||
|
&db,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
_ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
|
_ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use std::net::SocketAddr;
|
|
||||||
use bloom_inference_client::ShardedClient;
|
use bloom_inference_client::ShardedClient;
|
||||||
|
use std::net::SocketAddr;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
.expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
|
||||||
|
|
||||||
server::run(sharded_client, tokenizer, addr).await;
|
server::run(sharded_client, tokenizer, addr).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
use std::net::SocketAddr;
|
|
||||||
use axum::{Router, Json};
|
|
||||||
use axum::http::StatusCode;
|
|
||||||
use axum::extract::Extension;
|
|
||||||
use axum::routing::post;
|
|
||||||
use crate::{Batcher, ShardedClient, Validation};
|
use crate::{Batcher, ShardedClient, Validation};
|
||||||
|
use axum::extract::Extension;
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::routing::post;
|
||||||
|
use axum::{Json, Router};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
use std::net::SocketAddr;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
@ -60,6 +60,31 @@ pub(crate) struct GenerateRequest {
|
||||||
pub parameters: GenerateParameters,
|
pub parameters: GenerateParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(state), fields(time, time_per_token))]
|
||||||
|
async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
|
||||||
|
let output = state
|
||||||
|
.infer
|
||||||
|
.infer(
|
||||||
|
1,
|
||||||
|
GenerateRequest {
|
||||||
|
inputs: "liveness".to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
temperature: 1.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 1.0,
|
||||||
|
do_sample: false,
|
||||||
|
max_new_tokens: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match output {
|
||||||
|
Ok(_) => Ok(()),
|
||||||
|
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[instrument(skip(state), fields(time, time_per_token))]
|
#[instrument(skip(state), fields(time, time_per_token))]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
state: Extension<ServerState>,
|
state: Extension<ServerState>,
|
||||||
|
@ -67,14 +92,16 @@ async fn generate(
|
||||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let (input_length, validated_request) = match state.validation
|
let (input_length, validated_request) = match state
|
||||||
|
.validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: req.inputs.clone(),
|
inputs: req.inputs.clone(),
|
||||||
parameters: req.parameters.clone(),
|
parameters: req.parameters.clone(),
|
||||||
})
|
})
|
||||||
.await {
|
.await
|
||||||
|
{
|
||||||
Ok(result) => result,
|
Ok(result) => result,
|
||||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR)
|
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||||
};
|
};
|
||||||
|
|
||||||
let output = state.infer.infer(input_length, validated_request).await;
|
let output = state.infer.infer(input_length, validated_request).await;
|
||||||
|
@ -102,11 +129,7 @@ struct ServerState {
|
||||||
infer: Batcher,
|
infer: Batcher,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run(
|
pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
|
||||||
client: ShardedClient,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
addr: SocketAddr,
|
|
||||||
) {
|
|
||||||
client.clear_cache().await.expect("Unable to clear cache");
|
client.clear_cache().await.expect("Unable to clear cache");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
|
@ -114,13 +137,16 @@ pub async fn run(
|
||||||
|
|
||||||
let validation = Validation::new(tokenizer);
|
let validation = Validation::new(tokenizer);
|
||||||
|
|
||||||
let shared_state = ServerState {
|
let shared_state = ServerState { validation, infer };
|
||||||
validation,
|
|
||||||
infer,
|
|
||||||
};
|
|
||||||
|
|
||||||
let app = Router::new().route("/generate", post(generate)).layer(Extension(shared_state));
|
let app = Router::new()
|
||||||
|
.route("/generate", post(generate))
|
||||||
|
.layer(Extension(shared_state.clone()))
|
||||||
|
.route("/health", post(liveness))
|
||||||
|
.layer(Extension(shared_state.clone()));
|
||||||
|
|
||||||
axum::Server::bind(&addr)
|
axum::Server::bind(&addr)
|
||||||
.serve(app.into_make_service()).await.unwrap();
|
.serve(app.into_make_service())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
server_cmd="python server/bloom_inference/main.py $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
|
||||||
|
$server_cmd &
|
||||||
|
|
||||||
|
FILE=/tmp/bloom-inference-0
|
||||||
|
|
||||||
|
while :
|
||||||
|
do
|
||||||
|
if test -S "$FILE"; then
|
||||||
|
echo "Text Generation Python gRPC server started"
|
||||||
|
break
|
||||||
|
else
|
||||||
|
echo "Waiting for Text Generation Python gRPC server to start"
|
||||||
|
sleep 5
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
exec "bloom-inference"
|
|
@ -220,12 +220,14 @@ class BLOOM:
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.device = torch.device("cuda")
|
self.device = torch.device("cuda")
|
||||||
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||||
self.model = (
|
self.model = (
|
||||||
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device)
|
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device).to(dtype)
|
||||||
)
|
)
|
||||||
self.num_heads = self.model.base_model.num_heads
|
self.num_heads = self.model.base_model.num_heads
|
||||||
|
|
||||||
|
@ -427,7 +429,8 @@ class BLOOMSharded(BLOOM):
|
||||||
if do_transpose:
|
if do_transpose:
|
||||||
state_dict[key] = state_dict[key].transpose(1, 0).contiguous()
|
state_dict[key] = state_dict[key].transpose(1, 0).contiguous()
|
||||||
|
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
model.tie_weights()
|
||||||
self.model = model.to(self.device).eval()
|
self.model = model.to(self.device).eval()
|
||||||
self.num_heads = config.n_head // self.process_group.size()
|
self.num_heads = config.n_head // self.process_group.size()
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
|
@ -1,18 +1,62 @@
|
||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import json
|
||||||
|
|
||||||
|
from typing import BinaryIO
|
||||||
|
from joblib import Parallel, delayed
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
MODEL_NAME = "bigscience/bloom"
|
from huggingface_hub import hf_hub_url
|
||||||
|
from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
|
||||||
|
|
||||||
|
|
||||||
def match_suffix(text, suffix):
|
def match_suffix(text, suffix):
|
||||||
return text[-len(suffix) :] == suffix
|
return text[-len(suffix):] == suffix
|
||||||
|
|
||||||
|
|
||||||
def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
|
def http_get(
|
||||||
|
url: str,
|
||||||
|
temp_file: BinaryIO,
|
||||||
|
*,
|
||||||
|
timeout=10.0,
|
||||||
|
max_retries=0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
|
||||||
|
"""
|
||||||
|
r = _request_wrapper(
|
||||||
|
method="GET",
|
||||||
|
url=url,
|
||||||
|
stream=True,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
hf_raise_for_status(r)
|
||||||
|
for chunk in r.iter_content(chunk_size=1024):
|
||||||
|
if chunk: # filter out keep-alive new chunks
|
||||||
|
temp_file.write(chunk)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_download_url(url: str, root_dir: Path):
|
||||||
|
filename = root_dir / url.split("/")[-1]
|
||||||
|
|
||||||
|
if not filename.exists():
|
||||||
|
temp_file_manager = partial(
|
||||||
|
tempfile.NamedTemporaryFile, mode="wb", dir=root_dir, delete=False
|
||||||
|
)
|
||||||
|
with temp_file_manager() as temp_file:
|
||||||
|
http_get(url, temp_file)
|
||||||
|
|
||||||
|
os.replace(temp_file.name, filename)
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world_size: int):
|
||||||
save_paths = [
|
save_paths = [
|
||||||
save_path / f"{MODEL_NAME}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
|
save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
|
||||||
for tp_rank in range(tp_world_size)
|
for tp_rank in range(tp_world_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -20,9 +64,27 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
|
||||||
print("Weights are already prepared")
|
print("Weights are already prepared")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
cache_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
if model_name == "bigscience/bloom-560m":
|
||||||
|
url = hf_hub_url(model_name, filename="pytorch_model.bin")
|
||||||
|
cache_download_url(url, cache_path)
|
||||||
|
elif model_name == "bigscience/bloom":
|
||||||
|
url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json")
|
||||||
|
index_path = cache_download_url(url, cache_path)
|
||||||
|
with index_path.open("r") as f:
|
||||||
|
index = json.load(f)
|
||||||
|
|
||||||
|
# Get unique file names
|
||||||
|
weight_files = list(set([filename for filename in index["weight_map"].values()]))
|
||||||
|
urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files]
|
||||||
|
|
||||||
|
Parallel(n_jobs=5)(delayed(cache_download_url)(url, cache_path) for url in tqdm(urls))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model name: {model_name}")
|
||||||
|
|
||||||
shards_state_dicts = [{} for _ in range(tp_world_size)]
|
shards_state_dicts = [{} for _ in range(tp_world_size)]
|
||||||
|
|
||||||
for weight_path in tqdm(hub_path.glob("*.bin")):
|
for weight_path in tqdm(Path(cache_path).glob("*.bin")):
|
||||||
state_dict = torch.load(weight_path, map_location="cpu")
|
state_dict = torch.load(weight_path, map_location="cpu")
|
||||||
|
|
||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
|
@ -36,7 +98,6 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
|
||||||
"mlp.dense_h_to_4h.weight",
|
"mlp.dense_h_to_4h.weight",
|
||||||
"mlp.dense_h_to_4h.bias",
|
"mlp.dense_h_to_4h.bias",
|
||||||
"word_embeddings.weight",
|
"word_embeddings.weight",
|
||||||
"lm_head.weight",
|
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
output_size = state.shape[0]
|
output_size = state.shape[0]
|
||||||
|
@ -44,20 +105,25 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
|
||||||
block_size = output_size // tp_world_size
|
block_size = output_size // tp_world_size
|
||||||
sharded_weights = torch.split(state, block_size, dim=0)
|
sharded_weights = torch.split(state, block_size, dim=0)
|
||||||
assert len(sharded_weights) == tp_world_size
|
assert len(sharded_weights) == tp_world_size
|
||||||
|
|
||||||
|
for tp_rank, shard in enumerate(sharded_weights):
|
||||||
|
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone()
|
||||||
|
|
||||||
|
elif match_suffix(state_name, "lm_head.weight"):
|
||||||
|
output_size = state.shape[0]
|
||||||
|
assert output_size % tp_world_size == 0
|
||||||
|
block_size = output_size // tp_world_size
|
||||||
|
sharded_weights = torch.split(state, block_size, dim=0)
|
||||||
|
assert len(sharded_weights) == tp_world_size
|
||||||
|
|
||||||
for tp_rank, shard in enumerate(sharded_weights):
|
for tp_rank, shard in enumerate(sharded_weights):
|
||||||
assert shard.shape[0] == block_size
|
|
||||||
if match_suffix(state_name, "lm_head.weight"):
|
|
||||||
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
||||||
else:
|
|
||||||
shards_state_dicts[tp_rank][
|
|
||||||
"transformer." + state_name
|
|
||||||
] = shard.detach().clone()
|
|
||||||
elif any(
|
elif any(
|
||||||
match_suffix(state_name, candidate)
|
match_suffix(state_name, candidate)
|
||||||
for candidate in [
|
for candidate in [
|
||||||
"self_attention.dense.weight",
|
"self_attention.dense.weight",
|
||||||
"mlp.dense_4h_to_h.weight",
|
"mlp.dense_4h_to_h.weight",
|
||||||
"lm_head.weight",
|
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
input_size = state.shape[1]
|
input_size = state.shape[1]
|
||||||
|
@ -66,13 +132,8 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
|
||||||
sharded_weights = torch.split(state, block_size, dim=1)
|
sharded_weights = torch.split(state, block_size, dim=1)
|
||||||
assert len(sharded_weights) == tp_world_size
|
assert len(sharded_weights) == tp_world_size
|
||||||
for tp_rank, shard in enumerate(sharded_weights):
|
for tp_rank, shard in enumerate(sharded_weights):
|
||||||
assert shard.shape[1] == block_size
|
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone()
|
||||||
if match_suffix(state_name, "lm_head.weight"):
|
|
||||||
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
|
||||||
else:
|
|
||||||
shards_state_dicts[tp_rank][
|
|
||||||
"transformer." + state_name
|
|
||||||
] = shard.detach().clone()
|
|
||||||
elif any(
|
elif any(
|
||||||
match_suffix(state_name, candidate)
|
match_suffix(state_name, candidate)
|
||||||
for candidate in [
|
for candidate in [
|
||||||
|
@ -80,22 +141,18 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
|
||||||
"mlp.dense_4h_to_h.bias",
|
"mlp.dense_4h_to_h.bias",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
shards_state_dicts[0][
|
shards_state_dicts[0]["transformer." + state_name] = state.detach().clone()
|
||||||
"transformer." + state_name
|
|
||||||
] = state.detach().clone()
|
|
||||||
for tp_rank in range(1, tp_world_size):
|
for tp_rank in range(1, tp_world_size):
|
||||||
shards_state_dicts[tp_rank][
|
shards_state_dicts[tp_rank]["transformer." + state_name] = torch.zeros_like(state)
|
||||||
"transformer." + state_name
|
|
||||||
] = torch.zeros_like(state)
|
|
||||||
else:
|
else:
|
||||||
# We duplicate parameters across tp ranks
|
# We duplicate parameters across tp ranks
|
||||||
for tp_rank in range(tp_world_size):
|
for tp_rank in range(tp_world_size):
|
||||||
shards_state_dicts[tp_rank][
|
shards_state_dicts[tp_rank]["transformer." + state_name] = state.detach().clone()
|
||||||
"transformer." + state_name
|
|
||||||
] = state.detach().clone()
|
|
||||||
|
|
||||||
del state_dict[state_name] # delete key from state_dict
|
del state_dict[state_name] # delete key from state_dict
|
||||||
del state # delete tensor
|
del state # delete tensor
|
||||||
|
del state_dict
|
||||||
|
|
||||||
# we save state_dict
|
# we save state_dict
|
||||||
for tp_rank, (save_path, shard_state_dict) in enumerate(
|
for tp_rank, (save_path, shard_state_dict) in enumerate(
|
||||||
|
@ -116,9 +173,10 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--hub-path", required=True, type=str)
|
parser.add_argument("--model-name", required=True, type=str)
|
||||||
|
parser.add_argument("--cache-path", required=True, type=str)
|
||||||
parser.add_argument("--save-path", required=True, type=str)
|
parser.add_argument("--save-path", required=True, type=str)
|
||||||
parser.add_argument("--world-size", required=True, type=int)
|
parser.add_argument("--world-size", required=True, type=int)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
prepare_weights(Path(args.hub_path), Path(args.save_path), args.world_size)
|
prepare_weights(args.model_name, Path(args.cache_path), Path(args.save_path), args.world_size)
|
||||||
|
|
|
@ -1,102 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
def match_suffix(text, suffix):
|
|
||||||
return text[-len(suffix) :] == suffix
|
|
||||||
|
|
||||||
|
|
||||||
def shard_model(model_name: str, path: Path, tp_world_size: int, dtype: torch.dtype):
|
|
||||||
"""BLOOM specific sharding mechanism"""
|
|
||||||
save_paths = [
|
|
||||||
path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
|
|
||||||
for tp_rank in range(tp_world_size)
|
|
||||||
]
|
|
||||||
if all(save_path.exists() for save_path in save_paths):
|
|
||||||
print("Loading already cached values")
|
|
||||||
return save_paths
|
|
||||||
|
|
||||||
model: nn.Module = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name, torch_dtype=dtype, local_files_only=True
|
|
||||||
)
|
|
||||||
|
|
||||||
shards_state_dicts = [{} for _ in range(tp_world_size)]
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
keys = list(state_dict.keys())
|
|
||||||
for state_name in keys:
|
|
||||||
print(state_name)
|
|
||||||
state = state_dict[state_name]
|
|
||||||
if any(
|
|
||||||
match_suffix(state_name, candidate)
|
|
||||||
for candidate in [
|
|
||||||
"self_attention.query_key_value.weight",
|
|
||||||
"self_attention.query_key_value.bias",
|
|
||||||
"mlp.dense_h_to_4h.weight",
|
|
||||||
"mlp.dense_h_to_4h.bias",
|
|
||||||
"transformer.word_embeddings.weight",
|
|
||||||
"lm_head.weight",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
output_size = state.shape[0]
|
|
||||||
assert output_size % tp_world_size == 0
|
|
||||||
block_size = output_size // tp_world_size
|
|
||||||
sharded_weights = torch.split(state, block_size, dim=0)
|
|
||||||
assert len(sharded_weights) == tp_world_size
|
|
||||||
for tp_rank, shard in enumerate(sharded_weights):
|
|
||||||
assert shard.shape[0] == block_size
|
|
||||||
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
|
||||||
elif any(
|
|
||||||
match_suffix(state_name, candidate)
|
|
||||||
for candidate in [
|
|
||||||
"self_attention.dense.weight",
|
|
||||||
"mlp.dense_4h_to_h.weight",
|
|
||||||
"lm_head.weight",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
input_size = state.shape[1]
|
|
||||||
assert input_size % tp_world_size == 0
|
|
||||||
block_size = input_size // tp_world_size
|
|
||||||
sharded_weights = torch.split(state, block_size, dim=1)
|
|
||||||
assert len(sharded_weights) == tp_world_size
|
|
||||||
for tp_rank, shard in enumerate(sharded_weights):
|
|
||||||
assert shard.shape[1] == block_size
|
|
||||||
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
|
||||||
elif any(
|
|
||||||
match_suffix(state_name, candidate)
|
|
||||||
for candidate in [
|
|
||||||
"self_attention.dense.bias",
|
|
||||||
"mlp.dense_4h_to_h.bias",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
shards_state_dicts[0][state_name] = state.detach().clone()
|
|
||||||
for tp_rank in range(1, tp_world_size):
|
|
||||||
shards_state_dicts[tp_rank][state_name] = torch.zeros_like(state)
|
|
||||||
else:
|
|
||||||
# We duplicate parameters across tp ranks
|
|
||||||
for tp_rank in range(tp_world_size):
|
|
||||||
shards_state_dicts[tp_rank][state_name] = state.detach().clone()
|
|
||||||
|
|
||||||
del state_dict[state_name] # delete key from state_dict
|
|
||||||
del state # delete tensor
|
|
||||||
|
|
||||||
# we save state_dict
|
|
||||||
for tp_rank, (save_path, shard_state_dict) in enumerate(
|
|
||||||
zip(save_paths, shards_state_dicts)
|
|
||||||
):
|
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
torch.save(shard_state_dict, save_path)
|
|
||||||
save_paths.append(save_path)
|
|
||||||
|
|
||||||
return save_paths
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
model_name = "bigscience/bloom"
|
|
||||||
save_path = Path("/data/shards")
|
|
||||||
tp_world_size = 8
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
shard_model(model_name, save_path, tp_world_size=tp_world_size, dtype=dtype)
|
|
|
@ -80,6 +80,14 @@ grpcio = ">=1.49.1"
|
||||||
protobuf = ">=4.21.3,<5.0dev"
|
protobuf = ">=4.21.3,<5.0dev"
|
||||||
setuptools = "*"
|
setuptools = "*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "joblib"
|
||||||
|
version = "1.2.0"
|
||||||
|
description = "Lightweight pipelining with Python functions"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "numpy"
|
name = "numpy"
|
||||||
version = "1.23.3"
|
version = "1.23.3"
|
||||||
|
@ -197,7 +205,7 @@ python-versions = ">=3.7"
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.9"
|
python-versions = "^3.9"
|
||||||
content-hash = "cedd0aebeb3731e2bbddf017a2ee6074c285866354272f8dfe930e9606437a25"
|
content-hash = "f3dc5b2420183f2e7e9257e372489409d7bd26d1dcc535fc2558ebca50c988c2"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
accelerate = [
|
accelerate = [
|
||||||
|
@ -310,6 +318,10 @@ grpcio-tools = [
|
||||||
{file = "grpcio_tools-1.49.1-cp39-cp39-win32.whl", hash = "sha256:704d21509ec06efc9d034dbe70e7152715aac004941f4f0f553cf3a0aff15bd5"},
|
{file = "grpcio_tools-1.49.1-cp39-cp39-win32.whl", hash = "sha256:704d21509ec06efc9d034dbe70e7152715aac004941f4f0f553cf3a0aff15bd5"},
|
||||||
{file = "grpcio_tools-1.49.1-cp39-cp39-win_amd64.whl", hash = "sha256:1efa0c221c719433f441ac0e026fc3c4dbc9a1a08a552ecdc707775e2f2fbbae"},
|
{file = "grpcio_tools-1.49.1-cp39-cp39-win_amd64.whl", hash = "sha256:1efa0c221c719433f441ac0e026fc3c4dbc9a1a08a552ecdc707775e2f2fbbae"},
|
||||||
]
|
]
|
||||||
|
joblib = [
|
||||||
|
{file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"},
|
||||||
|
{file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"},
|
||||||
|
]
|
||||||
numpy = [
|
numpy = [
|
||||||
{file = "numpy-1.23.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c9f707b5bb73bf277d812ded9896f9512a43edff72712f31667d0a8c2f8e71ee"},
|
{file = "numpy-1.23.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c9f707b5bb73bf277d812ded9896f9512a43edff72712f31667d0a8c2f8e71ee"},
|
||||||
{file = "numpy-1.23.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ffcf105ecdd9396e05a8e58e81faaaf34d3f9875f137c7372450baa5d77c9a54"},
|
{file = "numpy-1.23.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ffcf105ecdd9396e05a8e58e81faaaf34d3f9875f137c7372450baa5d77c9a54"},
|
||||||
|
|
|
@ -12,6 +12,7 @@ torch = "^1.12.1"
|
||||||
typer = "^0.6.1"
|
typer = "^0.6.1"
|
||||||
grpcio-reflection = "^1.49.1"
|
grpcio-reflection = "^1.49.1"
|
||||||
accelerate = "^0.12.0"
|
accelerate = "^0.12.0"
|
||||||
|
joblib = "^1.2.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
grpcio-tools = "^1.49.1"
|
grpcio-tools = "^1.49.1"
|
||||||
|
|
Loading…
Reference in New Issue