feat(server): Use safetensors

Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
This commit is contained in:
Nicolas Patry 2022-10-22 20:00:15 +02:00 committed by GitHub
parent be8827fe41
commit c8ce9b2515
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 418 additions and 499 deletions

View File

@ -4,6 +4,9 @@ members = [
"router/client", "router/client",
"launcher" "launcher"
] ]
exclude = [
"server/safetensors",
]
[profile.release] [profile.release]
debug = 1 debug = 1

View File

@ -36,10 +36,10 @@ ENV LANG=C.UTF-8 \
SHELL ["/bin/bash", "-c"] SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y unzip wget libssl-dev && rm -rf /var/lib/apt/lists/* RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/*
RUN cd ~ && \ RUN cd ~ && \
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
chmod +x Miniconda3-latest-Linux-x86_64.sh && \ chmod +x Miniconda3-latest-Linux-x86_64.sh && \
bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \ bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \
conda create -n text-generation python=3.9 -y conda create -n text-generation python=3.9 -y
@ -54,13 +54,22 @@ RUN cd server && make install-torch
# Install specific version of transformers # Install specific version of transformers
RUN cd server && make install-transformers RUN cd server && make install-transformers
# Install specific version of safetensors
# FIXME: This is a temporary fix while we wait for a new release
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
ENV PATH="/root/.cargo/bin:${PATH}"
RUN cd server && make install-safetensors
# Install server # Install server
COPY proto proto
COPY server server COPY server server
RUN cd server && \ RUN cd server && \
make gen-server && \
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir /opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
# Install router # Install router
COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
CMD text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS --shard-directory $MODEL_BASE_PATH CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS

View File

@ -1,5 +1,5 @@
install-server: install-server:
cd server && make pip-install cd server && make install
install-router: install-router:
cd router && cargo install --path . cd router && cargo install --path .
@ -7,13 +7,16 @@ install-router:
install-launcher: install-launcher:
cd launcher && cargo install --path . cd launcher && cargo install --path .
install: install: install-server install-router install-launcher
make install-server
make install-router server-dev:
make install-launcher cd server && make run-dev
router-dev:
cd router && cargo run
run-bloom-560m: run-bloom-560m:
text-generation-launcher --model-name bigscience/bloom-560m --shard-directory /tmp/models --num-shard 2 text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2
run-bloom: run-bloom:
text-generation-launcher --model-name bigscience/bloom --shard-directory /tmp/models --num-shard 8 text-generation-launcher --model-name bigscience/bloom --num-shard 8

View File

@ -41,11 +41,21 @@ make run-bloom-560m
```shell ```shell
curl 127.0.0.1:3000/generate \ curl 127.0.0.1:3000/generate \
-v \
-X POST \ -X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \ -d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-H 'Content-Type: application/json' -H 'Content-Type: application/json'
``` ```
## Develop
```shell
make server-dev
make router-dev
```
## TODO: ## TODO:
- [ ] Add tests for the `server/model` logic - [ ] Add tests for the `server/model` logic
- [ ] Backport custom CUDA kernels to Transformers
- [ ] Install safetensors with pip

View File

@ -1,4 +1,5 @@
use clap::Parser; use clap::Parser;
use std::env;
use std::io::{BufRead, BufReader, Read}; use std::io::{BufRead, BufReader, Read};
use std::path::Path; use std::path::Path;
use std::process::ExitCode; use std::process::ExitCode;
@ -20,8 +21,6 @@ struct Args {
model_name: String, model_name: String,
#[clap(long, env)] #[clap(long, env)]
num_shard: Option<usize>, num_shard: Option<usize>,
#[clap(long, env)]
shard_directory: Option<String>,
#[clap(default_value = "128", long, env)] #[clap(default_value = "128", long, env)]
max_concurrent_requests: usize, max_concurrent_requests: usize,
#[clap(default_value = "1000", long, env)] #[clap(default_value = "1000", long, env)]
@ -47,7 +46,6 @@ fn main() -> ExitCode {
let Args { let Args {
model_name, model_name,
num_shard, num_shard,
shard_directory,
max_concurrent_requests, max_concurrent_requests,
max_input_length, max_input_length,
max_batch_size, max_batch_size,
@ -82,7 +80,6 @@ fn main() -> ExitCode {
for rank in 0..num_shard { for rank in 0..num_shard {
let model_name = model_name.clone(); let model_name = model_name.clone();
let uds_path = shard_uds_path.clone(); let uds_path = shard_uds_path.clone();
let shard_directory = shard_directory.clone();
let master_addr = master_addr.clone(); let master_addr = master_addr.clone();
let status_sender = status_sender.clone(); let status_sender = status_sender.clone();
let shutdown = shutdown.clone(); let shutdown = shutdown.clone();
@ -91,7 +88,6 @@ fn main() -> ExitCode {
shard_manager( shard_manager(
model_name, model_name,
uds_path, uds_path,
shard_directory,
rank, rank,
num_shard, num_shard,
master_addr, master_addr,
@ -237,7 +233,6 @@ enum ShardStatus {
fn shard_manager( fn shard_manager(
model_name: String, model_name: String,
uds_path: String, uds_path: String,
shard_directory: Option<String>,
rank: usize, rank: usize,
world_size: usize, world_size: usize,
master_addr: String, master_addr: String,
@ -265,10 +260,27 @@ fn shard_manager(
shard_argv.push("--sharded".to_string()); shard_argv.push("--sharded".to_string());
} }
if let Some(shard_directory) = shard_directory { let mut env = vec![
shard_argv.push("--shard-directory".to_string()); ("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
shard_argv.push(shard_directory); (
} "WORLD_SIZE".parse().unwrap(),
world_size.to_string().parse().unwrap(),
),
("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()),
(
"MASTER_PORT".parse().unwrap(),
master_port.to_string().parse().unwrap(),
),
];
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
// Useful when running inside a docker container
if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") {
env.push((
"HUGGINGFACE_HUB_CACHE".parse().unwrap(),
huggingface_hub_cache.parse().unwrap(),
));
};
// Start process // Start process
tracing::info!("Starting shard {}", rank); tracing::info!("Starting shard {}", rank);
@ -280,18 +292,7 @@ fn shard_manager(
// Needed for the shutdown procedure // Needed for the shutdown procedure
setpgid: true, setpgid: true,
// NCCL env vars // NCCL env vars
env: Some(vec![ env: Some(env),
("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
(
"WORLD_SIZE".parse().unwrap(),
world_size.to_string().parse().unwrap(),
),
("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()),
(
"MASTER_PORT".parse().unwrap(),
master_port.to_string().parse().unwrap(),
),
]),
..Default::default() ..Default::default()
}, },
) { ) {
@ -312,6 +313,7 @@ fn shard_manager(
let mut ready = false; let mut ready = false;
let start_time = Instant::now(); let start_time = Instant::now();
let mut wait_time = Instant::now();
loop { loop {
// Process exited // Process exited
if p.poll().is_some() { if p.poll().is_some() {
@ -337,9 +339,12 @@ fn shard_manager(
status_sender.send(ShardStatus::Ready).unwrap(); status_sender.send(ShardStatus::Ready).unwrap();
ready = true; ready = true;
} else if !ready { } else if !ready {
tracing::info!("Waiting for shard {} to be ready...", rank); if wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {} to be ready...", rank);
wait_time = Instant::now();
}
} }
sleep(Duration::from_secs(5)); sleep(Duration::from_millis(100));
} }
} }

3
server/.gitignore vendored
View File

@ -153,3 +153,6 @@ dmypy.json
# Cython debug symbols # Cython debug symbols
cython_debug/ cython_debug/
transformers
safetensors

View File

@ -1,4 +1,6 @@
gen-server: gen-server:
# Compile protos
pip install grpcio-tools==1.49.1 --no-cache-dir
mkdir bloom_inference/pb || true mkdir bloom_inference/pb || true
python -m grpc_tools.protoc -I../proto --python_out=bloom_inference/pb --grpc_python_out=bloom_inference/pb ../proto/generate.proto python -m grpc_tools.protoc -I../proto --python_out=bloom_inference/pb --grpc_python_out=bloom_inference/pb ../proto/generate.proto
find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
@ -7,25 +9,31 @@ gen-server:
install-transformers: install-transformers:
# Install specific version of transformers # Install specific version of transformers
rm transformers || true rm transformers || true
wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip rm transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 || true
curl -L -O https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
mv transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 transformers mv transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 transformers
cd transformers && python setup.py install cd transformers && python setup.py install
install-safetensors:
# Install specific version of safetensors
pip install setuptools_rust
rm safetensors || true
rm safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 || true
curl -L -O https://github.com/huggingface/safetensors/archive/634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
unzip 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
rm 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
mv safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 safetensors
cd safetensors/bindings/python && python setup.py develop
install-torch: install-torch:
# Install specific version of torch # Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
pip-install: install: gen-server install-torch install-transformers install-safetensors
pip install grpcio-tools pip install pip --upgrade
make gen-server pip install -e . --no-cache-dir
make install-torch
make install-transformers
pip install .
install: run-dev:
poetry install python -m torch.distributed.run --nproc_per_node=2 bloom_inference/cli.py serve bigscience/bloom-560m --sharded
make gen-server
make install-torch
make install-transformers

View File

@ -2,14 +2,14 @@
A Python gRPC server for BLOOM Inference A Python gRPC server for BLOOM Inference
## Local Install (with poetry) ## Install
```shell ```shell
make install make install
``` ```
## Local Install (with pip) ## Run
```shell ```shell
make pip-install make run-dev
``` ```

View File

@ -2,9 +2,8 @@ import os
import typer import typer
from pathlib import Path from pathlib import Path
from typing import Optional
from bloom_inference import prepare_weights, server from bloom_inference import server, utils
app = typer.Typer() app = typer.Typer()
@ -13,13 +12,9 @@ app = typer.Typer()
def serve( def serve(
model_name: str, model_name: str,
sharded: bool = False, sharded: bool = False,
shard_directory: Optional[Path] = None,
uds_path: Path = "/tmp/bloom-inference", uds_path: Path = "/tmp/bloom-inference",
): ):
if sharded: if sharded:
assert (
shard_directory is not None
), "shard_directory must be set when sharded is True"
assert ( assert (
os.getenv("RANK", None) is not None os.getenv("RANK", None) is not None
), "RANK must be set when sharded is True" ), "RANK must be set when sharded is True"
@ -33,19 +28,14 @@ def serve(
os.getenv("MASTER_PORT", None) is not None os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True" ), "MASTER_PORT must be set when sharded is True"
server.serve(model_name, sharded, uds_path, shard_directory) server.serve(model_name, sharded, uds_path)
@app.command() @app.command()
def prepare_weights( def download_weights(
model_name: str, model_name: str,
shard_directory: Path,
cache_directory: Path,
num_shard: int = 1,
): ):
prepare_weights.prepare_weights( utils.download_weights(model_name)
model_name, cache_directory, shard_directory, num_shard
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,19 +2,24 @@ import torch
import torch.distributed import torch.distributed
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Optional, Dict from typing import List, Tuple, Optional, Dict
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.modeling_utils import no_init_weights from transformers.models.bloom.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from bloom_inference.pb import generate_pb2 from bloom_inference.pb import generate_pb2
from bloom_inference.prepare_weights import prepare_weights, match_suffix
from bloom_inference.utils import ( from bloom_inference.utils import (
StoppingCriteria, StoppingCriteria,
NextTokenChooser, NextTokenChooser,
initialize_torch_distributed, initialize_torch_distributed,
set_default_dtype, weight_files,
download_weights
) )
torch.manual_seed(0) torch.manual_seed(0)
@ -63,7 +68,7 @@ class Batch:
) )
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens)) stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device) input_ids = tokenizer(inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8).to(device)
all_input_ids = input_ids["input_ids"].unsqueeze(-1) all_input_ids = input_ids["input_ids"].unsqueeze(-1)
return cls( return cls(
@ -369,7 +374,7 @@ class BLOOM:
class BLOOMSharded(BLOOM): class BLOOMSharded(BLOOM):
def __init__(self, model_name: str, shard_directory: Path): def __init__(self, model_name: str):
super(BLOOM, self).__init__() super(BLOOM, self).__init__()
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
@ -382,30 +387,11 @@ class BLOOMSharded(BLOOM):
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
# shard state_dict
if self.master:
# TODO @thomasw21 do some caching
shard_state_dict_paths = prepare_weights(
model_name,
shard_directory / "cache",
shard_directory,
tp_world_size=self.world_size,
)
shard_state_dict_paths = [
str(path.absolute()) for path in shard_state_dict_paths
]
else:
shard_state_dict_paths = [None] * self.world_size
torch.distributed.broadcast_object_list(
shard_state_dict_paths, src=0, group=self.process_group
)
shard_state_dict_path = shard_state_dict_paths[self.rank]
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_name, slow_but_exact=False, tp_parallel=True model_name, slow_but_exact=False, tp_parallel=True
) )
config.pad_token_id = 3 config.pad_token_id = 3
self.num_heads = config.n_head // self.process_group.size()
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later. # in PyTorch 1.12 and later.
@ -414,37 +400,88 @@ class BLOOMSharded(BLOOM):
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
with set_default_dtype(dtype): # Only download weights for small models
with no_init_weights(): if self.master and model_name == "bigscience/bloom-560m":
# we can probably set the device to `meta` here? download_weights(model_name)
model = AutoModelForCausalLM.from_config(config).to(dtype)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
# print_rank_0(f"Initialized model") filenames = weight_files(model_name)
state_dict = torch.load(shard_state_dict_path)
# TODO @thomasw21: HACK in order to transpose all weight prior
for key in state_dict.keys():
do_transpose = False
if not match_suffix(key, "weight"):
continue
for potential_suffix in [ with init_empty_weights():
"self_attention.query_key_value.weight", model = AutoModelForCausalLM.from_config(config)
"self_attention.dense.weight",
"dense_h_to_4h.weight",
"dense_4h_to_h.weight",
]:
if match_suffix(key, potential_suffix):
do_transpose = True
if do_transpose:
state_dict[key] = state_dict[key].transpose(1, 0).contiguous()
model.load_state_dict(state_dict, strict=False)
model.tie_weights()
self.model = model.to(self.device).eval()
self.num_heads = config.n_head // self.process_group.size()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
device=self.device,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group)
@staticmethod
def load_weights(
model, filenames: List[str], device: torch.device, rank: int, world_size: int
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(file, framework="pt", device=str(device)) as f:
for name in f.keys():
full_name = f"transformer.{name}"
module_name, param_name = full_name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[full_name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
if param_name == "weight":
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
tensor = tensor.transpose(1, 0)
else:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
tensor = tensor.transpose(1, 0)
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, past_key_values: Optional = None): def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
outputs = self.model.forward( outputs = self.model.forward(

View File

@ -1,185 +0,0 @@
import torch
import os
import tempfile
import json
from typing import BinaryIO
from joblib import Parallel, delayed
from functools import partial
from pathlib import Path
from tqdm import tqdm
from huggingface_hub import hf_hub_url
from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
def match_suffix(text, suffix):
return text[-len(suffix) :] == suffix
def http_get(
url: str,
temp_file: BinaryIO,
*,
timeout=10.0,
max_retries=0,
):
"""
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
"""
r = _request_wrapper(
method="GET",
url=url,
stream=True,
timeout=timeout,
max_retries=max_retries,
)
hf_raise_for_status(r)
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
temp_file.write(chunk)
def cache_download_url(url: str, root_dir: Path):
filename = root_dir / url.split("/")[-1]
if not filename.exists():
temp_file_manager = partial(
tempfile.NamedTemporaryFile, mode="wb", dir=root_dir, delete=False
)
with temp_file_manager() as temp_file:
http_get(url, temp_file)
os.replace(temp_file.name, filename)
return filename
def prepare_weights(
model_name: str, cache_path: Path, save_path: Path, tp_world_size: int
):
save_paths = [
save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
for tp_rank in range(tp_world_size)
]
if all(save_path.exists() for save_path in save_paths):
print("Weights are already prepared")
return save_paths
cache_path.mkdir(parents=True, exist_ok=True)
if model_name == "bigscience/bloom-560m":
url = hf_hub_url(model_name, filename="pytorch_model.bin")
cache_download_url(url, cache_path)
elif model_name == "bigscience/bloom":
url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json")
index_path = cache_download_url(url, cache_path)
with index_path.open("r") as f:
index = json.load(f)
# Get unique file names
weight_files = list(
set([filename for filename in index["weight_map"].values()])
)
urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files]
Parallel(n_jobs=5)(
delayed(cache_download_url)(url, cache_path) for url in tqdm(urls)
)
else:
raise ValueError(f"Unknown model name: {model_name}")
shards_state_dicts = [{} for _ in range(tp_world_size)]
for weight_path in tqdm(Path(cache_path).glob("*.bin")):
state_dict = torch.load(weight_path, map_location="cpu")
keys = list(state_dict.keys())
for state_name in keys:
state = state_dict[state_name]
if any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.query_key_value.weight",
"self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias",
"word_embeddings.weight",
]
):
output_size = state.shape[0]
assert output_size % tp_world_size == 0
block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
elif match_suffix(state_name, "lm_head.weight"):
output_size = state.shape[0]
assert output_size % tp_world_size == 0
block_size = output_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=0)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.weight",
"mlp.dense_4h_to_h.weight",
]
):
input_size = state.shape[1]
assert input_size % tp_world_size == 0
block_size = input_size // tp_world_size
sharded_weights = torch.split(state, block_size, dim=1)
assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights):
shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
elif any(
match_suffix(state_name, candidate)
for candidate in [
"self_attention.dense.bias",
"mlp.dense_4h_to_h.bias",
]
):
shards_state_dicts[0][
"transformer." + state_name
] = state.detach().clone()
for tp_rank in range(1, tp_world_size):
shards_state_dicts[tp_rank][
"transformer." + state_name
] = torch.zeros_like(state)
else:
# We duplicate parameters across tp ranks
for tp_rank in range(tp_world_size):
shards_state_dicts[tp_rank][
"transformer." + state_name
] = state.detach().clone()
del state_dict[state_name] # delete key from state_dict
del state # delete tensor
del state_dict
# we save state_dict
for tp_rank, (save_path, shard_state_dict) in enumerate(
zip(save_paths, shards_state_dicts)
):
save_paths.append(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
if save_path.exists():
print(f"Skipping {save_path} as it already exists")
else:
torch.save(shard_state_dict, save_path)
return save_paths

View File

@ -69,18 +69,14 @@ def serve(
model_name: str, model_name: str,
sharded: bool, sharded: bool,
uds_path: Path, uds_path: Path,
shard_directory: Optional[Path] = None,
): ):
async def serve_inner( async def serve_inner(
model_name: str, model_name: str,
sharded: bool = False, sharded: bool = False,
shard_directory: Optional[Path] = None,
): ):
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
if sharded: if sharded:
if shard_directory is None: model = BLOOMSharded(model_name)
raise ValueError("shard_directory must be set when sharded is True")
model = BLOOMSharded(model_name, shard_directory)
server_urls = [ server_urls = [
unix_socket_template.format(uds_path, rank) unix_socket_template.format(uds_path, rank)
for rank in range(model.world_size) for rank in range(model.world_size)
@ -109,4 +105,4 @@ def serve(
print("Signal received. Shutting down") print("Signal received. Shutting down")
await server.stop(0) await server.stop(0)
asyncio.run(serve_inner(model_name, sharded, shard_directory)) asyncio.run(serve_inner(model_name, sharded))

View File

@ -1,9 +1,14 @@
import os import os
import contextlib
import torch import torch
import torch.distributed import torch.distributed
from datetime import timedelta from datetime import timedelta
from functools import partial
from joblib import Parallel, delayed
from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache
from huggingface_hub.utils import LocalEntryNotFoundError
from tqdm import tqdm
from transformers.generation_logits_process import ( from transformers.generation_logits_process import (
LogitsProcessorList, LogitsProcessorList,
TemperatureLogitsWarper, TemperatureLogitsWarper,
@ -87,11 +92,42 @@ def initialize_torch_distributed():
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
@contextlib.contextmanager def weight_hub_files(model_name):
def set_default_dtype(dtype): """Get the safetensors filenames on the hub"""
saved_dtype = torch.get_default_dtype() api = HfApi()
torch.set_default_dtype(dtype) info = api.model_info(model_name)
try: filenames = [
yield s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors")
finally: ]
torch.set_default_dtype(saved_dtype) return filenames
def weight_files(model_name):
"""Get the local safetensors filenames"""
filenames = weight_hub_files(model_name)
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(model_name, filename=filename)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_name} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `bloom-inference-server download-weights {model_name}` first."
)
files.append(cache_file)
return files
def download_weights(model_name):
"""Download the safetensors files from the hub"""
filenames = weight_hub_files(model_name)
download_function = partial(
hf_hub_download, repo_id=model_name, local_files_only=False
)
# FIXME: fix the overlapping progress bars
files = Parallel(n_jobs=5)(
delayed(download_function)(filename=filename) for filename in tqdm(filenames)
)
return files

370
server/poetry.lock generated
View File

@ -43,7 +43,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]] [[package]]
name = "grpcio" name = "grpcio"
version = "1.49.1" version = "1.50.0"
description = "HTTP/2-based RPC framework" description = "HTTP/2-based RPC framework"
category = "main" category = "main"
optional = false optional = false
@ -53,31 +53,31 @@ python-versions = ">=3.7"
six = ">=1.5.2" six = ">=1.5.2"
[package.extras] [package.extras]
protobuf = ["grpcio-tools (>=1.49.1)"] protobuf = ["grpcio-tools (>=1.50.0)"]
[[package]] [[package]]
name = "grpcio-reflection" name = "grpcio-reflection"
version = "1.49.1" version = "1.50.0"
description = "Standard Protobuf Reflection Service for gRPC" description = "Standard Protobuf Reflection Service for gRPC"
category = "main" category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
[package.dependencies] [package.dependencies]
grpcio = ">=1.49.1" grpcio = ">=1.50.0"
protobuf = ">=4.21.3" protobuf = ">=4.21.6"
[[package]] [[package]]
name = "grpcio-tools" name = "grpcio-tools"
version = "1.49.1" version = "1.50.0"
description = "Protobuf code generator for gRPC" description = "Protobuf code generator for gRPC"
category = "dev" category = "dev"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
[package.dependencies] [package.dependencies]
grpcio = ">=1.49.1" grpcio = ">=1.50.0"
protobuf = ">=4.21.3,<5.0dev" protobuf = ">=4.21.6,<5.0dev"
setuptools = "*" setuptools = "*"
[[package]] [[package]]
@ -90,7 +90,7 @@ python-versions = ">=3.7"
[[package]] [[package]]
name = "numpy" name = "numpy"
version = "1.23.3" version = "1.23.4"
description = "NumPy is the fundamental package for array computing with Python." description = "NumPy is the fundamental package for array computing with Python."
category = "main" category = "main"
optional = false optional = false
@ -109,7 +109,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
[[package]] [[package]]
name = "protobuf" name = "protobuf"
version = "4.21.7" version = "4.21.8"
description = "" description = ""
category = "main" category = "main"
optional = false optional = false
@ -117,7 +117,7 @@ python-versions = ">=3.7"
[[package]] [[package]]
name = "psutil" name = "psutil"
version = "5.9.2" version = "5.9.3"
description = "Cross-platform lib for process and system monitoring in Python." description = "Cross-platform lib for process and system monitoring in Python."
category = "main" category = "main"
optional = false optional = false
@ -147,7 +147,7 @@ python-versions = ">=3.6"
[[package]] [[package]]
name = "setuptools" name = "setuptools"
version = "65.4.1" version = "65.5.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages" description = "Easily download, build, install, upgrade, and uninstall Python packages"
category = "dev" category = "dev"
optional = false optional = false
@ -196,7 +196,7 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.
[[package]] [[package]]
name = "typing-extensions" name = "typing-extensions"
version = "4.3.0" version = "4.4.0"
description = "Backported and Experimental Type Hints for Python 3.7+" description = "Backported and Experimental Type Hints for Python 3.7+"
category = "main" category = "main"
optional = false optional = false
@ -221,190 +221,194 @@ colorama = [
{file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"},
] ]
grpcio = [ grpcio = [
{file = "grpcio-1.49.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:fd86040232e805b8e6378b2348c928490ee595b058ce9aaa27ed8e4b0f172b20"}, {file = "grpcio-1.50.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:906f4d1beb83b3496be91684c47a5d870ee628715227d5d7c54b04a8de802974"},
{file = "grpcio-1.49.1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:6fd0c9cede9552bf00f8c5791d257d5bf3790d7057b26c59df08be5e7a1e021d"}, {file = "grpcio-1.50.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:2d9fd6e38b16c4d286a01e1776fdf6c7a4123d99ae8d6b3f0b4a03a34bf6ce45"},
{file = "grpcio-1.49.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:d0d402e158d4e84e49c158cb5204119d55e1baf363ee98d6cb5dce321c3a065d"}, {file = "grpcio-1.50.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:4b123fbb7a777a2fedec684ca0b723d85e1d2379b6032a9a9b7851829ed3ca9a"},
{file = "grpcio-1.49.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:822ceec743d42a627e64ea266059a62d214c5a3cdfcd0d7fe2b7a8e4e82527c7"}, {file = "grpcio-1.50.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b2f77a90ba7b85bfb31329f8eab9d9540da2cf8a302128fb1241d7ea239a5469"},
{file = "grpcio-1.49.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2106d9c16527f0a85e2eea6e6b91a74fc99579c60dd810d8690843ea02bc0f5f"}, {file = "grpcio-1.50.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eea18a878cffc804506d39c6682d71f6b42ec1c151d21865a95fae743fda500"},
{file = "grpcio-1.49.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:52dd02b7e7868233c571b49bc38ebd347c3bb1ff8907bb0cb74cb5f00c790afc"}, {file = "grpcio-1.50.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b71916fa8f9eb2abd93151fafe12e18cebb302686b924bd4ec39266211da525"},
{file = "grpcio-1.49.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:120fecba2ec5d14b5a15d11063b39783fda8dc8d24addd83196acb6582cabd9b"}, {file = "grpcio-1.50.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:95ce51f7a09491fb3da8cf3935005bff19983b77c4e9437ef77235d787b06842"},
{file = "grpcio-1.49.1-cp310-cp310-win32.whl", hash = "sha256:f1a3b88e3c53c1a6e6bed635ec1bbb92201bb6a1f2db186179f7f3f244829788"}, {file = "grpcio-1.50.0-cp310-cp310-win32.whl", hash = "sha256:f7025930039a011ed7d7e7ef95a1cb5f516e23c5a6ecc7947259b67bea8e06ca"},
{file = "grpcio-1.49.1-cp310-cp310-win_amd64.whl", hash = "sha256:a7d0017b92d3850abea87c1bdec6ea41104e71c77bca44c3e17f175c6700af62"}, {file = "grpcio-1.50.0-cp310-cp310-win_amd64.whl", hash = "sha256:05f7c248e440f538aaad13eee78ef35f0541e73498dd6f832fe284542ac4b298"},
{file = "grpcio-1.49.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:9fb17ff8c0d56099ac6ebfa84f670c5a62228d6b5c695cf21c02160c2ac1446b"}, {file = "grpcio-1.50.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:ca8a2254ab88482936ce941485c1c20cdeaef0efa71a61dbad171ab6758ec998"},
{file = "grpcio-1.49.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:075f2d06e3db6b48a2157a1bcd52d6cbdca980dd18988fe6afdb41795d51625f"}, {file = "grpcio-1.50.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:3b611b3de3dfd2c47549ca01abfa9bbb95937eb0ea546ea1d762a335739887be"},
{file = "grpcio-1.49.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46d93a1b4572b461a227f1db6b8d35a88952db1c47e5fadcf8b8a2f0e1dd9201"}, {file = "grpcio-1.50.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a4cd8cb09d1bc70b3ea37802be484c5ae5a576108bad14728f2516279165dd7"},
{file = "grpcio-1.49.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc79b2b37d779ac42341ddef40ad5bf0966a64af412c89fc2b062e3ddabb093f"}, {file = "grpcio-1.50.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:156f8009e36780fab48c979c5605eda646065d4695deea4cfcbcfdd06627ddb6"},
{file = "grpcio-1.49.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5f8b3a971c7820ea9878f3fd70086240a36aeee15d1b7e9ecbc2743b0e785568"}, {file = "grpcio-1.50.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de411d2b030134b642c092e986d21aefb9d26a28bf5a18c47dd08ded411a3bc5"},
{file = "grpcio-1.49.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49b301740cf5bc8fed4fee4c877570189ae3951432d79fa8e524b09353659811"}, {file = "grpcio-1.50.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d144ad10eeca4c1d1ce930faa105899f86f5d99cecfe0d7224f3c4c76265c15e"},
{file = "grpcio-1.49.1-cp311-cp311-win32.whl", hash = "sha256:1c66a25afc6c71d357867b341da594a5587db5849b48f4b7d5908d236bb62ede"}, {file = "grpcio-1.50.0-cp311-cp311-win32.whl", hash = "sha256:92d7635d1059d40d2ec29c8bf5ec58900120b3ce5150ef7414119430a4b2dd5c"},
{file = "grpcio-1.49.1-cp311-cp311-win_amd64.whl", hash = "sha256:6b6c3a95d27846f4145d6967899b3ab25fffc6ae99544415e1adcacef84842d2"}, {file = "grpcio-1.50.0-cp311-cp311-win_amd64.whl", hash = "sha256:ce8513aee0af9c159319692bfbf488b718d1793d764798c3d5cff827a09e25ef"},
{file = "grpcio-1.49.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:1cc400c8a2173d1c042997d98a9563e12d9bb3fb6ad36b7f355bc77c7663b8af"}, {file = "grpcio-1.50.0-cp37-cp37m-linux_armv7l.whl", hash = "sha256:8e8999a097ad89b30d584c034929f7c0be280cd7851ac23e9067111167dcbf55"},
{file = "grpcio-1.49.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:34f736bd4d0deae90015c0e383885b431444fe6b6c591dea288173df20603146"}, {file = "grpcio-1.50.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:a50a1be449b9e238b9bd43d3857d40edf65df9416dea988929891d92a9f8a778"},
{file = "grpcio-1.49.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:196082b9c89ebf0961dcd77cb114bed8171964c8e3063b9da2fb33536a6938ed"}, {file = "grpcio-1.50.0-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:cf151f97f5f381163912e8952eb5b3afe89dec9ed723d1561d59cabf1e219a35"},
{file = "grpcio-1.49.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c9f89c42749890618cd3c2464e1fbf88446e3d2f67f1e334c8e5db2f3272bbd"}, {file = "grpcio-1.50.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a23d47f2fc7111869f0ff547f771733661ff2818562b04b9ed674fa208e261f4"},
{file = "grpcio-1.49.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64419cb8a5b612cdb1550c2fd4acbb7d4fb263556cf4625f25522337e461509e"}, {file = "grpcio-1.50.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84d04dec64cc4ed726d07c5d17b73c343c8ddcd6b59c7199c801d6bbb9d9ed1"},
{file = "grpcio-1.49.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:8a5272061826e6164f96e3255405ef6f73b88fd3e8bef464c7d061af8585ac62"}, {file = "grpcio-1.50.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:67dd41a31f6fc5c7db097a5c14a3fa588af54736ffc174af4411d34c4f306f68"},
{file = "grpcio-1.49.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ea9d0172445241ad7cb49577314e39d0af2c5267395b3561d7ced5d70458a9f3"}, {file = "grpcio-1.50.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8d4c8e73bf20fb53fe5a7318e768b9734cf122fe671fcce75654b98ba12dfb75"},
{file = "grpcio-1.49.1-cp37-cp37m-win32.whl", hash = "sha256:2070e87d95991473244c72d96d13596c751cb35558e11f5df5414981e7ed2492"}, {file = "grpcio-1.50.0-cp37-cp37m-win32.whl", hash = "sha256:7489dbb901f4fdf7aec8d3753eadd40839c9085967737606d2c35b43074eea24"},
{file = "grpcio-1.49.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fcedcab49baaa9db4a2d240ac81f2d57eb0052b1c6a9501b46b8ae912720fbf"}, {file = "grpcio-1.50.0-cp37-cp37m-win_amd64.whl", hash = "sha256:531f8b46f3d3db91d9ef285191825d108090856b3bc86a75b7c3930f16ce432f"},
{file = "grpcio-1.49.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:afbb3475cf7f4f7d380c2ca37ee826e51974f3e2665613996a91d6a58583a534"}, {file = "grpcio-1.50.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:d534d169673dd5e6e12fb57cc67664c2641361e1a0885545495e65a7b761b0f4"},
{file = "grpcio-1.49.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:a4f9ba141380abde6c3adc1727f21529137a2552002243fa87c41a07e528245c"}, {file = "grpcio-1.50.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:1d8d02dbb616c0a9260ce587eb751c9c7dc689bc39efa6a88cc4fa3e9c138a7b"},
{file = "grpcio-1.49.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:cf0a1fb18a7204b9c44623dfbd1465b363236ce70c7a4ed30402f9f60d8b743b"}, {file = "grpcio-1.50.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:baab51dcc4f2aecabf4ed1e2f57bceab240987c8b03533f1cef90890e6502067"},
{file = "grpcio-1.49.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:17bb6fe72784b630728c6cff9c9d10ccc3b6d04e85da6e0a7b27fb1d135fac62"}, {file = "grpcio-1.50.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40838061e24f960b853d7bce85086c8e1b81c6342b1f4c47ff0edd44bbae2722"},
{file = "grpcio-1.49.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18305d5a082d1593b005a895c10041f833b16788e88b02bb81061f5ebcc465df"}, {file = "grpcio-1.50.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:931e746d0f75b2a5cff0a1197d21827a3a2f400c06bace036762110f19d3d507"},
{file = "grpcio-1.49.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b6a1b39e59ac5a3067794a0e498911cf2e37e4b19ee9e9977dc5e7051714f13f"}, {file = "grpcio-1.50.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:15f9e6d7f564e8f0776770e6ef32dac172c6f9960c478616c366862933fa08b4"},
{file = "grpcio-1.49.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0e20d59aafc086b1cc68400463bddda6e41d3e5ed30851d1e2e0f6a2e7e342d3"}, {file = "grpcio-1.50.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a4c23e54f58e016761b576976da6a34d876420b993f45f66a2bfb00363ecc1f9"},
{file = "grpcio-1.49.1-cp38-cp38-win32.whl", hash = "sha256:e1e83233d4680863a421f3ee4a7a9b80d33cd27ee9ed7593bc93f6128302d3f2"}, {file = "grpcio-1.50.0-cp38-cp38-win32.whl", hash = "sha256:3e4244c09cc1b65c286d709658c061f12c61c814be0b7030a2d9966ff02611e0"},
{file = "grpcio-1.49.1-cp38-cp38-win_amd64.whl", hash = "sha256:221d42c654d2a41fa31323216279c73ed17d92f533bc140a3390cc1bd78bf63c"}, {file = "grpcio-1.50.0-cp38-cp38-win_amd64.whl", hash = "sha256:8e69aa4e9b7f065f01d3fdcecbe0397895a772d99954bb82eefbb1682d274518"},
{file = "grpcio-1.49.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:fa9e6e61391e99708ac87fc3436f6b7b9c6b845dc4639b406e5e61901e1aacde"}, {file = "grpcio-1.50.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:af98d49e56605a2912cf330b4627e5286243242706c3a9fa0bcec6e6f68646fc"},
{file = "grpcio-1.49.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:9b449e966ef518ce9c860d21f8afe0b0f055220d95bc710301752ac1db96dd6a"}, {file = "grpcio-1.50.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:080b66253f29e1646ac53ef288c12944b131a2829488ac3bac8f52abb4413c0d"},
{file = "grpcio-1.49.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:aa34d2ad9f24e47fa9a3172801c676e4037d862247e39030165fe83821a7aafd"}, {file = "grpcio-1.50.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:ab5d0e3590f0a16cb88de4a3fa78d10eb66a84ca80901eb2c17c1d2c308c230f"},
{file = "grpcio-1.49.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5207f4eed1b775d264fcfe379d8541e1c43b878f2b63c0698f8f5c56c40f3d68"}, {file = "grpcio-1.50.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb11464f480e6103c59d558a3875bd84eed6723f0921290325ebe97262ae1347"},
{file = "grpcio-1.49.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b24a74651438d45619ac67004638856f76cc13d78b7478f2457754cbcb1c8ad"}, {file = "grpcio-1.50.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e07fe0d7ae395897981d16be61f0db9791f482f03fee7d1851fe20ddb4f69c03"},
{file = "grpcio-1.49.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:fe763781669790dc8b9618e7e677c839c87eae6cf28b655ee1fa69ae04eea03f"}, {file = "grpcio-1.50.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d75061367a69808ab2e84c960e9dce54749bcc1e44ad3f85deee3a6c75b4ede9"},
{file = "grpcio-1.49.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2f2ff7ba0f8f431f32d4b4bc3a3713426949d3533b08466c4ff1b2b475932ca8"}, {file = "grpcio-1.50.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ae23daa7eda93c1c49a9ecc316e027ceb99adbad750fbd3a56fa9e4a2ffd5ae0"},
{file = "grpcio-1.49.1-cp39-cp39-win32.whl", hash = "sha256:08ff74aec8ff457a89b97152d36cb811dcc1d17cd5a92a65933524e363327394"}, {file = "grpcio-1.50.0-cp39-cp39-win32.whl", hash = "sha256:177afaa7dba3ab5bfc211a71b90da1b887d441df33732e94e26860b3321434d9"},
{file = "grpcio-1.49.1-cp39-cp39-win_amd64.whl", hash = "sha256:274ffbb39717918c514b35176510ae9be06e1d93121e84d50b350861dcb9a705"}, {file = "grpcio-1.50.0-cp39-cp39-win_amd64.whl", hash = "sha256:ea8ccf95e4c7e20419b7827aa5b6da6f02720270686ac63bd3493a651830235c"},
{file = "grpcio-1.49.1.tar.gz", hash = "sha256:d4725fc9ec8e8822906ae26bb26f5546891aa7fbc3443de970cc556d43a5c99f"}, {file = "grpcio-1.50.0.tar.gz", hash = "sha256:12b479839a5e753580b5e6053571de14006157f2ef9b71f38c56dc9b23b95ad6"},
] ]
grpcio-reflection = [ grpcio-reflection = [
{file = "grpcio-reflection-1.49.1.tar.gz", hash = "sha256:b755dfe61d5255a02fb8d0d845bd0027847dee68bf0763a2b286d664ed07ec4d"}, {file = "grpcio-reflection-1.50.0.tar.gz", hash = "sha256:d75ff6a12f70f28ba6465d5a29781cf2d9c740eb809711990f61699019af1f72"},
{file = "grpcio_reflection-1.49.1-py3-none-any.whl", hash = "sha256:70a325a83c1c1ab583d368711e5733cbef5e068ad2c17cbe77df6e47e0311d1f"}, {file = "grpcio_reflection-1.50.0-py3-none-any.whl", hash = "sha256:92fea4a9527a8bcc52b615b6a7af655a5a118ab0f57227cc814dc2192f8a7455"},
] ]
grpcio-tools = [ grpcio-tools = [
{file = "grpcio-tools-1.49.1.tar.gz", hash = "sha256:84cc64e5b46bad43d5d7bd2fd772b656eba0366961187a847e908e2cb735db91"}, {file = "grpcio-tools-1.50.0.tar.gz", hash = "sha256:88b75f2afd889c7c6939f58d76b58ab84de4723c7de882a1f8448af6632e256f"},
{file = "grpcio_tools-1.49.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:2dfb6c7ece84d46bd690b23d3e060d18115c8bc5047d2e8a33e6747ed323a348"}, {file = "grpcio_tools-1.50.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:fad4f44220847691605e4079ec67b6663064d5b1283bee5a89dd6f7726672d08"},
{file = "grpcio_tools-1.49.1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:8f452a107c054a04db2570f7851a07f060313c6e841b0d394ce6030d598290e6"}, {file = "grpcio_tools-1.50.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:4fe1cd201f0e0f601551c20bea2a8c1e109767ce341ac263e88d5a0acd0a124c"},
{file = "grpcio_tools-1.49.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:6a198871b582287213c4d70792bf275e1d7cf34eed1d019f534ddf4cd15ab039"}, {file = "grpcio_tools-1.50.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:66113bc60db8ccb15e2d7b79efd5757bba33b15e653f20f473c776857bd4415d"},
{file = "grpcio_tools-1.49.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a0cca67a7d0287bdc855d81fdd38dc949c4273273a74f832f9e520abe4f20bc6"}, {file = "grpcio_tools-1.50.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:71e98662f76d3bcf031f36f727ffee594c57006d5dce7b6d031d030d194f054d"},
{file = "grpcio_tools-1.49.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdaff4c89eecb37c247b93025410db68114d97fa093cbb028e9bd7cda5912473"}, {file = "grpcio_tools-1.50.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1891aca6042fc403c1d492d03ce1c4bac0053c26132063b7b0171665a0aa6086"},
{file = "grpcio_tools-1.49.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bb8773118ad315db317d7b22b5ff75d649ca20931733281209e7cbd8c0fad53e"}, {file = "grpcio_tools-1.50.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:54528cec27238761055b28a34ea29392c8e90ac8d8966ee3ccfc9f906eb26fd3"},
{file = "grpcio_tools-1.49.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7cc5534023735b8a8f56760b7c533918f874ce5a9064d7c5456d2709ae2b31f9"}, {file = "grpcio_tools-1.50.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89010e8e0d158966af478d9b345d29d98e52d62464cb93ef192125e94fb904c1"},
{file = "grpcio_tools-1.49.1-cp310-cp310-win32.whl", hash = "sha256:d277642acbe305f5586f9597b78fb9970d6633eb9f89c61e429c92c296c37129"}, {file = "grpcio_tools-1.50.0-cp310-cp310-win32.whl", hash = "sha256:e42ac5aa09debbf7b72444af7f390713e85b8808557e920d9d8071f14127c0eb"},
{file = "grpcio_tools-1.49.1-cp310-cp310-win_amd64.whl", hash = "sha256:eed599cf08fc1a06c72492d3c5750c32f58de3750eddd984af1f257c14326701"}, {file = "grpcio_tools-1.50.0-cp310-cp310-win_amd64.whl", hash = "sha256:c876dbaa6897e45ed074df1562006ac9fdeb5c7235f623cda2c29894fe4f8fba"},
{file = "grpcio_tools-1.49.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:9e5c13809ab2f245398e8446c4c3b399a62d591db651e46806cccf52a700452e"}, {file = "grpcio_tools-1.50.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:3995d96635d6decaf9d5d2e071677782b08ea3186d7cb8ad93f6c2dd97ec9097"},
{file = "grpcio_tools-1.49.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:ab3d0ee9623720ee585fdf3753b3755d3144a4a8ae35bca8e3655fa2f41056be"}, {file = "grpcio_tools-1.50.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1736defd9fc54f942f500dd71555816487e52f07d25b0a58ff4089c375f0b5c3"},
{file = "grpcio_tools-1.49.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ba87e3512bc91d78bf9febcfb522eadda171d2d4ddaf886066b0f01aa4929ad"}, {file = "grpcio_tools-1.50.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fecbc5e773c1c1971641227d8fcd5de97533dbbc0bd7854b436ceb57243f4b2a"},
{file = "grpcio_tools-1.49.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13e13b3643e7577a3ec13b79689eb4d7548890b1e104c04b9ed6557a3c3dd452"}, {file = "grpcio_tools-1.50.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8178a475de70e6ae0a7e3f6d7ec1415c9dd10680654dbea9a1caf79665b0ed4"},
{file = "grpcio_tools-1.49.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:324f67d9cb4b7058b6ce45352fb64c20cc1fa04c34d97ad44772cfe6a4ae0cf5"}, {file = "grpcio_tools-1.50.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d9b8b59a76f67270ce8a4fd7af005732356d5b4ebaffa4598343ea63677bad63"},
{file = "grpcio_tools-1.49.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a64bab81b220c50033f584f57978ebbea575f09c1ccee765cd5c462177988098"}, {file = "grpcio_tools-1.50.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ead51910d83123648b1e36e2de8483b66e9ba358ca1d8965005ba7713cbcb7db"},
{file = "grpcio_tools-1.49.1-cp311-cp311-win32.whl", hash = "sha256:f632d376f92f23e5931697a3acf1b38df7eb719774213d93c52e02acd2d529ac"}, {file = "grpcio_tools-1.50.0-cp311-cp311-win32.whl", hash = "sha256:4c928610ceda7890cf7ddd3cb4cc16fb81995063bf4fdb5171bba572d508079b"},
{file = "grpcio_tools-1.49.1-cp311-cp311-win_amd64.whl", hash = "sha256:28ff2b978d9509474928b9c096a0cce4eaa9c8f7046136aee1545f6211ed8126"}, {file = "grpcio_tools-1.50.0-cp311-cp311-win_amd64.whl", hash = "sha256:56c28f50b9c88cbfc62861de86a1f8b2454d3b48cff85a73d523ec22f130635d"},
{file = "grpcio_tools-1.49.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:46afd3cb7e555187451a5d283f108cdef397952a662cb48680afc615b158864a"}, {file = "grpcio_tools-1.50.0-cp37-cp37m-linux_armv7l.whl", hash = "sha256:675634367846fc906e8b6016830b3d361f9e788bce7820b4e6432d914cb0e25d"},
{file = "grpcio_tools-1.49.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:9284568b728e41fa8f7e9c2e7399545d605f75d8072ef0e9aa2a05655cb679eb"}, {file = "grpcio_tools-1.50.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:f6dbd9d86aa2fbeb8a104ae8edd23c55f182829db58cae3f28842024e08b1071"},
{file = "grpcio_tools-1.49.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:aa34442cf787732cb41f2aa6172007e24f480b8b9d3dc5166de80d63e9072ea4"}, {file = "grpcio_tools-1.50.0-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:2869fba92670d5be3730f1ad108cbacd3f9a7b7afa57187a3c39379513b3fde7"},
{file = "grpcio_tools-1.49.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8c9eb5a4250905414cd53a68caea3eb8f0c515aadb689e6e81b71ebe9ab5c6"}, {file = "grpcio_tools-1.50.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f211a8a6f159b496097cabbf4bdbb6d01db0fd7a8d516dd0e577ee2ddccded0"},
{file = "grpcio_tools-1.49.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab15db024051bf21feb21c29cb2c3ea0a2e4f5cf341d46ef76e17fcf6aaef164"}, {file = "grpcio_tools-1.50.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27f734e44631d755481a179ffd5f2659eaf1e9755cf3539e291d7a28d8373cc8"},
{file = "grpcio_tools-1.49.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:502084b622f758bef620a9107c2db9fcdf66d26c7e0e481d6bb87db4dc917d70"}, {file = "grpcio_tools-1.50.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:9bc97b52d577406c240b0d6acf68ab5100c67c903792f1f9ce04d871ca83de53"},
{file = "grpcio_tools-1.49.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4085890b77c640085f82bf1e90a0ea166ce48000bc2f5180914b974783c9c0a8"}, {file = "grpcio_tools-1.50.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a9816a6c8d1b7adb65217d01624006704faf02b5daa15b6e145b63a5e559d791"},
{file = "grpcio_tools-1.49.1-cp37-cp37m-win32.whl", hash = "sha256:da0edb984699769ce02e18e3392d54b59a7a3f93acd285a68043f5bde4fc028e"}, {file = "grpcio_tools-1.50.0-cp37-cp37m-win32.whl", hash = "sha256:c5dde47bae615554349dd09185eaaeffd9a28e961de347cf0b3b98a853b43a6c"},
{file = "grpcio_tools-1.49.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9887cd622770271101a7dd1832845d64744c3f88fd11ccb2620394079197a42e"}, {file = "grpcio_tools-1.50.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e51133b78b67bd4c06e99bea9c631f29acd52af2b2a254926d192f277f9067b5"},
{file = "grpcio_tools-1.49.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:8440fe7dae6a40c279e3a24b82793735babd38ecbb0d07bb712ff9c8963185d9"}, {file = "grpcio_tools-1.50.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fec4336e067f2bdd45d74c2466dcd06f245e972d2139187d87cfccf847dbaf50"},
{file = "grpcio_tools-1.49.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:b5de2bb7dd6b6231da9b1556ade981513330b740e767f1d902c71ceee0a7d196"}, {file = "grpcio_tools-1.50.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:8d74672162474e2592c50708408b4ee145994451ee5171c9b8f580504b987ad3"},
{file = "grpcio_tools-1.49.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:1e6f06a763aea7836b63d9c117347f2bf7038008ceef72758815c9e09c5fb1fc"}, {file = "grpcio_tools-1.50.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:e6baae6fa9a88c1895193f9610b54d483753ad3e3f91f5d5a10a506afe342a43"},
{file = "grpcio_tools-1.49.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e31562f90120318c5395aabec0f2f69ad8c14b6676996b7730d9d2eaf9415d57"}, {file = "grpcio_tools-1.50.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:927c8cbbfbf3bca46cd05bbabfd8cb6c99f5dbaeb860f2786f9286ce6894ad3d"},
{file = "grpcio_tools-1.49.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49ef9a4e389a618157a9daa9fafdfeeaef1ece9adda7f50f85db928f24d4b3e8"}, {file = "grpcio_tools-1.50.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fe40a66d3ccde959f7290b881c78e8c420fd54289fede5c9119d993703e1a49"},
{file = "grpcio_tools-1.49.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b384cb8e8d9bcb55ee8f9b064374561c7a1a05d848249581403d36fc7060032f"}, {file = "grpcio_tools-1.50.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:86a9e9b2cc30797472bd421c75b42fe1062519ef5675243cb33d99de33af279c"},
{file = "grpcio_tools-1.49.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:73732f77943ac3e898879cbb29c27253aa3c47566b8a59780fd24c6a54de1b66"}, {file = "grpcio_tools-1.50.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c1646bfec6d707bcaafea69088770151ba291ede788f2cda5efdc09c1eac5c1c"},
{file = "grpcio_tools-1.49.1-cp38-cp38-win32.whl", hash = "sha256:b594b2745a5ba9e7a76ce561bc5ab40bc65bb44743c505529b1e4f12af29104d"}, {file = "grpcio_tools-1.50.0-cp38-cp38-win32.whl", hash = "sha256:dedfe2947ad4ccc8cb2f7d0cac5766dd623901f29e1953a2854126559c317885"},
{file = "grpcio_tools-1.49.1-cp38-cp38-win_amd64.whl", hash = "sha256:680fbc88f8709ddcabb88f86749f2d8e429160890cff2c70680880a6970d4eef"}, {file = "grpcio_tools-1.50.0-cp38-cp38-win_amd64.whl", hash = "sha256:73ba21fe500e2476d391cc92aa05545fdc71bdeedb5b9ec2b22e9a3ccc298a2f"},
{file = "grpcio_tools-1.49.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:e8c3869121860f6767eedb7d24fc54dfd71e737fdfbb26e1334684606f3274fd"}, {file = "grpcio_tools-1.50.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:23381ac6a19de0d3cb470f4d1fb450267014234a3923e3c43abb858049f5caf4"},
{file = "grpcio_tools-1.49.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:73e9d7c886ba10e20c97d1dab0ff961ba5800757ae5e31be21b1cda8130c52f8"}, {file = "grpcio_tools-1.50.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:c61730ae21533ef992a89df0321ebe59e344f9f0d1d54572b92d9468969c75b3"},
{file = "grpcio_tools-1.49.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:1760de2dd2c4f08de87b039043a4797f3c17193656e7e3eb84e92f0517083c0c"}, {file = "grpcio_tools-1.50.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:e5adcbec235dce997efacb44cca5cde24db47ad9c432fcbca875185d2095c55a"},
{file = "grpcio_tools-1.49.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd4b1e216dd04d9245ee8f4e601a1f98c25e6e417ea5cf8d825c50589a8b447e"}, {file = "grpcio_tools-1.50.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c40f918d1fe0cb094bfb09ab7567a92f5368fa0b9935235680fe205ddc37518"},
{file = "grpcio_tools-1.49.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1c28751ab5955cae563d07677e799233f0fe1c0fc49d9cbd61ff1957e83617f"}, {file = "grpcio_tools-1.50.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2313c8c73a5003d789fe05f4d01112ccf8330d3d141f5b77b514200d742193f6"},
{file = "grpcio_tools-1.49.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c24239c3ee9ed16314c14b4e24437b5079ebc344f343f33629a582f8699f583b"}, {file = "grpcio_tools-1.50.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:578f9137ebca0f152136b47b4895a6b4dabcf17e41f957b03637d1ae096ff67b"},
{file = "grpcio_tools-1.49.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:892d3dacf1942820f0b7a868a30e6fbcdf5bec08543b682c7274b0101cee632d"}, {file = "grpcio_tools-1.50.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d8aaff094d94d2119fffe7e0d08ca59d42076971f9d85857e99288e500ab31f0"},
{file = "grpcio_tools-1.49.1-cp39-cp39-win32.whl", hash = "sha256:704d21509ec06efc9d034dbe70e7152715aac004941f4f0f553cf3a0aff15bd5"}, {file = "grpcio_tools-1.50.0-cp39-cp39-win32.whl", hash = "sha256:e1a8f9a57bbcc2e633aaf327e39830527f3c1f7add18c7580f3058fe9a0fa780"},
{file = "grpcio_tools-1.49.1-cp39-cp39-win_amd64.whl", hash = "sha256:1efa0c221c719433f441ac0e026fc3c4dbc9a1a08a552ecdc707775e2f2fbbae"}, {file = "grpcio_tools-1.50.0-cp39-cp39-win_amd64.whl", hash = "sha256:b7eb7a84d9171c0ae1550833f4a6ca52372bed9db0fa10f8c9dbe6ca65f97a8c"},
] ]
joblib = [ joblib = [
{file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"}, {file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"},
{file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"}, {file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"},
] ]
numpy = [ numpy = [
{file = "numpy-1.23.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c9f707b5bb73bf277d812ded9896f9512a43edff72712f31667d0a8c2f8e71ee"}, {file = "numpy-1.23.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d79ada05005f6f4f337d3bb9de8a7774f259341c70bc88047a1f7b96a4bcb2"},
{file = "numpy-1.23.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ffcf105ecdd9396e05a8e58e81faaaf34d3f9875f137c7372450baa5d77c9a54"}, {file = "numpy-1.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:926db372bc4ac1edf81cfb6c59e2a881606b409ddc0d0920b988174b2e2a767f"},
{file = "numpy-1.23.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ea3f98a0ffce3f8f57675eb9119f3f4edb81888b6874bc1953f91e0b1d4f440"}, {file = "numpy-1.23.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c237129f0e732885c9a6076a537e974160482eab8f10db6292e92154d4c67d71"},
{file = "numpy-1.23.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:004f0efcb2fe1c0bd6ae1fcfc69cc8b6bf2407e0f18be308612007a0762b4089"}, {file = "numpy-1.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8365b942f9c1a7d0f0dc974747d99dd0a0cdfc5949a33119caf05cb314682d3"},
{file = "numpy-1.23.3-cp310-cp310-win32.whl", hash = "sha256:98dcbc02e39b1658dc4b4508442a560fe3ca5ca0d989f0df062534e5ca3a5c1a"}, {file = "numpy-1.23.4-cp310-cp310-win32.whl", hash = "sha256:2341f4ab6dba0834b685cce16dad5f9b6606ea8a00e6da154f5dbded70fdc4dd"},
{file = "numpy-1.23.3-cp310-cp310-win_amd64.whl", hash = "sha256:39a664e3d26ea854211867d20ebcc8023257c1800ae89773cbba9f9e97bae036"}, {file = "numpy-1.23.4-cp310-cp310-win_amd64.whl", hash = "sha256:d331afac87c92373826af83d2b2b435f57b17a5c74e6268b79355b970626e329"},
{file = "numpy-1.23.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1f27b5322ac4067e67c8f9378b41c746d8feac8bdd0e0ffede5324667b8a075c"}, {file = "numpy-1.23.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:488a66cb667359534bc70028d653ba1cf307bae88eab5929cd707c761ff037db"},
{file = "numpy-1.23.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2ad3ec9a748a8943e6eb4358201f7e1c12ede35f510b1a2221b70af4bb64295c"}, {file = "numpy-1.23.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ce03305dd694c4873b9429274fd41fc7eb4e0e4dea07e0af97a933b079a5814f"},
{file = "numpy-1.23.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bdc9febce3e68b697d931941b263c59e0c74e8f18861f4064c1f712562903411"}, {file = "numpy-1.23.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8981d9b5619569899666170c7c9748920f4a5005bf79c72c07d08c8a035757b0"},
{file = "numpy-1.23.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:301c00cf5e60e08e04d842fc47df641d4a181e651c7135c50dc2762ffe293dbd"}, {file = "numpy-1.23.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a70a7d3ce4c0e9284e92285cba91a4a3f5214d87ee0e95928f3614a256a1488"},
{file = "numpy-1.23.3-cp311-cp311-win32.whl", hash = "sha256:7cd1328e5bdf0dee621912f5833648e2daca72e3839ec1d6695e91089625f0b4"}, {file = "numpy-1.23.4-cp311-cp311-win32.whl", hash = "sha256:5e13030f8793e9ee42f9c7d5777465a560eb78fa7e11b1c053427f2ccab90c79"},
{file = "numpy-1.23.3-cp311-cp311-win_amd64.whl", hash = "sha256:8355fc10fd33a5a70981a5b8a0de51d10af3688d7a9e4a34fcc8fa0d7467bb7f"}, {file = "numpy-1.23.4-cp311-cp311-win_amd64.whl", hash = "sha256:7607b598217745cc40f751da38ffd03512d33ec06f3523fb0b5f82e09f6f676d"},
{file = "numpy-1.23.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bc6e8da415f359b578b00bcfb1d08411c96e9a97f9e6c7adada554a0812a6cc6"}, {file = "numpy-1.23.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7ab46e4e7ec63c8a5e6dbf5c1b9e1c92ba23a7ebecc86c336cb7bf3bd2fb10e5"},
{file = "numpy-1.23.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:22d43376ee0acd547f3149b9ec12eec2f0ca4a6ab2f61753c5b29bb3e795ac4d"}, {file = "numpy-1.23.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a8aae2fb3180940011b4862b2dd3756616841c53db9734b27bb93813cd79fce6"},
{file = "numpy-1.23.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a64403f634e5ffdcd85e0b12c08f04b3080d3e840aef118721021f9b48fc1460"}, {file = "numpy-1.23.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c053d7557a8f022ec823196d242464b6955a7e7e5015b719e76003f63f82d0f"},
{file = "numpy-1.23.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efd9d3abe5774404becdb0748178b48a218f1d8c44e0375475732211ea47c67e"}, {file = "numpy-1.23.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0882323e0ca4245eb0a3d0a74f88ce581cc33aedcfa396e415e5bba7bf05f68"},
{file = "numpy-1.23.3-cp38-cp38-win32.whl", hash = "sha256:f8c02ec3c4c4fcb718fdf89a6c6f709b14949408e8cf2a2be5bfa9c49548fd85"}, {file = "numpy-1.23.4-cp38-cp38-win32.whl", hash = "sha256:dada341ebb79619fe00a291185bba370c9803b1e1d7051610e01ed809ef3a4ba"},
{file = "numpy-1.23.3-cp38-cp38-win_amd64.whl", hash = "sha256:e868b0389c5ccfc092031a861d4e158ea164d8b7fdbb10e3b5689b4fc6498df6"}, {file = "numpy-1.23.4-cp38-cp38-win_amd64.whl", hash = "sha256:0fe563fc8ed9dc4474cbf70742673fc4391d70f4363f917599a7fa99f042d5a8"},
{file = "numpy-1.23.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:09f6b7bdffe57fc61d869a22f506049825d707b288039d30f26a0d0d8ea05164"}, {file = "numpy-1.23.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c67b833dbccefe97cdd3f52798d430b9d3430396af7cdb2a0c32954c3ef73894"},
{file = "numpy-1.23.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8c79d7cf86d049d0c5089231a5bcd31edb03555bd93d81a16870aa98c6cfb79d"}, {file = "numpy-1.23.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f76025acc8e2114bb664294a07ede0727aa75d63a06d2fae96bf29a81747e4a7"},
{file = "numpy-1.23.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5d5420053bbb3dd64c30e58f9363d7a9c27444c3648e61460c1237f9ec3fa14"}, {file = "numpy-1.23.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12ac457b63ec8ded85d85c1e17d85efd3c2b0967ca39560b307a35a6703a4735"},
{file = "numpy-1.23.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5422d6a1ea9b15577a9432e26608c73a78faf0b9039437b075cf322c92e98e7"}, {file = "numpy-1.23.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95de7dc7dc47a312f6feddd3da2500826defdccbc41608d0031276a24181a2c0"},
{file = "numpy-1.23.3-cp39-cp39-win32.whl", hash = "sha256:c1ba66c48b19cc9c2975c0d354f24058888cdc674bebadceb3cdc9ec403fb5d1"}, {file = "numpy-1.23.4-cp39-cp39-win32.whl", hash = "sha256:f2f390aa4da44454db40a1f0201401f9036e8d578a25f01a6e237cea238337ef"},
{file = "numpy-1.23.3-cp39-cp39-win_amd64.whl", hash = "sha256:78a63d2df1d947bd9d1b11d35564c2f9e4b57898aae4626638056ec1a231c40c"}, {file = "numpy-1.23.4-cp39-cp39-win_amd64.whl", hash = "sha256:f260da502d7441a45695199b4e7fd8ca87db659ba1c78f2bbf31f934fe76ae0e"},
{file = "numpy-1.23.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:17c0e467ade9bda685d5ac7f5fa729d8d3e76b23195471adae2d6a6941bd2c18"}, {file = "numpy-1.23.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:61be02e3bf810b60ab74e81d6d0d36246dbfb644a462458bb53b595791251911"},
{file = "numpy-1.23.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91b8d6768a75247026e951dce3b2aac79dc7e78622fc148329135ba189813584"}, {file = "numpy-1.23.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:296d17aed51161dbad3c67ed6d164e51fcd18dbcd5dd4f9d0a9c6055dce30810"},
{file = "numpy-1.23.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:94c15ca4e52671a59219146ff584488907b1f9b3fc232622b47e2cf832e94fb8"}, {file = "numpy-1.23.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4d52914c88b4930dafb6c48ba5115a96cbab40f45740239d9f4159c4ba779962"},
{file = "numpy-1.23.3.tar.gz", hash = "sha256:51bf49c0cd1d52be0a240aa66f3458afc4b95d8993d2d04f0d91fa60c10af6cd"}, {file = "numpy-1.23.4.tar.gz", hash = "sha256:ed2cc92af0efad20198638c69bb0fc2870a58dabfba6eb722c933b48556c686c"},
] ]
packaging = [ packaging = [
{file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"},
{file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
] ]
protobuf = [ protobuf = [
{file = "protobuf-4.21.7-cp310-abi3-win32.whl", hash = "sha256:c7cb105d69a87416bd9023e64324e1c089593e6dae64d2536f06bcbe49cd97d8"}, {file = "protobuf-4.21.8-cp310-abi3-win32.whl", hash = "sha256:c252c55ee15175aa1b21b7b9896e6add5162d066d5202e75c39f96136f08cce3"},
{file = "protobuf-4.21.7-cp310-abi3-win_amd64.whl", hash = "sha256:3ec85328a35a16463c6f419dbce3c0fc42b3e904d966f17f48bae39597c7a543"}, {file = "protobuf-4.21.8-cp310-abi3-win_amd64.whl", hash = "sha256:809ca0b225d3df42655a12f311dd0f4148a943c51f1ad63c38343e457492b689"},
{file = "protobuf-4.21.7-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:db9056b6a11cb5131036d734bcbf91ef3ef9235d6b681b2fc431cbfe5a7f2e56"}, {file = "protobuf-4.21.8-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bbececaf3cfea9ea65ebb7974e6242d310d2a7772a6f015477e0d79993af4511"},
{file = "protobuf-4.21.7-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:ca200645d6235ce0df3ccfdff1567acbab35c4db222a97357806e015f85b5744"}, {file = "protobuf-4.21.8-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:b02eabb9ebb1a089ed20626a90ad7a69cee6bcd62c227692466054b19c38dd1f"},
{file = "protobuf-4.21.7-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:b019c79e23a80735cc8a71b95f76a49a262f579d6b84fd20a0b82279f40e2cc1"}, {file = "protobuf-4.21.8-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:4761201b93e024bb70ee3a6a6425d61f3152ca851f403ba946fb0cde88872661"},
{file = "protobuf-4.21.7-cp37-cp37m-win32.whl", hash = "sha256:d3f89ccf7182293feba2de2739c8bf34fed1ed7c65a5cf987be00311acac57c1"}, {file = "protobuf-4.21.8-cp37-cp37m-win32.whl", hash = "sha256:f2d55ff22ec300c4d954d3b0d1eeb185681ec8ad4fbecff8a5aee6a1cdd345ba"},
{file = "protobuf-4.21.7-cp37-cp37m-win_amd64.whl", hash = "sha256:a74d96cd960b87b4b712797c741bb3ea3a913f5c2dc4b6cbe9c0f8360b75297d"}, {file = "protobuf-4.21.8-cp37-cp37m-win_amd64.whl", hash = "sha256:c5f94911dd8feb3cd3786fc90f7565c9aba7ce45d0f254afd625b9628f578c3f"},
{file = "protobuf-4.21.7-cp38-cp38-win32.whl", hash = "sha256:8e09d1916386eca1ef1353767b6efcebc0a6859ed7f73cb7fb974feba3184830"}, {file = "protobuf-4.21.8-cp38-cp38-win32.whl", hash = "sha256:b37b76efe84d539f16cba55ee0036a11ad91300333abd213849cbbbb284b878e"},
{file = "protobuf-4.21.7-cp38-cp38-win_amd64.whl", hash = "sha256:9e355f2a839d9930d83971b9f562395e13493f0e9211520f8913bd11efa53c02"}, {file = "protobuf-4.21.8-cp38-cp38-win_amd64.whl", hash = "sha256:2c92a7bfcf4ae76a8ac72e545e99a7407e96ffe52934d690eb29a8809ee44d7b"},
{file = "protobuf-4.21.7-cp39-cp39-win32.whl", hash = "sha256:f370c0a71712f8965023dd5b13277444d3cdfecc96b2c778b0e19acbfd60df6e"}, {file = "protobuf-4.21.8-cp39-cp39-win32.whl", hash = "sha256:89d641be4b5061823fa0e463c50a2607a97833e9f8cfb36c2f91ef5ccfcc3861"},
{file = "protobuf-4.21.7-cp39-cp39-win_amd64.whl", hash = "sha256:9643684232b6b340b5e63bb69c9b4904cdd39e4303d498d1a92abddc7e895b7f"}, {file = "protobuf-4.21.8-cp39-cp39-win_amd64.whl", hash = "sha256:bc471cf70a0f53892fdd62f8cd4215f0af8b3f132eeee002c34302dff9edd9b6"},
{file = "protobuf-4.21.7-py2.py3-none-any.whl", hash = "sha256:8066322588d4b499869bf9f665ebe448e793036b552f68c585a9b28f1e393f66"}, {file = "protobuf-4.21.8-py2.py3-none-any.whl", hash = "sha256:a55545ce9eec4030cf100fcb93e861c622d927ef94070c1a3c01922902464278"},
{file = "protobuf-4.21.7-py3-none-any.whl", hash = "sha256:58b81358ec6c0b5d50df761460ae2db58405c063fd415e1101209221a0a810e1"}, {file = "protobuf-4.21.8-py3-none-any.whl", hash = "sha256:0f236ce5016becd989bf39bd20761593e6d8298eccd2d878eda33012645dc369"},
{file = "protobuf-4.21.7.tar.gz", hash = "sha256:71d9dba03ed3432c878a801e2ea51e034b0ea01cf3a4344fb60166cb5f6c8757"}, {file = "protobuf-4.21.8.tar.gz", hash = "sha256:427426593b55ff106c84e4a88cac855175330cb6eb7e889e85aaa7b5652b686d"},
] ]
psutil = [ psutil = [
{file = "psutil-5.9.2-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:8f024fbb26c8daf5d70287bb3edfafa22283c255287cf523c5d81721e8e5d82c"}, {file = "psutil-5.9.3-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:b4a247cd3feaae39bb6085fcebf35b3b8ecd9b022db796d89c8f05067ca28e71"},
{file = "psutil-5.9.2-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:b2f248ffc346f4f4f0d747ee1947963613216b06688be0be2e393986fe20dbbb"}, {file = "psutil-5.9.3-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:5fa88e3d5d0b480602553d362c4b33a63e0c40bfea7312a7bf78799e01e0810b"},
{file = "psutil-5.9.2-cp27-cp27m-win32.whl", hash = "sha256:b1928b9bf478d31fdffdb57101d18f9b70ed4e9b0e41af751851813547b2a9ab"}, {file = "psutil-5.9.3-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:767ef4fa33acda16703725c0473a91e1832d296c37c63896c7153ba81698f1ab"},
{file = "psutil-5.9.2-cp27-cp27m-win_amd64.whl", hash = "sha256:404f4816c16a2fcc4eaa36d7eb49a66df2d083e829d3e39ee8759a411dbc9ecf"}, {file = "psutil-5.9.3-cp27-cp27m-win32.whl", hash = "sha256:9a4af6ed1094f867834f5f07acd1250605a0874169a5fcadbcec864aec2496a6"},
{file = "psutil-5.9.2-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:94e621c6a4ddb2573d4d30cba074f6d1aa0186645917df42c811c473dd22b339"}, {file = "psutil-5.9.3-cp27-cp27m-win_amd64.whl", hash = "sha256:fa5e32c7d9b60b2528108ade2929b115167fe98d59f89555574715054f50fa31"},
{file = "psutil-5.9.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:256098b4f6ffea6441eb54ab3eb64db9ecef18f6a80d7ba91549195d55420f84"}, {file = "psutil-5.9.3-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:fe79b4ad4836e3da6c4650cb85a663b3a51aef22e1a829c384e18fae87e5e727"},
{file = "psutil-5.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:614337922702e9be37a39954d67fdb9e855981624d8011a9927b8f2d3c9625d9"}, {file = "psutil-5.9.3-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:db8e62016add2235cc87fb7ea000ede9e4ca0aa1f221b40cef049d02d5d2593d"},
{file = "psutil-5.9.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:39ec06dc6c934fb53df10c1672e299145ce609ff0611b569e75a88f313634969"}, {file = "psutil-5.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:941a6c2c591da455d760121b44097781bc970be40e0e43081b9139da485ad5b7"},
{file = "psutil-5.9.2-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3ac2c0375ef498e74b9b4ec56df3c88be43fe56cac465627572dbfb21c4be34"}, {file = "psutil-5.9.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:71b1206e7909792d16933a0d2c1c7f04ae196186c51ba8567abae1d041f06dcb"},
{file = "psutil-5.9.2-cp310-cp310-win32.whl", hash = "sha256:e4c4a7636ffc47b7141864f1c5e7d649f42c54e49da2dd3cceb1c5f5d29bfc85"}, {file = "psutil-5.9.3-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f57d63a2b5beaf797b87024d018772439f9d3103a395627b77d17a8d72009543"},
{file = "psutil-5.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:f4cb67215c10d4657e320037109939b1c1d2fd70ca3d76301992f89fe2edb1f1"}, {file = "psutil-5.9.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7507f6c7b0262d3e7b0eeda15045bf5881f4ada70473b87bc7b7c93b992a7d7"},
{file = "psutil-5.9.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dc9bda7d5ced744622f157cc8d8bdd51735dafcecff807e928ff26bdb0ff097d"}, {file = "psutil-5.9.3-cp310-cp310-win32.whl", hash = "sha256:1b540599481c73408f6b392cdffef5b01e8ff7a2ac8caae0a91b8222e88e8f1e"},
{file = "psutil-5.9.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d75291912b945a7351d45df682f9644540d564d62115d4a20d45fa17dc2d48f8"}, {file = "psutil-5.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:547ebb02031fdada635452250ff39942db8310b5c4a8102dfe9384ee5791e650"},
{file = "psutil-5.9.2-cp36-cp36m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4018d5f9b6651f9896c7a7c2c9f4652e4eea53f10751c4e7d08a9093ab587ec"}, {file = "psutil-5.9.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d8c3cc6bb76492133474e130a12351a325336c01c96a24aae731abf5a47fe088"},
{file = "psutil-5.9.2-cp36-cp36m-win32.whl", hash = "sha256:f40ba362fefc11d6bea4403f070078d60053ed422255bd838cd86a40674364c9"}, {file = "psutil-5.9.3-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07d880053c6461c9b89cd5d4808f3b8336665fa3acdefd6777662c5ed73a851a"},
{file = "psutil-5.9.2-cp36-cp36m-win_amd64.whl", hash = "sha256:9770c1d25aee91417eba7869139d629d6328a9422ce1cdd112bd56377ca98444"}, {file = "psutil-5.9.3-cp36-cp36m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e8b50241dd3c2ed498507f87a6602825073c07f3b7e9560c58411c14fe1e1c9"},
{file = "psutil-5.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:42638876b7f5ef43cef8dcf640d3401b27a51ee3fa137cb2aa2e72e188414c32"}, {file = "psutil-5.9.3-cp36-cp36m-win32.whl", hash = "sha256:828c9dc9478b34ab96be75c81942d8df0c2bb49edbb481f597314d92b6441d89"},
{file = "psutil-5.9.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:91aa0dac0c64688667b4285fa29354acfb3e834e1fd98b535b9986c883c2ce1d"}, {file = "psutil-5.9.3-cp36-cp36m-win_amd64.whl", hash = "sha256:ed15edb14f52925869250b1375f0ff58ca5c4fa8adefe4883cfb0737d32f5c02"},
{file = "psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fb54941aac044a61db9d8eb56fc5bee207db3bc58645d657249030e15ba3727"}, {file = "psutil-5.9.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:d266cd05bd4a95ca1c2b9b5aac50d249cf7c94a542f47e0b22928ddf8b80d1ef"},
{file = "psutil-5.9.2-cp37-cp37m-win32.whl", hash = "sha256:7cbb795dcd8ed8fd238bc9e9f64ab188f3f4096d2e811b5a82da53d164b84c3f"}, {file = "psutil-5.9.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e4939ff75149b67aef77980409f156f0082fa36accc475d45c705bb00c6c16a"},
{file = "psutil-5.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:5d39e3a2d5c40efa977c9a8dd4f679763c43c6c255b1340a56489955dbca767c"}, {file = "psutil-5.9.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68fa227c32240c52982cb931801c5707a7f96dd8927f9102d6c7771ea1ff5698"},
{file = "psutil-5.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fd331866628d18223a4265371fd255774affd86244fc307ef66eaf00de0633d5"}, {file = "psutil-5.9.3-cp37-cp37m-win32.whl", hash = "sha256:beb57d8a1ca0ae0eb3d08ccaceb77e1a6d93606f0e1754f0d60a6ebd5c288837"},
{file = "psutil-5.9.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b315febaebae813326296872fdb4be92ad3ce10d1d742a6b0c49fb619481ed0b"}, {file = "psutil-5.9.3-cp37-cp37m-win_amd64.whl", hash = "sha256:12500d761ac091f2426567f19f95fd3f15a197d96befb44a5c1e3cbe6db5752c"},
{file = "psutil-5.9.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7929a516125f62399d6e8e026129c8835f6c5a3aab88c3fff1a05ee8feb840d"}, {file = "psutil-5.9.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ba38cf9984d5462b506e239cf4bc24e84ead4b1d71a3be35e66dad0d13ded7c1"},
{file = "psutil-5.9.2-cp38-cp38-win32.whl", hash = "sha256:561dec454853846d1dd0247b44c2e66a0a0c490f937086930ec4b8f83bf44f06"}, {file = "psutil-5.9.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:46907fa62acaac364fff0b8a9da7b360265d217e4fdeaca0a2397a6883dffba2"},
{file = "psutil-5.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:67b33f27fc0427483b61563a16c90d9f3b547eeb7af0ef1b9fe024cdc9b3a6ea"}, {file = "psutil-5.9.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a04a1836894c8279e5e0a0127c0db8e198ca133d28be8a2a72b4db16f6cf99c1"},
{file = "psutil-5.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b3591616fa07b15050b2f87e1cdefd06a554382e72866fcc0ab2be9d116486c8"}, {file = "psutil-5.9.3-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a4e07611997acf178ad13b842377e3d8e9d0a5bac43ece9bfc22a96735d9a4f"},
{file = "psutil-5.9.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:14b29f581b5edab1f133563272a6011925401804d52d603c5c606936b49c8b97"}, {file = "psutil-5.9.3-cp38-cp38-win32.whl", hash = "sha256:6ced1ad823ecfa7d3ce26fe8aa4996e2e53fb49b7fed8ad81c80958501ec0619"},
{file = "psutil-5.9.2-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4642fd93785a29353d6917a23e2ac6177308ef5e8be5cc17008d885cb9f70f12"}, {file = "psutil-5.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:35feafe232d1aaf35d51bd42790cbccb882456f9f18cdc411532902370d660df"},
{file = "psutil-5.9.2-cp39-cp39-win32.whl", hash = "sha256:ed29ea0b9a372c5188cdb2ad39f937900a10fb5478dc077283bf86eeac678ef1"}, {file = "psutil-5.9.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:538fcf6ae856b5e12d13d7da25ad67f02113c96f5989e6ad44422cb5994ca7fc"},
{file = "psutil-5.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:68b35cbff92d1f7103d8f1db77c977e72f49fcefae3d3d2b91c76b0e7aef48b8"}, {file = "psutil-5.9.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a3d81165b8474087bb90ec4f333a638ccfd1d69d34a9b4a1a7eaac06648f9fbe"},
{file = "psutil-5.9.2.tar.gz", hash = "sha256:feb861a10b6c3bb00701063b37e4afc754f8217f0f09c42280586bd6ac712b5c"}, {file = "psutil-5.9.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a7826e68b0cf4ce2c1ee385d64eab7d70e3133171376cac53d7c1790357ec8f"},
{file = "psutil-5.9.3-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ec296f565191f89c48f33d9544d8d82b0d2af7dd7d2d4e6319f27a818f8d1cc"},
{file = "psutil-5.9.3-cp39-cp39-win32.whl", hash = "sha256:9ec95df684583b5596c82bb380c53a603bb051cf019d5c849c47e117c5064395"},
{file = "psutil-5.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:4bd4854f0c83aa84a5a40d3b5d0eb1f3c128f4146371e03baed4589fe4f3c931"},
{file = "psutil-5.9.3.tar.gz", hash = "sha256:7ccfcdfea4fc4b0a02ca2c31de7fcd186beb9cff8207800e14ab66f79c773af6"},
] ]
pyparsing = [ pyparsing = [
{file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
@ -453,8 +457,8 @@ PyYAML = [
{file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
] ]
setuptools = [ setuptools = [
{file = "setuptools-65.4.1-py3-none-any.whl", hash = "sha256:1b6bdc6161661409c5f21508763dc63ab20a9ac2f8ba20029aaaa7fdb9118012"}, {file = "setuptools-65.5.0-py3-none-any.whl", hash = "sha256:f62ea9da9ed6289bfe868cd6845968a2c854d1427f8548d52cae02a42b4f0356"},
{file = "setuptools-65.4.1.tar.gz", hash = "sha256:3050e338e5871e70c72983072fe34f6032ae1cdeeeb67338199c2f74e083a80e"}, {file = "setuptools-65.5.0.tar.gz", hash = "sha256:512e5536220e38146176efb833d4a62aa726b7bbff82cfbc8ba9eaa3996e0b17"},
] ]
six = [ six = [
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
@ -487,6 +491,6 @@ typer = [
{file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"}, {file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"},
] ]
typing-extensions = [ typing-extensions = [
{file = "typing_extensions-4.3.0-py3-none-any.whl", hash = "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02"}, {file = "typing_extensions-4.4.0-py3-none-any.whl", hash = "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e"},
{file = "typing_extensions-4.3.0.tar.gz", hash = "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6"}, {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"},
] ]