feat: Docker image
This commit is contained in:
parent
39df4d9975
commit
bf99afe916
|
@ -0,0 +1 @@
|
|||
router/target
|
|
@ -0,0 +1,59 @@
|
|||
FROM rust:1.64 as builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY proto proto
|
||||
COPY router router
|
||||
|
||||
WORKDIR /usr/src/router
|
||||
|
||||
RUN cargo install --path .
|
||||
|
||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
|
||||
ENV LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8 \
|
||||
DEBIAN_FRONTEND=noninteractive \
|
||||
MODEL_BASE_PATH=/var/azureml-model \
|
||||
MODEL_NAME=bigscience/bloom \
|
||||
NUM_GPUS=8 \
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
CUDA_HOME=/usr/local/cuda \
|
||||
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
|
||||
CONDA_DEFAULT_ENV=text-generation \
|
||||
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
|
||||
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
RUN apt-get update && apt-get install -y unzip wget libssl-dev && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN cd ~ && \
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
chmod +x Miniconda3-latest-Linux-x86_64.sh && \
|
||||
bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \
|
||||
conda create -n text-generation python=3.9 -y
|
||||
|
||||
# Install specific version of torch
|
||||
RUN /opt/miniconda/envs/text-generation/bin/pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
|
||||
|
||||
# Install specific version of transformers
|
||||
RUN wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
|
||||
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
|
||||
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
|
||||
cd transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 && \
|
||||
/opt/miniconda/envs/text-generation/bin/python setup.py install
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
# Install server
|
||||
COPY server server
|
||||
RUN cd server && \
|
||||
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
|
||||
|
||||
# Install router
|
||||
COPY --from=builder /usr/local/cargo/bin/bloom-inference /usr/local/bin/bloom-inference
|
||||
|
||||
COPY run.sh .
|
||||
RUN chmod +x run.sh
|
||||
|
||||
CMD ["./run.sh"]
|
|
@ -43,8 +43,6 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
|
|||
|
||||
## TODO:
|
||||
|
||||
- [ ] Improve model download
|
||||
- Store "shardable" layers separately and layer by layer
|
||||
- [ ] Add batching args to router CLI
|
||||
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
|
||||
- [ ] Add tests
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
[toolchain]
|
||||
channel = "1.64.0"
|
||||
components = ["rustfmt", "clippy"]
|
|
@ -83,7 +83,12 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
|
|||
|
||||
cached_batch = match batch_size {
|
||||
size if size > 16 => {
|
||||
wrap_future(client.generate_until_finished_with_cache(batches), request_ids, &db).await
|
||||
wrap_future(
|
||||
client.generate_until_finished_with_cache(batches),
|
||||
request_ids,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
}
|
||||
_ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use std::net::SocketAddr;
|
||||
use bloom_inference_client::ShardedClient;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
|
@ -37,7 +37,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
.expect("Unable to clear cache");
|
||||
tracing::info!("Connected");
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
|
||||
|
||||
server::run(sharded_client, tokenizer, addr).await;
|
||||
Ok(())
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use std::net::SocketAddr;
|
||||
use axum::{Router, Json};
|
||||
use axum::http::StatusCode;
|
||||
use axum::extract::Extension;
|
||||
use axum::routing::post;
|
||||
use crate::{Batcher, ShardedClient, Validation};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::StatusCode;
|
||||
use axum::routing::post;
|
||||
use axum::{Json, Router};
|
||||
use serde::Deserialize;
|
||||
use std::net::SocketAddr;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::time::Instant;
|
||||
use tracing::instrument;
|
||||
|
@ -60,6 +60,31 @@ pub(crate) struct GenerateRequest {
|
|||
pub parameters: GenerateParameters,
|
||||
}
|
||||
|
||||
#[instrument(skip(state), fields(time, time_per_token))]
|
||||
async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
|
||||
let output = state
|
||||
.infer
|
||||
.infer(
|
||||
1,
|
||||
GenerateRequest {
|
||||
inputs: "liveness".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
do_sample: false,
|
||||
max_new_tokens: 1,
|
||||
},
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
match output {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(state), fields(time, time_per_token))]
|
||||
async fn generate(
|
||||
state: Extension<ServerState>,
|
||||
|
@ -67,14 +92,16 @@ async fn generate(
|
|||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let start = Instant::now();
|
||||
|
||||
let (input_length, validated_request) = match state.validation
|
||||
let (input_length, validated_request) = match state
|
||||
.validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: req.inputs.clone(),
|
||||
parameters: req.parameters.clone(),
|
||||
})
|
||||
.await {
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
};
|
||||
|
||||
let output = state.infer.infer(input_length, validated_request).await;
|
||||
|
@ -102,11 +129,7 @@ struct ServerState {
|
|||
infer: Batcher,
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
client: ShardedClient,
|
||||
tokenizer: Tokenizer,
|
||||
addr: SocketAddr,
|
||||
) {
|
||||
pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
|
||||
client.clear_cache().await.expect("Unable to clear cache");
|
||||
tracing::info!("Connected");
|
||||
|
||||
|
@ -114,13 +137,16 @@ pub async fn run(
|
|||
|
||||
let validation = Validation::new(tokenizer);
|
||||
|
||||
let shared_state = ServerState {
|
||||
validation,
|
||||
infer,
|
||||
};
|
||||
let shared_state = ServerState { validation, infer };
|
||||
|
||||
let app = Router::new().route("/generate", post(generate)).layer(Extension(shared_state));
|
||||
let app = Router::new()
|
||||
.route("/generate", post(generate))
|
||||
.layer(Extension(shared_state.clone()))
|
||||
.route("/health", post(liveness))
|
||||
.layer(Extension(shared_state.clone()));
|
||||
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service()).await.unwrap();
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
server_cmd="python server/bloom_inference/main.py $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
|
||||
$server_cmd &
|
||||
|
||||
FILE=/tmp/bloom-inference-0
|
||||
|
||||
while :
|
||||
do
|
||||
if test -S "$FILE"; then
|
||||
echo "Text Generation Python gRPC server started"
|
||||
break
|
||||
else
|
||||
echo "Waiting for Text Generation Python gRPC server to start"
|
||||
sleep 5
|
||||
fi
|
||||
done
|
||||
|
||||
sleep 1
|
||||
|
||||
exec "bloom-inference"
|
|
@ -220,12 +220,14 @@ class BLOOM:
|
|||
def __init__(self, model_name: str):
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
self.model = (
|
||||
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device)
|
||||
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device).to(dtype)
|
||||
)
|
||||
self.num_heads = self.model.base_model.num_heads
|
||||
|
||||
|
@ -427,7 +429,8 @@ class BLOOMSharded(BLOOM):
|
|||
if do_transpose:
|
||||
state_dict[key] = state_dict[key].transpose(1, 0).contiguous()
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.tie_weights()
|
||||
self.model = model.to(self.device).eval()
|
||||
self.num_heads = config.n_head // self.process_group.size()
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
|
|
@ -1,18 +1,62 @@
|
|||
import torch
|
||||
import os
|
||||
import tempfile
|
||||
import json
|
||||
|
||||
from typing import BinaryIO
|
||||
from joblib import Parallel, delayed
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
MODEL_NAME = "bigscience/bloom"
|
||||
from huggingface_hub import hf_hub_url
|
||||
from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
|
||||
|
||||
|
||||
def match_suffix(text, suffix):
|
||||
return text[-len(suffix) :] == suffix
|
||||
return text[-len(suffix):] == suffix
|
||||
|
||||
|
||||
def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
|
||||
def http_get(
|
||||
url: str,
|
||||
temp_file: BinaryIO,
|
||||
*,
|
||||
timeout=10.0,
|
||||
max_retries=0,
|
||||
):
|
||||
"""
|
||||
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
|
||||
"""
|
||||
r = _request_wrapper(
|
||||
method="GET",
|
||||
url=url,
|
||||
stream=True,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
hf_raise_for_status(r)
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
temp_file.write(chunk)
|
||||
|
||||
|
||||
def cache_download_url(url: str, root_dir: Path):
|
||||
filename = root_dir / url.split("/")[-1]
|
||||
|
||||
if not filename.exists():
|
||||
temp_file_manager = partial(
|
||||
tempfile.NamedTemporaryFile, mode="wb", dir=root_dir, delete=False
|
||||
)
|
||||
with temp_file_manager() as temp_file:
|
||||
http_get(url, temp_file)
|
||||
|
||||
os.replace(temp_file.name, filename)
|
||||
return filename
|
||||
|
||||
|
||||
def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world_size: int):
|
||||
save_paths = [
|
||||
save_path / f"{MODEL_NAME}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
|
||||
save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
|
||||
for tp_rank in range(tp_world_size)
|
||||
]
|
||||
|
||||
|
@ -20,45 +64,67 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
|
|||
print("Weights are already prepared")
|
||||
return
|
||||
|
||||
cache_path.mkdir(parents=True, exist_ok=True)
|
||||
if model_name == "bigscience/bloom-560m":
|
||||
url = hf_hub_url(model_name, filename="pytorch_model.bin")
|
||||
cache_download_url(url, cache_path)
|
||||
elif model_name == "bigscience/bloom":
|
||||
url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json")
|
||||
index_path = cache_download_url(url, cache_path)
|
||||
with index_path.open("r") as f:
|
||||
index = json.load(f)
|
||||
|
||||
# Get unique file names
|
||||
weight_files = list(set([filename for filename in index["weight_map"].values()]))
|
||||
urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files]
|
||||
|
||||
Parallel(n_jobs=5)(delayed(cache_download_url)(url, cache_path) for url in tqdm(urls))
|
||||
else:
|
||||
raise ValueError(f"Unknown model name: {model_name}")
|
||||
|
||||
shards_state_dicts = [{} for _ in range(tp_world_size)]
|
||||
|
||||
for weight_path in tqdm(hub_path.glob("*.bin")):
|
||||
for weight_path in tqdm(Path(cache_path).glob("*.bin")):
|
||||
state_dict = torch.load(weight_path, map_location="cpu")
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
for state_name in keys:
|
||||
state = state_dict[state_name]
|
||||
if any(
|
||||
match_suffix(state_name, candidate)
|
||||
for candidate in [
|
||||
"self_attention.query_key_value.weight",
|
||||
"self_attention.query_key_value.bias",
|
||||
"mlp.dense_h_to_4h.weight",
|
||||
"mlp.dense_h_to_4h.bias",
|
||||
"word_embeddings.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
match_suffix(state_name, candidate)
|
||||
for candidate in [
|
||||
"self_attention.query_key_value.weight",
|
||||
"self_attention.query_key_value.bias",
|
||||
"mlp.dense_h_to_4h.weight",
|
||||
"mlp.dense_h_to_4h.bias",
|
||||
"word_embeddings.weight",
|
||||
]
|
||||
):
|
||||
output_size = state.shape[0]
|
||||
assert output_size % tp_world_size == 0
|
||||
block_size = output_size // tp_world_size
|
||||
sharded_weights = torch.split(state, block_size, dim=0)
|
||||
assert len(sharded_weights) == tp_world_size
|
||||
|
||||
for tp_rank, shard in enumerate(sharded_weights):
|
||||
assert shard.shape[0] == block_size
|
||||
if match_suffix(state_name, "lm_head.weight"):
|
||||
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
||||
else:
|
||||
shards_state_dicts[tp_rank][
|
||||
"transformer." + state_name
|
||||
] = shard.detach().clone()
|
||||
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone()
|
||||
|
||||
elif match_suffix(state_name, "lm_head.weight"):
|
||||
output_size = state.shape[0]
|
||||
assert output_size % tp_world_size == 0
|
||||
block_size = output_size // tp_world_size
|
||||
sharded_weights = torch.split(state, block_size, dim=0)
|
||||
assert len(sharded_weights) == tp_world_size
|
||||
|
||||
for tp_rank, shard in enumerate(sharded_weights):
|
||||
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
||||
|
||||
elif any(
|
||||
match_suffix(state_name, candidate)
|
||||
for candidate in [
|
||||
"self_attention.dense.weight",
|
||||
"mlp.dense_4h_to_h.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
match_suffix(state_name, candidate)
|
||||
for candidate in [
|
||||
"self_attention.dense.weight",
|
||||
"mlp.dense_4h_to_h.weight",
|
||||
]
|
||||
):
|
||||
input_size = state.shape[1]
|
||||
assert input_size % tp_world_size == 0
|
||||
|
@ -66,40 +132,31 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
|
|||
sharded_weights = torch.split(state, block_size, dim=1)
|
||||
assert len(sharded_weights) == tp_world_size
|
||||
for tp_rank, shard in enumerate(sharded_weights):
|
||||
assert shard.shape[1] == block_size
|
||||
if match_suffix(state_name, "lm_head.weight"):
|
||||
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
||||
else:
|
||||
shards_state_dicts[tp_rank][
|
||||
"transformer." + state_name
|
||||
] = shard.detach().clone()
|
||||
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone()
|
||||
|
||||
elif any(
|
||||
match_suffix(state_name, candidate)
|
||||
for candidate in [
|
||||
"self_attention.dense.bias",
|
||||
"mlp.dense_4h_to_h.bias",
|
||||
]
|
||||
match_suffix(state_name, candidate)
|
||||
for candidate in [
|
||||
"self_attention.dense.bias",
|
||||
"mlp.dense_4h_to_h.bias",
|
||||
]
|
||||
):
|
||||
shards_state_dicts[0][
|
||||
"transformer." + state_name
|
||||
] = state.detach().clone()
|
||||
shards_state_dicts[0]["transformer." + state_name] = state.detach().clone()
|
||||
for tp_rank in range(1, tp_world_size):
|
||||
shards_state_dicts[tp_rank][
|
||||
"transformer." + state_name
|
||||
] = torch.zeros_like(state)
|
||||
shards_state_dicts[tp_rank]["transformer." + state_name] = torch.zeros_like(state)
|
||||
|
||||
else:
|
||||
# We duplicate parameters across tp ranks
|
||||
for tp_rank in range(tp_world_size):
|
||||
shards_state_dicts[tp_rank][
|
||||
"transformer." + state_name
|
||||
] = state.detach().clone()
|
||||
shards_state_dicts[tp_rank]["transformer." + state_name] = state.detach().clone()
|
||||
|
||||
del state_dict[state_name] # delete key from state_dict
|
||||
del state # delete tensor
|
||||
del state_dict
|
||||
|
||||
# we save state_dict
|
||||
for tp_rank, (save_path, shard_state_dict) in enumerate(
|
||||
zip(save_paths, shards_state_dicts)
|
||||
zip(save_paths, shards_state_dicts)
|
||||
):
|
||||
save_paths.append(save_path)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -116,9 +173,10 @@ if __name__ == "__main__":
|
|||
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument("--hub-path", required=True, type=str)
|
||||
parser.add_argument("--model-name", required=True, type=str)
|
||||
parser.add_argument("--cache-path", required=True, type=str)
|
||||
parser.add_argument("--save-path", required=True, type=str)
|
||||
parser.add_argument("--world-size", required=True, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
prepare_weights(Path(args.hub_path), Path(args.save_path), args.world_size)
|
||||
prepare_weights(args.model_name, Path(args.cache_path), Path(args.save_path), args.world_size)
|
||||
|
|
|
@ -1,102 +0,0 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def match_suffix(text, suffix):
|
||||
return text[-len(suffix) :] == suffix
|
||||
|
||||
|
||||
def shard_model(model_name: str, path: Path, tp_world_size: int, dtype: torch.dtype):
|
||||
"""BLOOM specific sharding mechanism"""
|
||||
save_paths = [
|
||||
path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
|
||||
for tp_rank in range(tp_world_size)
|
||||
]
|
||||
if all(save_path.exists() for save_path in save_paths):
|
||||
print("Loading already cached values")
|
||||
return save_paths
|
||||
|
||||
model: nn.Module = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=dtype, local_files_only=True
|
||||
)
|
||||
|
||||
shards_state_dicts = [{} for _ in range(tp_world_size)]
|
||||
state_dict = model.state_dict()
|
||||
keys = list(state_dict.keys())
|
||||
for state_name in keys:
|
||||
print(state_name)
|
||||
state = state_dict[state_name]
|
||||
if any(
|
||||
match_suffix(state_name, candidate)
|
||||
for candidate in [
|
||||
"self_attention.query_key_value.weight",
|
||||
"self_attention.query_key_value.bias",
|
||||
"mlp.dense_h_to_4h.weight",
|
||||
"mlp.dense_h_to_4h.bias",
|
||||
"transformer.word_embeddings.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
):
|
||||
output_size = state.shape[0]
|
||||
assert output_size % tp_world_size == 0
|
||||
block_size = output_size // tp_world_size
|
||||
sharded_weights = torch.split(state, block_size, dim=0)
|
||||
assert len(sharded_weights) == tp_world_size
|
||||
for tp_rank, shard in enumerate(sharded_weights):
|
||||
assert shard.shape[0] == block_size
|
||||
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
||||
elif any(
|
||||
match_suffix(state_name, candidate)
|
||||
for candidate in [
|
||||
"self_attention.dense.weight",
|
||||
"mlp.dense_4h_to_h.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
):
|
||||
input_size = state.shape[1]
|
||||
assert input_size % tp_world_size == 0
|
||||
block_size = input_size // tp_world_size
|
||||
sharded_weights = torch.split(state, block_size, dim=1)
|
||||
assert len(sharded_weights) == tp_world_size
|
||||
for tp_rank, shard in enumerate(sharded_weights):
|
||||
assert shard.shape[1] == block_size
|
||||
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
||||
elif any(
|
||||
match_suffix(state_name, candidate)
|
||||
for candidate in [
|
||||
"self_attention.dense.bias",
|
||||
"mlp.dense_4h_to_h.bias",
|
||||
]
|
||||
):
|
||||
shards_state_dicts[0][state_name] = state.detach().clone()
|
||||
for tp_rank in range(1, tp_world_size):
|
||||
shards_state_dicts[tp_rank][state_name] = torch.zeros_like(state)
|
||||
else:
|
||||
# We duplicate parameters across tp ranks
|
||||
for tp_rank in range(tp_world_size):
|
||||
shards_state_dicts[tp_rank][state_name] = state.detach().clone()
|
||||
|
||||
del state_dict[state_name] # delete key from state_dict
|
||||
del state # delete tensor
|
||||
|
||||
# we save state_dict
|
||||
for tp_rank, (save_path, shard_state_dict) in enumerate(
|
||||
zip(save_paths, shards_state_dicts)
|
||||
):
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(shard_state_dict, save_path)
|
||||
save_paths.append(save_path)
|
||||
|
||||
return save_paths
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_name = "bigscience/bloom"
|
||||
save_path = Path("/data/shards")
|
||||
tp_world_size = 8
|
||||
dtype = torch.bfloat16
|
||||
|
||||
shard_model(model_name, save_path, tp_world_size=tp_world_size, dtype=dtype)
|
|
@ -80,6 +80,14 @@ grpcio = ">=1.49.1"
|
|||
protobuf = ">=4.21.3,<5.0dev"
|
||||
setuptools = "*"
|
||||
|
||||
[[package]]
|
||||
name = "joblib"
|
||||
version = "1.2.0"
|
||||
description = "Lightweight pipelining with Python functions"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.23.3"
|
||||
|
@ -197,7 +205,7 @@ python-versions = ">=3.7"
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "cedd0aebeb3731e2bbddf017a2ee6074c285866354272f8dfe930e9606437a25"
|
||||
content-hash = "f3dc5b2420183f2e7e9257e372489409d7bd26d1dcc535fc2558ebca50c988c2"
|
||||
|
||||
[metadata.files]
|
||||
accelerate = [
|
||||
|
@ -310,6 +318,10 @@ grpcio-tools = [
|
|||
{file = "grpcio_tools-1.49.1-cp39-cp39-win32.whl", hash = "sha256:704d21509ec06efc9d034dbe70e7152715aac004941f4f0f553cf3a0aff15bd5"},
|
||||
{file = "grpcio_tools-1.49.1-cp39-cp39-win_amd64.whl", hash = "sha256:1efa0c221c719433f441ac0e026fc3c4dbc9a1a08a552ecdc707775e2f2fbbae"},
|
||||
]
|
||||
joblib = [
|
||||
{file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"},
|
||||
{file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"},
|
||||
]
|
||||
numpy = [
|
||||
{file = "numpy-1.23.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c9f707b5bb73bf277d812ded9896f9512a43edff72712f31667d0a8c2f8e71ee"},
|
||||
{file = "numpy-1.23.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ffcf105ecdd9396e05a8e58e81faaaf34d3f9875f137c7372450baa5d77c9a54"},
|
||||
|
|
|
@ -12,6 +12,7 @@ torch = "^1.12.1"
|
|||
typer = "^0.6.1"
|
||||
grpcio-reflection = "^1.49.1"
|
||||
accelerate = "^0.12.0"
|
||||
joblib = "^1.2.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
grpcio-tools = "^1.49.1"
|
||||
|
|
Loading…
Reference in New Issue