feat: Docker image

This commit is contained in:
Olivier Dehaene 2022-10-14 15:56:21 +02:00
parent 39df4d9975
commit bf99afe916
13 changed files with 265 additions and 180 deletions

1
.dockerignore Normal file
View File

@ -0,0 +1 @@
router/target

59
Dockerfile Normal file
View File

@ -0,0 +1,59 @@
FROM rust:1.64 as builder
WORKDIR /usr/src
COPY proto proto
COPY router router
WORKDIR /usr/src/router
RUN cargo install --path .
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \
MODEL_BASE_PATH=/var/azureml-model \
MODEL_NAME=bigscience/bloom \
NUM_GPUS=8 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
CONDA_DEFAULT_ENV=text-generation \
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y unzip wget libssl-dev && rm -rf /var/lib/apt/lists/*
RUN cd ~ && \
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
chmod +x Miniconda3-latest-Linux-x86_64.sh && \
bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \
conda create -n text-generation python=3.9 -y
# Install specific version of torch
RUN /opt/miniconda/envs/text-generation/bin/pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
# Install specific version of transformers
RUN wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
cd transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 && \
/opt/miniconda/envs/text-generation/bin/python setup.py install
WORKDIR /usr/src
# Install server
COPY server server
RUN cd server && \
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
# Install router
COPY --from=builder /usr/local/cargo/bin/bloom-inference /usr/local/bin/bloom-inference
COPY run.sh .
RUN chmod +x run.sh
CMD ["./run.sh"]

View File

@ -43,8 +43,6 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
## TODO: ## TODO:
- [ ] Improve model download
- Store "shardable" layers separately and layer by layer
- [ ] Add batching args to router CLI - [ ] Add batching args to router CLI
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated - [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
- [ ] Add tests - [ ] Add tests

View File

@ -0,0 +1,3 @@
[toolchain]
channel = "1.64.0"
components = ["rustfmt", "clippy"]

View File

@ -83,7 +83,12 @@ async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
cached_batch = match batch_size { cached_batch = match batch_size {
size if size > 16 => { size if size > 16 => {
wrap_future(client.generate_until_finished_with_cache(batches), request_ids, &db).await wrap_future(
client.generate_until_finished_with_cache(batches),
request_ids,
&db,
)
.await
} }
_ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await, _ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
}; };

View File

@ -1,5 +1,5 @@
use std::net::SocketAddr;
use bloom_inference_client::ShardedClient; use bloom_inference_client::ShardedClient;
use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -37,7 +37,7 @@ fn main() -> Result<(), std::io::Error> {
.expect("Unable to clear cache"); .expect("Unable to clear cache");
tracing::info!("Connected"); tracing::info!("Connected");
let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
server::run(sharded_client, tokenizer, addr).await; server::run(sharded_client, tokenizer, addr).await;
Ok(()) Ok(())

View File

@ -1,10 +1,10 @@
use std::net::SocketAddr;
use axum::{Router, Json};
use axum::http::StatusCode;
use axum::extract::Extension;
use axum::routing::post;
use crate::{Batcher, ShardedClient, Validation}; use crate::{Batcher, ShardedClient, Validation};
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::routing::post;
use axum::{Json, Router};
use serde::Deserialize; use serde::Deserialize;
use std::net::SocketAddr;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::time::Instant; use tokio::time::Instant;
use tracing::instrument; use tracing::instrument;
@ -60,6 +60,31 @@ pub(crate) struct GenerateRequest {
pub parameters: GenerateParameters, pub parameters: GenerateParameters,
} }
#[instrument(skip(state), fields(time, time_per_token))]
async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
let output = state
.infer
.infer(
1,
GenerateRequest {
inputs: "liveness".to_string(),
parameters: GenerateParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
do_sample: false,
max_new_tokens: 1,
},
},
)
.await;
match output {
Ok(_) => Ok(()),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
#[instrument(skip(state), fields(time, time_per_token))] #[instrument(skip(state), fields(time, time_per_token))]
async fn generate( async fn generate(
state: Extension<ServerState>, state: Extension<ServerState>,
@ -67,14 +92,16 @@ async fn generate(
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let start = Instant::now(); let start = Instant::now();
let (input_length, validated_request) = match state.validation let (input_length, validated_request) = match state
.validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: req.inputs.clone(), inputs: req.inputs.clone(),
parameters: req.parameters.clone(), parameters: req.parameters.clone(),
}) })
.await { .await
{
Ok(result) => result, Ok(result) => result,
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR) Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
}; };
let output = state.infer.infer(input_length, validated_request).await; let output = state.infer.infer(input_length, validated_request).await;
@ -102,11 +129,7 @@ struct ServerState {
infer: Batcher, infer: Batcher,
} }
pub async fn run( pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
client: ShardedClient,
tokenizer: Tokenizer,
addr: SocketAddr,
) {
client.clear_cache().await.expect("Unable to clear cache"); client.clear_cache().await.expect("Unable to clear cache");
tracing::info!("Connected"); tracing::info!("Connected");
@ -114,13 +137,16 @@ pub async fn run(
let validation = Validation::new(tokenizer); let validation = Validation::new(tokenizer);
let shared_state = ServerState { let shared_state = ServerState { validation, infer };
validation,
infer,
};
let app = Router::new().route("/generate", post(generate)).layer(Extension(shared_state)); let app = Router::new()
.route("/generate", post(generate))
.layer(Extension(shared_state.clone()))
.route("/health", post(liveness))
.layer(Extension(shared_state.clone()));
axum::Server::bind(&addr) axum::Server::bind(&addr)
.serve(app.into_make_service()).await.unwrap(); .serve(app.into_make_service())
.await
.unwrap();
} }

21
run.sh Executable file
View File

@ -0,0 +1,21 @@
#!/usr/bin/env bash
server_cmd="python server/bloom_inference/main.py $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
$server_cmd &
FILE=/tmp/bloom-inference-0
while :
do
if test -S "$FILE"; then
echo "Text Generation Python gRPC server started"
break
else
echo "Waiting for Text Generation Python gRPC server to start"
sleep 5
fi
done
sleep 1
exec "bloom-inference"

View File

@ -220,12 +220,14 @@ class BLOOM:
def __init__(self, model_name: str): def __init__(self, model_name: str):
if torch.cuda.is_available(): if torch.cuda.is_available():
self.device = torch.device("cuda") self.device = torch.device("cuda")
dtype = torch.bfloat16
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.model = ( self.model = (
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device) AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device).to(dtype)
) )
self.num_heads = self.model.base_model.num_heads self.num_heads = self.model.base_model.num_heads
@ -427,7 +429,8 @@ class BLOOMSharded(BLOOM):
if do_transpose: if do_transpose:
state_dict[key] = state_dict[key].transpose(1, 0).contiguous() state_dict[key] = state_dict[key].transpose(1, 0).contiguous()
model.load_state_dict(state_dict) model.load_state_dict(state_dict, strict=False)
model.tie_weights()
self.model = model.to(self.device).eval() self.model = model.to(self.device).eval()
self.num_heads = config.n_head // self.process_group.size() self.num_heads = config.n_head // self.process_group.size()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -1,18 +1,62 @@
import torch import torch
import os
import tempfile
import json
from typing import BinaryIO
from joblib import Parallel, delayed
from functools import partial
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
MODEL_NAME = "bigscience/bloom" from huggingface_hub import hf_hub_url
from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
def match_suffix(text, suffix): def match_suffix(text, suffix):
return text[-len(suffix) :] == suffix return text[-len(suffix):] == suffix
def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int): def http_get(
url: str,
temp_file: BinaryIO,
*,
timeout=10.0,
max_retries=0,
):
"""
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
"""
r = _request_wrapper(
method="GET",
url=url,
stream=True,
timeout=timeout,
max_retries=max_retries,
)
hf_raise_for_status(r)
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
temp_file.write(chunk)
def cache_download_url(url: str, root_dir: Path):
filename = root_dir / url.split("/")[-1]
if not filename.exists():
temp_file_manager = partial(
tempfile.NamedTemporaryFile, mode="wb", dir=root_dir, delete=False
)
with temp_file_manager() as temp_file:
http_get(url, temp_file)
os.replace(temp_file.name, filename)
return filename
def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world_size: int):
save_paths = [ save_paths = [
save_path / f"{MODEL_NAME}_tp-rank-{tp_rank}-of-{tp_world_size}.pty" save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
for tp_rank in range(tp_world_size) for tp_rank in range(tp_world_size)
] ]
@ -20,9 +64,27 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
print("Weights are already prepared") print("Weights are already prepared")
return return
cache_path.mkdir(parents=True, exist_ok=True)
if model_name == "bigscience/bloom-560m":
url = hf_hub_url(model_name, filename="pytorch_model.bin")
cache_download_url(url, cache_path)
elif model_name == "bigscience/bloom":
url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json")
index_path = cache_download_url(url, cache_path)
with index_path.open("r") as f:
index = json.load(f)
# Get unique file names
weight_files = list(set([filename for filename in index["weight_map"].values()]))
urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files]
Parallel(n_jobs=5)(delayed(cache_download_url)(url, cache_path) for url in tqdm(urls))
else:
raise ValueError(f"Unknown model name: {model_name}")
shards_state_dicts = [{} for _ in range(tp_world_size)] shards_state_dicts = [{} for _ in range(tp_world_size)]
for weight_path in tqdm(hub_path.glob("*.bin")): for weight_path in tqdm(Path(cache_path).glob("*.bin")):
state_dict = torch.load(weight_path, map_location="cpu") state_dict = torch.load(weight_path, map_location="cpu")
keys = list(state_dict.keys()) keys = list(state_dict.keys())
@ -36,7 +98,6 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
"mlp.dense_h_to_4h.weight", "mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias", "mlp.dense_h_to_4h.bias",
"word_embeddings.weight", "word_embeddings.weight",
"lm_head.weight",
] ]
): ):
output_size = state.shape[0] output_size = state.shape[0]
@ -44,20 +105,25 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
block_size = output_size // tp_world_size block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0) sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone()
elif match_suffix(state_name, "lm_head.weight"):
output_size = state.shape[0]
assert output_size % tp_world_size == 0
block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights): for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[0] == block_size
if match_suffix(state_name, "lm_head.weight"):
shards_state_dicts[tp_rank][state_name] = shard.detach().clone() shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
else:
shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
elif any( elif any(
match_suffix(state_name, candidate) match_suffix(state_name, candidate)
for candidate in [ for candidate in [
"self_attention.dense.weight", "self_attention.dense.weight",
"mlp.dense_4h_to_h.weight", "mlp.dense_4h_to_h.weight",
"lm_head.weight",
] ]
): ):
input_size = state.shape[1] input_size = state.shape[1]
@ -66,13 +132,8 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
sharded_weights = torch.split(state, block_size, dim=1) sharded_weights = torch.split(state, block_size, dim=1)
assert len(sharded_weights) == tp_world_size assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights): for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[1] == block_size shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone()
if match_suffix(state_name, "lm_head.weight"):
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
else:
shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
elif any( elif any(
match_suffix(state_name, candidate) match_suffix(state_name, candidate)
for candidate in [ for candidate in [
@ -80,22 +141,18 @@ def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
"mlp.dense_4h_to_h.bias", "mlp.dense_4h_to_h.bias",
] ]
): ):
shards_state_dicts[0][ shards_state_dicts[0]["transformer." + state_name] = state.detach().clone()
"transformer." + state_name
] = state.detach().clone()
for tp_rank in range(1, tp_world_size): for tp_rank in range(1, tp_world_size):
shards_state_dicts[tp_rank][ shards_state_dicts[tp_rank]["transformer." + state_name] = torch.zeros_like(state)
"transformer." + state_name
] = torch.zeros_like(state)
else: else:
# We duplicate parameters across tp ranks # We duplicate parameters across tp ranks
for tp_rank in range(tp_world_size): for tp_rank in range(tp_world_size):
shards_state_dicts[tp_rank][ shards_state_dicts[tp_rank]["transformer." + state_name] = state.detach().clone()
"transformer." + state_name
] = state.detach().clone()
del state_dict[state_name] # delete key from state_dict del state_dict[state_name] # delete key from state_dict
del state # delete tensor del state # delete tensor
del state_dict
# we save state_dict # we save state_dict
for tp_rank, (save_path, shard_state_dict) in enumerate( for tp_rank, (save_path, shard_state_dict) in enumerate(
@ -116,9 +173,10 @@ if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--hub-path", required=True, type=str) parser.add_argument("--model-name", required=True, type=str)
parser.add_argument("--cache-path", required=True, type=str)
parser.add_argument("--save-path", required=True, type=str) parser.add_argument("--save-path", required=True, type=str)
parser.add_argument("--world-size", required=True, type=int) parser.add_argument("--world-size", required=True, type=int)
args = parser.parse_args() args = parser.parse_args()
prepare_weights(Path(args.hub_path), Path(args.save_path), args.world_size) prepare_weights(args.model_name, Path(args.cache_path), Path(args.save_path), args.world_size)

View File

@ -1,102 +0,0 @@
from pathlib import Path
import torch
from torch import nn
from transformers import AutoModelForCausalLM
def match_suffix(text, suffix):
return text[-len(suffix) :] == suffix
def shard_model(model_name: str, path: Path, tp_world_size: int, dtype: torch.dtype):
"""BLOOM specific sharding mechanism"""
save_paths = [
path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
for tp_rank in range(tp_world_size)
]
if all(save_path.exists() for save_path in save_paths):
print("Loading already cached values")
return save_paths
model: nn.Module = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, local_files_only=True
)
shards_state_dicts = [{} for _ in range(tp_world_size)]
state_dict = model.state_dict()
keys = list(state_dict.keys())
for state_name in keys:
print(state_name)
state = state_dict[state_name]
if any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.query_key_value.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias",
"transformer.word_embeddings.weight",
"lm_head.weight",
]
):
output_size = state.shape[0]
assert output_size % tp_world_size == 0
block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[0] == block_size
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
"lm_head.weight",
]
):
input_size = state.shape[1]
assert input_size % tp_world_size == 0
block_size = input_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=1)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
assert shard.shape[1] == block_size
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
]
):
shards_state_dicts[0][state_name] = state.detach().clone()
for tp_rank in range(1, tp_world_size):
shards_state_dicts[tp_rank][state_name] = torch.zeros_like(state)
else:
# We duplicate parameters across tp ranks
for tp_rank in range(tp_world_size):
shards_state_dicts[tp_rank][state_name] = state.detach().clone()
del state_dict[state_name] # delete key from state_dict
del state # delete tensor
# we save state_dict
for tp_rank, (save_path, shard_state_dict) in enumerate(
zip(save_paths, shards_state_dicts)
):
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(shard_state_dict, save_path)
save_paths.append(save_path)
return save_paths
if __name__ == "__main__":
model_name = "bigscience/bloom"
save_path = Path("/data/shards")
tp_world_size = 8
dtype = torch.bfloat16
shard_model(model_name, save_path, tp_world_size=tp_world_size, dtype=dtype)

14
server/poetry.lock generated
View File

@ -80,6 +80,14 @@ grpcio = ">=1.49.1"
protobuf = ">=4.21.3,<5.0dev" protobuf = ">=4.21.3,<5.0dev"
setuptools = "*" setuptools = "*"
[[package]]
name = "joblib"
version = "1.2.0"
description = "Lightweight pipelining with Python functions"
category = "main"
optional = false
python-versions = ">=3.7"
[[package]] [[package]]
name = "numpy" name = "numpy"
version = "1.23.3" version = "1.23.3"
@ -197,7 +205,7 @@ python-versions = ">=3.7"
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "cedd0aebeb3731e2bbddf017a2ee6074c285866354272f8dfe930e9606437a25" content-hash = "f3dc5b2420183f2e7e9257e372489409d7bd26d1dcc535fc2558ebca50c988c2"
[metadata.files] [metadata.files]
accelerate = [ accelerate = [
@ -310,6 +318,10 @@ grpcio-tools = [
{file = "grpcio_tools-1.49.1-cp39-cp39-win32.whl", hash = "sha256:704d21509ec06efc9d034dbe70e7152715aac004941f4f0f553cf3a0aff15bd5"}, {file = "grpcio_tools-1.49.1-cp39-cp39-win32.whl", hash = "sha256:704d21509ec06efc9d034dbe70e7152715aac004941f4f0f553cf3a0aff15bd5"},
{file = "grpcio_tools-1.49.1-cp39-cp39-win_amd64.whl", hash = "sha256:1efa0c221c719433f441ac0e026fc3c4dbc9a1a08a552ecdc707775e2f2fbbae"}, {file = "grpcio_tools-1.49.1-cp39-cp39-win_amd64.whl", hash = "sha256:1efa0c221c719433f441ac0e026fc3c4dbc9a1a08a552ecdc707775e2f2fbbae"},
] ]
joblib = [
{file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"},
{file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"},
]
numpy = [ numpy = [
{file = "numpy-1.23.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c9f707b5bb73bf277d812ded9896f9512a43edff72712f31667d0a8c2f8e71ee"}, {file = "numpy-1.23.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c9f707b5bb73bf277d812ded9896f9512a43edff72712f31667d0a8c2f8e71ee"},
{file = "numpy-1.23.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ffcf105ecdd9396e05a8e58e81faaaf34d3f9875f137c7372450baa5d77c9a54"}, {file = "numpy-1.23.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ffcf105ecdd9396e05a8e58e81faaaf34d3f9875f137c7372450baa5d77c9a54"},

View File

@ -12,6 +12,7 @@ torch = "^1.12.1"
typer = "^0.6.1" typer = "^0.6.1"
grpcio-reflection = "^1.49.1" grpcio-reflection = "^1.49.1"
accelerate = "^0.12.0" accelerate = "^0.12.0"
joblib = "^1.2.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.49.1" grpcio-tools = "^1.49.1"