From c8ce9b25153e727e9b2de3efa0acf5d69922ac62 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sat, 22 Oct 2022 20:00:15 +0200 Subject: [PATCH] feat(server): Use safetensors Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> --- Cargo.toml | 3 + Dockerfile | 15 +- Makefile | 17 +- README.md | 12 +- launcher/src/main.rs | 53 ++-- server/.gitignore | 3 + server/Makefile | 32 +- server/README.md | 6 +- server/bloom_inference/cli.py | 18 +- server/bloom_inference/model.py | 141 ++++++--- server/bloom_inference/prepare_weights.py | 185 ----------- server/bloom_inference/server.py | 8 +- server/bloom_inference/utils.py | 54 +++- server/poetry.lock | 370 +++++++++++----------- 14 files changed, 418 insertions(+), 499 deletions(-) delete mode 100644 server/bloom_inference/prepare_weights.py diff --git a/Cargo.toml b/Cargo.toml index d3f7dfb0..e23b802e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,9 @@ members = [ "router/client", "launcher" ] +exclude = [ + "server/safetensors", +] [profile.release] debug = 1 diff --git a/Dockerfile b/Dockerfile index 70b91bc2..78f66c0f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -36,10 +36,10 @@ ENV LANG=C.UTF-8 \ SHELL ["/bin/bash", "-c"] -RUN apt-get update && apt-get install -y unzip wget libssl-dev && rm -rf /var/lib/apt/lists/* +RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/* RUN cd ~ && \ - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ chmod +x Miniconda3-latest-Linux-x86_64.sh && \ bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \ conda create -n text-generation python=3.9 -y @@ -54,13 +54,22 @@ RUN cd server && make install-torch # Install specific version of transformers RUN cd server && make install-transformers +# Install specific version of safetensors +# FIXME: This is a temporary fix while we wait for a new release +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y +ENV PATH="/root/.cargo/bin:${PATH}" +RUN cd server && make install-safetensors + # Install server +COPY proto proto COPY server server RUN cd server && \ + make gen-server && \ /opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir # Install router COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router +# Install launcher COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher -CMD text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS --shard-directory $MODEL_BASE_PATH \ No newline at end of file +CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS \ No newline at end of file diff --git a/Makefile b/Makefile index 3a80a12d..e22309f9 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ install-server: - cd server && make pip-install + cd server && make install install-router: cd router && cargo install --path . @@ -7,13 +7,16 @@ install-router: install-launcher: cd launcher && cargo install --path . -install: - make install-server - make install-router - make install-launcher +install: install-server install-router install-launcher + +server-dev: + cd server && make run-dev + +router-dev: + cd router && cargo run run-bloom-560m: - text-generation-launcher --model-name bigscience/bloom-560m --shard-directory /tmp/models --num-shard 2 + text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 run-bloom: - text-generation-launcher --model-name bigscience/bloom --shard-directory /tmp/models --num-shard 8 + text-generation-launcher --model-name bigscience/bloom --num-shard 8 diff --git a/README.md b/README.md index 18e9865d..1af90dce 100644 --- a/README.md +++ b/README.md @@ -41,11 +41,21 @@ make run-bloom-560m ```shell curl 127.0.0.1:3000/generate \ + -v \ -X POST \ -d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \ -H 'Content-Type: application/json' ``` +## Develop + +```shell +make server-dev +make router-dev +``` + ## TODO: -- [ ] Add tests for the `server/model` logic \ No newline at end of file +- [ ] Add tests for the `server/model` logic +- [ ] Backport custom CUDA kernels to Transformers +- [ ] Install safetensors with pip \ No newline at end of file diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0aedf563..6b529d60 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,4 +1,5 @@ use clap::Parser; +use std::env; use std::io::{BufRead, BufReader, Read}; use std::path::Path; use std::process::ExitCode; @@ -20,8 +21,6 @@ struct Args { model_name: String, #[clap(long, env)] num_shard: Option, - #[clap(long, env)] - shard_directory: Option, #[clap(default_value = "128", long, env)] max_concurrent_requests: usize, #[clap(default_value = "1000", long, env)] @@ -47,7 +46,6 @@ fn main() -> ExitCode { let Args { model_name, num_shard, - shard_directory, max_concurrent_requests, max_input_length, max_batch_size, @@ -82,7 +80,6 @@ fn main() -> ExitCode { for rank in 0..num_shard { let model_name = model_name.clone(); let uds_path = shard_uds_path.clone(); - let shard_directory = shard_directory.clone(); let master_addr = master_addr.clone(); let status_sender = status_sender.clone(); let shutdown = shutdown.clone(); @@ -91,7 +88,6 @@ fn main() -> ExitCode { shard_manager( model_name, uds_path, - shard_directory, rank, num_shard, master_addr, @@ -237,7 +233,6 @@ enum ShardStatus { fn shard_manager( model_name: String, uds_path: String, - shard_directory: Option, rank: usize, world_size: usize, master_addr: String, @@ -265,10 +260,27 @@ fn shard_manager( shard_argv.push("--sharded".to_string()); } - if let Some(shard_directory) = shard_directory { - shard_argv.push("--shard-directory".to_string()); - shard_argv.push(shard_directory); - } + let mut env = vec![ + ("RANK".parse().unwrap(), rank.to_string().parse().unwrap()), + ( + "WORLD_SIZE".parse().unwrap(), + world_size.to_string().parse().unwrap(), + ), + ("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()), + ( + "MASTER_PORT".parse().unwrap(), + master_port.to_string().parse().unwrap(), + ), + ]; + + // If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard + // Useful when running inside a docker container + if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") { + env.push(( + "HUGGINGFACE_HUB_CACHE".parse().unwrap(), + huggingface_hub_cache.parse().unwrap(), + )); + }; // Start process tracing::info!("Starting shard {}", rank); @@ -280,18 +292,7 @@ fn shard_manager( // Needed for the shutdown procedure setpgid: true, // NCCL env vars - env: Some(vec![ - ("RANK".parse().unwrap(), rank.to_string().parse().unwrap()), - ( - "WORLD_SIZE".parse().unwrap(), - world_size.to_string().parse().unwrap(), - ), - ("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()), - ( - "MASTER_PORT".parse().unwrap(), - master_port.to_string().parse().unwrap(), - ), - ]), + env: Some(env), ..Default::default() }, ) { @@ -312,6 +313,7 @@ fn shard_manager( let mut ready = false; let start_time = Instant::now(); + let mut wait_time = Instant::now(); loop { // Process exited if p.poll().is_some() { @@ -337,9 +339,12 @@ fn shard_manager( status_sender.send(ShardStatus::Ready).unwrap(); ready = true; } else if !ready { - tracing::info!("Waiting for shard {} to be ready...", rank); + if wait_time.elapsed() > Duration::from_secs(10) { + tracing::info!("Waiting for shard {} to be ready...", rank); + wait_time = Instant::now(); + } } - sleep(Duration::from_secs(5)); + sleep(Duration::from_millis(100)); } } diff --git a/server/.gitignore b/server/.gitignore index 0ebf3670..88bbea8f 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -153,3 +153,6 @@ dmypy.json # Cython debug symbols cython_debug/ + +transformers +safetensors \ No newline at end of file diff --git a/server/Makefile b/server/Makefile index 52b4d405..9c7a2c43 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,6 @@ gen-server: + # Compile protos + pip install grpcio-tools==1.49.1 --no-cache-dir mkdir bloom_inference/pb || true python -m grpc_tools.protoc -I../proto --python_out=bloom_inference/pb --grpc_python_out=bloom_inference/pb ../proto/generate.proto find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; @@ -7,25 +9,31 @@ gen-server: install-transformers: # Install specific version of transformers rm transformers || true - wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip + rm transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 || true + curl -L -O https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip mv transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 transformers cd transformers && python setup.py install +install-safetensors: + # Install specific version of safetensors + pip install setuptools_rust + rm safetensors || true + rm safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 || true + curl -L -O https://github.com/huggingface/safetensors/archive/634deccbcbad5eaf417935281f8b3be7ebca69c5.zip + unzip 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip + rm 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip + mv safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 safetensors + cd safetensors/bindings/python && python setup.py develop + install-torch: # Install specific version of torch pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir -pip-install: - pip install grpcio-tools - make gen-server - make install-torch - make install-transformers - pip install . +install: gen-server install-torch install-transformers install-safetensors + pip install pip --upgrade + pip install -e . --no-cache-dir -install: - poetry install - make gen-server - make install-torch - make install-transformers +run-dev: + python -m torch.distributed.run --nproc_per_node=2 bloom_inference/cli.py serve bigscience/bloom-560m --sharded \ No newline at end of file diff --git a/server/README.md b/server/README.md index 9127b978..c7054b8d 100644 --- a/server/README.md +++ b/server/README.md @@ -2,14 +2,14 @@ A Python gRPC server for BLOOM Inference -## Local Install (with poetry) +## Install ```shell make install ``` -## Local Install (with pip) +## Run ```shell -make pip-install +make run-dev ``` \ No newline at end of file diff --git a/server/bloom_inference/cli.py b/server/bloom_inference/cli.py index 751485bb..600c5273 100644 --- a/server/bloom_inference/cli.py +++ b/server/bloom_inference/cli.py @@ -2,9 +2,8 @@ import os import typer from pathlib import Path -from typing import Optional -from bloom_inference import prepare_weights, server +from bloom_inference import server, utils app = typer.Typer() @@ -13,13 +12,9 @@ app = typer.Typer() def serve( model_name: str, sharded: bool = False, - shard_directory: Optional[Path] = None, uds_path: Path = "/tmp/bloom-inference", ): if sharded: - assert ( - shard_directory is not None - ), "shard_directory must be set when sharded is True" assert ( os.getenv("RANK", None) is not None ), "RANK must be set when sharded is True" @@ -33,19 +28,14 @@ def serve( os.getenv("MASTER_PORT", None) is not None ), "MASTER_PORT must be set when sharded is True" - server.serve(model_name, sharded, uds_path, shard_directory) + server.serve(model_name, sharded, uds_path) @app.command() -def prepare_weights( +def download_weights( model_name: str, - shard_directory: Path, - cache_directory: Path, - num_shard: int = 1, ): - prepare_weights.prepare_weights( - model_name, cache_directory, shard_directory, num_shard - ) + utils.download_weights(model_name) if __name__ == "__main__": diff --git a/server/bloom_inference/model.py b/server/bloom_inference/model.py index 0ba90cee..eac54002 100644 --- a/server/bloom_inference/model.py +++ b/server/bloom_inference/model.py @@ -2,19 +2,24 @@ import torch import torch.distributed from dataclasses import dataclass -from pathlib import Path from typing import List, Tuple, Optional, Dict +from accelerate import init_empty_weights +from safetensors import safe_open from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig -from transformers.modeling_utils import no_init_weights +from transformers.models.bloom.parallel_layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) from bloom_inference.pb import generate_pb2 -from bloom_inference.prepare_weights import prepare_weights, match_suffix from bloom_inference.utils import ( StoppingCriteria, NextTokenChooser, initialize_torch_distributed, - set_default_dtype, + weight_files, + download_weights ) torch.manual_seed(0) @@ -63,7 +68,7 @@ class Batch: ) stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens)) - input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device) + input_ids = tokenizer(inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8).to(device) all_input_ids = input_ids["input_ids"].unsqueeze(-1) return cls( @@ -369,7 +374,7 @@ class BLOOM: class BLOOMSharded(BLOOM): - def __init__(self, model_name: str, shard_directory: Path): + def __init__(self, model_name: str): super(BLOOM, self).__init__() self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 @@ -382,30 +387,11 @@ class BLOOMSharded(BLOOM): self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") - # shard state_dict - if self.master: - # TODO @thomasw21 do some caching - shard_state_dict_paths = prepare_weights( - model_name, - shard_directory / "cache", - shard_directory, - tp_world_size=self.world_size, - ) - shard_state_dict_paths = [ - str(path.absolute()) for path in shard_state_dict_paths - ] - else: - shard_state_dict_paths = [None] * self.world_size - - torch.distributed.broadcast_object_list( - shard_state_dict_paths, src=0, group=self.process_group - ) - shard_state_dict_path = shard_state_dict_paths[self.rank] - config = AutoConfig.from_pretrained( model_name, slow_but_exact=False, tp_parallel=True ) config.pad_token_id = 3 + self.num_heads = config.n_head // self.process_group.size() # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -414,37 +400,88 @@ class BLOOMSharded(BLOOM): # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True - with set_default_dtype(dtype): - with no_init_weights(): - # we can probably set the device to `meta` here? - model = AutoModelForCausalLM.from_config(config).to(dtype) + # Only download weights for small models + if self.master and model_name == "bigscience/bloom-560m": + download_weights(model_name) torch.distributed.barrier(group=self.process_group) - # print_rank_0(f"Initialized model") - state_dict = torch.load(shard_state_dict_path) - # TODO @thomasw21: HACK in order to transpose all weight prior - for key in state_dict.keys(): - do_transpose = False - if not match_suffix(key, "weight"): - continue + filenames = weight_files(model_name) - for potential_suffix in [ - "self_attention.query_key_value.weight", - "self_attention.dense.weight", - "dense_h_to_4h.weight", - "dense_4h_to_h.weight", - ]: - if match_suffix(key, potential_suffix): - do_transpose = True + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) - if do_transpose: - state_dict[key] = state_dict[key].transpose(1, 0).contiguous() - - model.load_state_dict(state_dict, strict=False) - model.tie_weights() - self.model = model.to(self.device).eval() - self.num_heads = config.n_head // self.process_group.size() torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + device=self.device, + rank=self.rank, + world_size=self.world_size, + ) + self.model = model.eval().to(dtype) + torch.distributed.barrier(group=self.process_group) + + @staticmethod + def load_weights( + model, filenames: List[str], device: torch.device, rank: int, world_size: int + ): + parameters = dict(model.named_parameters()) + for file in filenames: + with safe_open(file, framework="pt", device=str(device)) as f: + for name in f.keys(): + full_name = f"transformer.{name}" + + module_name, param_name = full_name.rsplit(".", 1) + module = model.get_submodule(module_name) + current_tensor = parameters[full_name] + + slice_ = f.get_slice(name) + + if isinstance(module, TensorParallelColumnLinear): + if param_name == "weight": + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + tensor = tensor.transpose(1, 0) + else: + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + tensor = tensor.transpose(1, 0) + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + tensor = slice_[:] + + if current_tensor.shape != tensor.shape: + raise ValueError( + f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" + ) + + tensor = tensor.contiguous() + module._parameters[param_name] = tensor + if name == "word_embeddings.weight": + model.lm_head._parameters["weight"] = tensor def forward(self, input_ids, attention_mask, past_key_values: Optional = None): outputs = self.model.forward( diff --git a/server/bloom_inference/prepare_weights.py b/server/bloom_inference/prepare_weights.py deleted file mode 100644 index 5594998d..00000000 --- a/server/bloom_inference/prepare_weights.py +++ /dev/null @@ -1,185 +0,0 @@ -import torch -import os -import tempfile -import json - -from typing import BinaryIO -from joblib import Parallel, delayed -from functools import partial -from pathlib import Path -from tqdm import tqdm - -from huggingface_hub import hf_hub_url -from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status - - -def match_suffix(text, suffix): - return text[-len(suffix) :] == suffix - - -def http_get( - url: str, - temp_file: BinaryIO, - *, - timeout=10.0, - max_retries=0, -): - """ - Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub. - """ - r = _request_wrapper( - method="GET", - url=url, - stream=True, - timeout=timeout, - max_retries=max_retries, - ) - hf_raise_for_status(r) - for chunk in r.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks - temp_file.write(chunk) - - -def cache_download_url(url: str, root_dir: Path): - filename = root_dir / url.split("/")[-1] - - if not filename.exists(): - temp_file_manager = partial( - tempfile.NamedTemporaryFile, mode="wb", dir=root_dir, delete=False - ) - with temp_file_manager() as temp_file: - http_get(url, temp_file) - - os.replace(temp_file.name, filename) - return filename - - -def prepare_weights( - model_name: str, cache_path: Path, save_path: Path, tp_world_size: int -): - save_paths = [ - save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty" - for tp_rank in range(tp_world_size) - ] - - if all(save_path.exists() for save_path in save_paths): - print("Weights are already prepared") - return save_paths - - cache_path.mkdir(parents=True, exist_ok=True) - if model_name == "bigscience/bloom-560m": - url = hf_hub_url(model_name, filename="pytorch_model.bin") - cache_download_url(url, cache_path) - - elif model_name == "bigscience/bloom": - url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json") - index_path = cache_download_url(url, cache_path) - with index_path.open("r") as f: - index = json.load(f) - - # Get unique file names - weight_files = list( - set([filename for filename in index["weight_map"].values()]) - ) - urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files] - - Parallel(n_jobs=5)( - delayed(cache_download_url)(url, cache_path) for url in tqdm(urls) - ) - else: - raise ValueError(f"Unknown model name: {model_name}") - - shards_state_dicts = [{} for _ in range(tp_world_size)] - - for weight_path in tqdm(Path(cache_path).glob("*.bin")): - state_dict = torch.load(weight_path, map_location="cpu") - - keys = list(state_dict.keys()) - for state_name in keys: - state = state_dict[state_name] - if any( - match_suffix(state_name, candidate) - for candidate in [ - "self_attention.query_key_value.weight", - "self_attention.query_key_value.bias", - "mlp.dense_h_to_4h.weight", - "mlp.dense_h_to_4h.bias", - "word_embeddings.weight", - ] - ): - output_size = state.shape[0] - assert output_size % tp_world_size == 0 - block_size = output_size // tp_world_size - sharded_weights = torch.split(state, block_size, dim=0) - assert len(sharded_weights) == tp_world_size - - for tp_rank, shard in enumerate(sharded_weights): - shards_state_dicts[tp_rank][ - "transformer." + state_name - ] = shard.detach().clone() - - elif match_suffix(state_name, "lm_head.weight"): - output_size = state.shape[0] - assert output_size % tp_world_size == 0 - block_size = output_size // tp_world_size - sharded_weights = torch.split(state, block_size, dim=0) - assert len(sharded_weights) == tp_world_size - - for tp_rank, shard in enumerate(sharded_weights): - shards_state_dicts[tp_rank][state_name] = shard.detach().clone() - - elif any( - match_suffix(state_name, candidate) - for candidate in [ - "self_attention.dense.weight", - "mlp.dense_4h_to_h.weight", - ] - ): - input_size = state.shape[1] - assert input_size % tp_world_size == 0 - block_size = input_size // tp_world_size - sharded_weights = torch.split(state, block_size, dim=1) - assert len(sharded_weights) == tp_world_size - for tp_rank, shard in enumerate(sharded_weights): - shards_state_dicts[tp_rank][ - "transformer." + state_name - ] = shard.detach().clone() - - elif any( - match_suffix(state_name, candidate) - for candidate in [ - "self_attention.dense.bias", - "mlp.dense_4h_to_h.bias", - ] - ): - shards_state_dicts[0][ - "transformer." + state_name - ] = state.detach().clone() - for tp_rank in range(1, tp_world_size): - shards_state_dicts[tp_rank][ - "transformer." + state_name - ] = torch.zeros_like(state) - - else: - # We duplicate parameters across tp ranks - for tp_rank in range(tp_world_size): - shards_state_dicts[tp_rank][ - "transformer." + state_name - ] = state.detach().clone() - - del state_dict[state_name] # delete key from state_dict - del state # delete tensor - del state_dict - - # we save state_dict - for tp_rank, (save_path, shard_state_dict) in enumerate( - zip(save_paths, shards_state_dicts) - ): - save_paths.append(save_path) - save_path.parent.mkdir(parents=True, exist_ok=True) - if save_path.exists(): - print(f"Skipping {save_path} as it already exists") - else: - torch.save(shard_state_dict, save_path) - - return save_paths diff --git a/server/bloom_inference/server.py b/server/bloom_inference/server.py index 734aba43..fb02db6d 100644 --- a/server/bloom_inference/server.py +++ b/server/bloom_inference/server.py @@ -69,18 +69,14 @@ def serve( model_name: str, sharded: bool, uds_path: Path, - shard_directory: Optional[Path] = None, ): async def serve_inner( model_name: str, sharded: bool = False, - shard_directory: Optional[Path] = None, ): unix_socket_template = "unix://{}-{}" if sharded: - if shard_directory is None: - raise ValueError("shard_directory must be set when sharded is True") - model = BLOOMSharded(model_name, shard_directory) + model = BLOOMSharded(model_name) server_urls = [ unix_socket_template.format(uds_path, rank) for rank in range(model.world_size) @@ -109,4 +105,4 @@ def serve( print("Signal received. Shutting down") await server.stop(0) - asyncio.run(serve_inner(model_name, sharded, shard_directory)) + asyncio.run(serve_inner(model_name, sharded)) diff --git a/server/bloom_inference/utils.py b/server/bloom_inference/utils.py index c351806a..cca5403f 100644 --- a/server/bloom_inference/utils.py +++ b/server/bloom_inference/utils.py @@ -1,9 +1,14 @@ import os -import contextlib import torch import torch.distributed from datetime import timedelta + +from functools import partial +from joblib import Parallel, delayed +from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache +from huggingface_hub.utils import LocalEntryNotFoundError +from tqdm import tqdm from transformers.generation_logits_process import ( LogitsProcessorList, TemperatureLogitsWarper, @@ -87,11 +92,42 @@ def initialize_torch_distributed(): return torch.distributed.distributed_c10d._get_default_group(), rank, world_size -@contextlib.contextmanager -def set_default_dtype(dtype): - saved_dtype = torch.get_default_dtype() - torch.set_default_dtype(dtype) - try: - yield - finally: - torch.set_default_dtype(saved_dtype) +def weight_hub_files(model_name): + """Get the safetensors filenames on the hub""" + api = HfApi() + info = api.model_info(model_name) + filenames = [ + s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors") + ] + return filenames + + +def weight_files(model_name): + """Get the local safetensors filenames""" + filenames = weight_hub_files(model_name) + files = [] + for filename in filenames: + cache_file = try_to_load_from_cache(model_name, filename=filename) + if cache_file is None: + raise LocalEntryNotFoundError( + f"File {filename} of model {model_name} not found in " + f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " + f"Please run `bloom-inference-server download-weights {model_name}` first." + ) + files.append(cache_file) + + return files + + +def download_weights(model_name): + """Download the safetensors files from the hub""" + filenames = weight_hub_files(model_name) + + download_function = partial( + hf_hub_download, repo_id=model_name, local_files_only=False + ) + # FIXME: fix the overlapping progress bars + files = Parallel(n_jobs=5)( + delayed(download_function)(filename=filename) for filename in tqdm(filenames) + ) + return files diff --git a/server/poetry.lock b/server/poetry.lock index 3feee60d..38f14d35 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -43,7 +43,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "grpcio" -version = "1.49.1" +version = "1.50.0" description = "HTTP/2-based RPC framework" category = "main" optional = false @@ -53,31 +53,31 @@ python-versions = ">=3.7" six = ">=1.5.2" [package.extras] -protobuf = ["grpcio-tools (>=1.49.1)"] +protobuf = ["grpcio-tools (>=1.50.0)"] [[package]] name = "grpcio-reflection" -version = "1.49.1" +version = "1.50.0" description = "Standard Protobuf Reflection Service for gRPC" category = "main" optional = false python-versions = ">=3.6" [package.dependencies] -grpcio = ">=1.49.1" -protobuf = ">=4.21.3" +grpcio = ">=1.50.0" +protobuf = ">=4.21.6" [[package]] name = "grpcio-tools" -version = "1.49.1" +version = "1.50.0" description = "Protobuf code generator for gRPC" category = "dev" optional = false python-versions = ">=3.7" [package.dependencies] -grpcio = ">=1.49.1" -protobuf = ">=4.21.3,<5.0dev" +grpcio = ">=1.50.0" +protobuf = ">=4.21.6,<5.0dev" setuptools = "*" [[package]] @@ -90,7 +90,7 @@ python-versions = ">=3.7" [[package]] name = "numpy" -version = "1.23.3" +version = "1.23.4" description = "NumPy is the fundamental package for array computing with Python." category = "main" optional = false @@ -109,7 +109,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" [[package]] name = "protobuf" -version = "4.21.7" +version = "4.21.8" description = "" category = "main" optional = false @@ -117,7 +117,7 @@ python-versions = ">=3.7" [[package]] name = "psutil" -version = "5.9.2" +version = "5.9.3" description = "Cross-platform lib for process and system monitoring in Python." category = "main" optional = false @@ -147,7 +147,7 @@ python-versions = ">=3.6" [[package]] name = "setuptools" -version = "65.4.1" +version = "65.5.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" category = "dev" optional = false @@ -196,7 +196,7 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6. [[package]] name = "typing-extensions" -version = "4.3.0" +version = "4.4.0" description = "Backported and Experimental Type Hints for Python 3.7+" category = "main" optional = false @@ -221,190 +221,194 @@ colorama = [ {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, ] grpcio = [ - {file = "grpcio-1.49.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:fd86040232e805b8e6378b2348c928490ee595b058ce9aaa27ed8e4b0f172b20"}, - {file = "grpcio-1.49.1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:6fd0c9cede9552bf00f8c5791d257d5bf3790d7057b26c59df08be5e7a1e021d"}, - {file = "grpcio-1.49.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:d0d402e158d4e84e49c158cb5204119d55e1baf363ee98d6cb5dce321c3a065d"}, - {file = "grpcio-1.49.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:822ceec743d42a627e64ea266059a62d214c5a3cdfcd0d7fe2b7a8e4e82527c7"}, - {file = "grpcio-1.49.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2106d9c16527f0a85e2eea6e6b91a74fc99579c60dd810d8690843ea02bc0f5f"}, - {file = "grpcio-1.49.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:52dd02b7e7868233c571b49bc38ebd347c3bb1ff8907bb0cb74cb5f00c790afc"}, - {file = "grpcio-1.49.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:120fecba2ec5d14b5a15d11063b39783fda8dc8d24addd83196acb6582cabd9b"}, - {file = "grpcio-1.49.1-cp310-cp310-win32.whl", hash = "sha256:f1a3b88e3c53c1a6e6bed635ec1bbb92201bb6a1f2db186179f7f3f244829788"}, - {file = "grpcio-1.49.1-cp310-cp310-win_amd64.whl", hash = "sha256:a7d0017b92d3850abea87c1bdec6ea41104e71c77bca44c3e17f175c6700af62"}, - {file = "grpcio-1.49.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:9fb17ff8c0d56099ac6ebfa84f670c5a62228d6b5c695cf21c02160c2ac1446b"}, - {file = "grpcio-1.49.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:075f2d06e3db6b48a2157a1bcd52d6cbdca980dd18988fe6afdb41795d51625f"}, - {file = "grpcio-1.49.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46d93a1b4572b461a227f1db6b8d35a88952db1c47e5fadcf8b8a2f0e1dd9201"}, - {file = "grpcio-1.49.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc79b2b37d779ac42341ddef40ad5bf0966a64af412c89fc2b062e3ddabb093f"}, - {file = "grpcio-1.49.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5f8b3a971c7820ea9878f3fd70086240a36aeee15d1b7e9ecbc2743b0e785568"}, - {file = "grpcio-1.49.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49b301740cf5bc8fed4fee4c877570189ae3951432d79fa8e524b09353659811"}, - {file = "grpcio-1.49.1-cp311-cp311-win32.whl", hash = "sha256:1c66a25afc6c71d357867b341da594a5587db5849b48f4b7d5908d236bb62ede"}, - {file = "grpcio-1.49.1-cp311-cp311-win_amd64.whl", hash = "sha256:6b6c3a95d27846f4145d6967899b3ab25fffc6ae99544415e1adcacef84842d2"}, - {file = "grpcio-1.49.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:1cc400c8a2173d1c042997d98a9563e12d9bb3fb6ad36b7f355bc77c7663b8af"}, - {file = "grpcio-1.49.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:34f736bd4d0deae90015c0e383885b431444fe6b6c591dea288173df20603146"}, - {file = "grpcio-1.49.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:196082b9c89ebf0961dcd77cb114bed8171964c8e3063b9da2fb33536a6938ed"}, - {file = "grpcio-1.49.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c9f89c42749890618cd3c2464e1fbf88446e3d2f67f1e334c8e5db2f3272bbd"}, - {file = "grpcio-1.49.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64419cb8a5b612cdb1550c2fd4acbb7d4fb263556cf4625f25522337e461509e"}, - {file = "grpcio-1.49.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:8a5272061826e6164f96e3255405ef6f73b88fd3e8bef464c7d061af8585ac62"}, - {file = "grpcio-1.49.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ea9d0172445241ad7cb49577314e39d0af2c5267395b3561d7ced5d70458a9f3"}, - {file = "grpcio-1.49.1-cp37-cp37m-win32.whl", hash = "sha256:2070e87d95991473244c72d96d13596c751cb35558e11f5df5414981e7ed2492"}, - {file = "grpcio-1.49.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fcedcab49baaa9db4a2d240ac81f2d57eb0052b1c6a9501b46b8ae912720fbf"}, - {file = "grpcio-1.49.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:afbb3475cf7f4f7d380c2ca37ee826e51974f3e2665613996a91d6a58583a534"}, - {file = "grpcio-1.49.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:a4f9ba141380abde6c3adc1727f21529137a2552002243fa87c41a07e528245c"}, - {file = "grpcio-1.49.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:cf0a1fb18a7204b9c44623dfbd1465b363236ce70c7a4ed30402f9f60d8b743b"}, - {file = "grpcio-1.49.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:17bb6fe72784b630728c6cff9c9d10ccc3b6d04e85da6e0a7b27fb1d135fac62"}, - {file = "grpcio-1.49.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18305d5a082d1593b005a895c10041f833b16788e88b02bb81061f5ebcc465df"}, - {file = "grpcio-1.49.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b6a1b39e59ac5a3067794a0e498911cf2e37e4b19ee9e9977dc5e7051714f13f"}, - {file = "grpcio-1.49.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0e20d59aafc086b1cc68400463bddda6e41d3e5ed30851d1e2e0f6a2e7e342d3"}, - {file = "grpcio-1.49.1-cp38-cp38-win32.whl", hash = "sha256:e1e83233d4680863a421f3ee4a7a9b80d33cd27ee9ed7593bc93f6128302d3f2"}, - {file = "grpcio-1.49.1-cp38-cp38-win_amd64.whl", hash = "sha256:221d42c654d2a41fa31323216279c73ed17d92f533bc140a3390cc1bd78bf63c"}, - {file = "grpcio-1.49.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:fa9e6e61391e99708ac87fc3436f6b7b9c6b845dc4639b406e5e61901e1aacde"}, - {file = "grpcio-1.49.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:9b449e966ef518ce9c860d21f8afe0b0f055220d95bc710301752ac1db96dd6a"}, - {file = "grpcio-1.49.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:aa34d2ad9f24e47fa9a3172801c676e4037d862247e39030165fe83821a7aafd"}, - {file = "grpcio-1.49.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5207f4eed1b775d264fcfe379d8541e1c43b878f2b63c0698f8f5c56c40f3d68"}, - {file = "grpcio-1.49.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b24a74651438d45619ac67004638856f76cc13d78b7478f2457754cbcb1c8ad"}, - {file = "grpcio-1.49.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:fe763781669790dc8b9618e7e677c839c87eae6cf28b655ee1fa69ae04eea03f"}, - {file = "grpcio-1.49.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2f2ff7ba0f8f431f32d4b4bc3a3713426949d3533b08466c4ff1b2b475932ca8"}, - {file = "grpcio-1.49.1-cp39-cp39-win32.whl", hash = "sha256:08ff74aec8ff457a89b97152d36cb811dcc1d17cd5a92a65933524e363327394"}, - {file = "grpcio-1.49.1-cp39-cp39-win_amd64.whl", hash = "sha256:274ffbb39717918c514b35176510ae9be06e1d93121e84d50b350861dcb9a705"}, - {file = "grpcio-1.49.1.tar.gz", hash = "sha256:d4725fc9ec8e8822906ae26bb26f5546891aa7fbc3443de970cc556d43a5c99f"}, + {file = "grpcio-1.50.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:906f4d1beb83b3496be91684c47a5d870ee628715227d5d7c54b04a8de802974"}, + {file = "grpcio-1.50.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:2d9fd6e38b16c4d286a01e1776fdf6c7a4123d99ae8d6b3f0b4a03a34bf6ce45"}, + {file = "grpcio-1.50.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:4b123fbb7a777a2fedec684ca0b723d85e1d2379b6032a9a9b7851829ed3ca9a"}, + {file = "grpcio-1.50.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b2f77a90ba7b85bfb31329f8eab9d9540da2cf8a302128fb1241d7ea239a5469"}, + {file = "grpcio-1.50.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eea18a878cffc804506d39c6682d71f6b42ec1c151d21865a95fae743fda500"}, + {file = "grpcio-1.50.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b71916fa8f9eb2abd93151fafe12e18cebb302686b924bd4ec39266211da525"}, + {file = "grpcio-1.50.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:95ce51f7a09491fb3da8cf3935005bff19983b77c4e9437ef77235d787b06842"}, + {file = "grpcio-1.50.0-cp310-cp310-win32.whl", hash = "sha256:f7025930039a011ed7d7e7ef95a1cb5f516e23c5a6ecc7947259b67bea8e06ca"}, + {file = "grpcio-1.50.0-cp310-cp310-win_amd64.whl", hash = "sha256:05f7c248e440f538aaad13eee78ef35f0541e73498dd6f832fe284542ac4b298"}, + {file = "grpcio-1.50.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:ca8a2254ab88482936ce941485c1c20cdeaef0efa71a61dbad171ab6758ec998"}, + {file = "grpcio-1.50.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:3b611b3de3dfd2c47549ca01abfa9bbb95937eb0ea546ea1d762a335739887be"}, + {file = "grpcio-1.50.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a4cd8cb09d1bc70b3ea37802be484c5ae5a576108bad14728f2516279165dd7"}, + {file = "grpcio-1.50.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:156f8009e36780fab48c979c5605eda646065d4695deea4cfcbcfdd06627ddb6"}, + {file = "grpcio-1.50.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de411d2b030134b642c092e986d21aefb9d26a28bf5a18c47dd08ded411a3bc5"}, + {file = "grpcio-1.50.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d144ad10eeca4c1d1ce930faa105899f86f5d99cecfe0d7224f3c4c76265c15e"}, + {file = "grpcio-1.50.0-cp311-cp311-win32.whl", hash = "sha256:92d7635d1059d40d2ec29c8bf5ec58900120b3ce5150ef7414119430a4b2dd5c"}, + {file = "grpcio-1.50.0-cp311-cp311-win_amd64.whl", hash = "sha256:ce8513aee0af9c159319692bfbf488b718d1793d764798c3d5cff827a09e25ef"}, + {file = "grpcio-1.50.0-cp37-cp37m-linux_armv7l.whl", hash = "sha256:8e8999a097ad89b30d584c034929f7c0be280cd7851ac23e9067111167dcbf55"}, + {file = "grpcio-1.50.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:a50a1be449b9e238b9bd43d3857d40edf65df9416dea988929891d92a9f8a778"}, + {file = "grpcio-1.50.0-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:cf151f97f5f381163912e8952eb5b3afe89dec9ed723d1561d59cabf1e219a35"}, + {file = "grpcio-1.50.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a23d47f2fc7111869f0ff547f771733661ff2818562b04b9ed674fa208e261f4"}, + {file = "grpcio-1.50.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84d04dec64cc4ed726d07c5d17b73c343c8ddcd6b59c7199c801d6bbb9d9ed1"}, + {file = "grpcio-1.50.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:67dd41a31f6fc5c7db097a5c14a3fa588af54736ffc174af4411d34c4f306f68"}, + {file = "grpcio-1.50.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8d4c8e73bf20fb53fe5a7318e768b9734cf122fe671fcce75654b98ba12dfb75"}, + {file = "grpcio-1.50.0-cp37-cp37m-win32.whl", hash = "sha256:7489dbb901f4fdf7aec8d3753eadd40839c9085967737606d2c35b43074eea24"}, + {file = "grpcio-1.50.0-cp37-cp37m-win_amd64.whl", hash = "sha256:531f8b46f3d3db91d9ef285191825d108090856b3bc86a75b7c3930f16ce432f"}, + {file = "grpcio-1.50.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:d534d169673dd5e6e12fb57cc67664c2641361e1a0885545495e65a7b761b0f4"}, + {file = "grpcio-1.50.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:1d8d02dbb616c0a9260ce587eb751c9c7dc689bc39efa6a88cc4fa3e9c138a7b"}, + {file = "grpcio-1.50.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:baab51dcc4f2aecabf4ed1e2f57bceab240987c8b03533f1cef90890e6502067"}, + {file = "grpcio-1.50.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40838061e24f960b853d7bce85086c8e1b81c6342b1f4c47ff0edd44bbae2722"}, + {file = "grpcio-1.50.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:931e746d0f75b2a5cff0a1197d21827a3a2f400c06bace036762110f19d3d507"}, + {file = "grpcio-1.50.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:15f9e6d7f564e8f0776770e6ef32dac172c6f9960c478616c366862933fa08b4"}, + {file = "grpcio-1.50.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a4c23e54f58e016761b576976da6a34d876420b993f45f66a2bfb00363ecc1f9"}, + {file = "grpcio-1.50.0-cp38-cp38-win32.whl", hash = "sha256:3e4244c09cc1b65c286d709658c061f12c61c814be0b7030a2d9966ff02611e0"}, + {file = "grpcio-1.50.0-cp38-cp38-win_amd64.whl", hash = "sha256:8e69aa4e9b7f065f01d3fdcecbe0397895a772d99954bb82eefbb1682d274518"}, + {file = "grpcio-1.50.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:af98d49e56605a2912cf330b4627e5286243242706c3a9fa0bcec6e6f68646fc"}, + {file = "grpcio-1.50.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:080b66253f29e1646ac53ef288c12944b131a2829488ac3bac8f52abb4413c0d"}, + {file = "grpcio-1.50.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:ab5d0e3590f0a16cb88de4a3fa78d10eb66a84ca80901eb2c17c1d2c308c230f"}, + {file = "grpcio-1.50.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb11464f480e6103c59d558a3875bd84eed6723f0921290325ebe97262ae1347"}, + {file = "grpcio-1.50.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e07fe0d7ae395897981d16be61f0db9791f482f03fee7d1851fe20ddb4f69c03"}, + {file = "grpcio-1.50.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d75061367a69808ab2e84c960e9dce54749bcc1e44ad3f85deee3a6c75b4ede9"}, + {file = "grpcio-1.50.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ae23daa7eda93c1c49a9ecc316e027ceb99adbad750fbd3a56fa9e4a2ffd5ae0"}, + {file = "grpcio-1.50.0-cp39-cp39-win32.whl", hash = "sha256:177afaa7dba3ab5bfc211a71b90da1b887d441df33732e94e26860b3321434d9"}, + {file = "grpcio-1.50.0-cp39-cp39-win_amd64.whl", hash = "sha256:ea8ccf95e4c7e20419b7827aa5b6da6f02720270686ac63bd3493a651830235c"}, + {file = "grpcio-1.50.0.tar.gz", hash = "sha256:12b479839a5e753580b5e6053571de14006157f2ef9b71f38c56dc9b23b95ad6"}, ] grpcio-reflection = [ - {file = "grpcio-reflection-1.49.1.tar.gz", hash = "sha256:b755dfe61d5255a02fb8d0d845bd0027847dee68bf0763a2b286d664ed07ec4d"}, - {file = "grpcio_reflection-1.49.1-py3-none-any.whl", hash = "sha256:70a325a83c1c1ab583d368711e5733cbef5e068ad2c17cbe77df6e47e0311d1f"}, + {file = "grpcio-reflection-1.50.0.tar.gz", hash = "sha256:d75ff6a12f70f28ba6465d5a29781cf2d9c740eb809711990f61699019af1f72"}, + {file = "grpcio_reflection-1.50.0-py3-none-any.whl", hash = "sha256:92fea4a9527a8bcc52b615b6a7af655a5a118ab0f57227cc814dc2192f8a7455"}, ] grpcio-tools = [ - {file = "grpcio-tools-1.49.1.tar.gz", hash = "sha256:84cc64e5b46bad43d5d7bd2fd772b656eba0366961187a847e908e2cb735db91"}, - {file = "grpcio_tools-1.49.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:2dfb6c7ece84d46bd690b23d3e060d18115c8bc5047d2e8a33e6747ed323a348"}, - {file = "grpcio_tools-1.49.1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:8f452a107c054a04db2570f7851a07f060313c6e841b0d394ce6030d598290e6"}, - {file = "grpcio_tools-1.49.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:6a198871b582287213c4d70792bf275e1d7cf34eed1d019f534ddf4cd15ab039"}, - {file = "grpcio_tools-1.49.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a0cca67a7d0287bdc855d81fdd38dc949c4273273a74f832f9e520abe4f20bc6"}, - {file = "grpcio_tools-1.49.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdaff4c89eecb37c247b93025410db68114d97fa093cbb028e9bd7cda5912473"}, - {file = "grpcio_tools-1.49.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bb8773118ad315db317d7b22b5ff75d649ca20931733281209e7cbd8c0fad53e"}, - {file = "grpcio_tools-1.49.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7cc5534023735b8a8f56760b7c533918f874ce5a9064d7c5456d2709ae2b31f9"}, - {file = "grpcio_tools-1.49.1-cp310-cp310-win32.whl", hash = "sha256:d277642acbe305f5586f9597b78fb9970d6633eb9f89c61e429c92c296c37129"}, - {file = "grpcio_tools-1.49.1-cp310-cp310-win_amd64.whl", hash = "sha256:eed599cf08fc1a06c72492d3c5750c32f58de3750eddd984af1f257c14326701"}, - {file = "grpcio_tools-1.49.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:9e5c13809ab2f245398e8446c4c3b399a62d591db651e46806cccf52a700452e"}, - {file = "grpcio_tools-1.49.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:ab3d0ee9623720ee585fdf3753b3755d3144a4a8ae35bca8e3655fa2f41056be"}, - {file = "grpcio_tools-1.49.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ba87e3512bc91d78bf9febcfb522eadda171d2d4ddaf886066b0f01aa4929ad"}, - {file = "grpcio_tools-1.49.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13e13b3643e7577a3ec13b79689eb4d7548890b1e104c04b9ed6557a3c3dd452"}, - {file = "grpcio_tools-1.49.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:324f67d9cb4b7058b6ce45352fb64c20cc1fa04c34d97ad44772cfe6a4ae0cf5"}, - {file = "grpcio_tools-1.49.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a64bab81b220c50033f584f57978ebbea575f09c1ccee765cd5c462177988098"}, - {file = "grpcio_tools-1.49.1-cp311-cp311-win32.whl", hash = "sha256:f632d376f92f23e5931697a3acf1b38df7eb719774213d93c52e02acd2d529ac"}, - {file = "grpcio_tools-1.49.1-cp311-cp311-win_amd64.whl", hash = "sha256:28ff2b978d9509474928b9c096a0cce4eaa9c8f7046136aee1545f6211ed8126"}, - {file = "grpcio_tools-1.49.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:46afd3cb7e555187451a5d283f108cdef397952a662cb48680afc615b158864a"}, - {file = "grpcio_tools-1.49.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:9284568b728e41fa8f7e9c2e7399545d605f75d8072ef0e9aa2a05655cb679eb"}, - {file = "grpcio_tools-1.49.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:aa34442cf787732cb41f2aa6172007e24f480b8b9d3dc5166de80d63e9072ea4"}, - {file = "grpcio_tools-1.49.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8c9eb5a4250905414cd53a68caea3eb8f0c515aadb689e6e81b71ebe9ab5c6"}, - {file = "grpcio_tools-1.49.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab15db024051bf21feb21c29cb2c3ea0a2e4f5cf341d46ef76e17fcf6aaef164"}, - {file = "grpcio_tools-1.49.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:502084b622f758bef620a9107c2db9fcdf66d26c7e0e481d6bb87db4dc917d70"}, - {file = "grpcio_tools-1.49.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4085890b77c640085f82bf1e90a0ea166ce48000bc2f5180914b974783c9c0a8"}, - {file = "grpcio_tools-1.49.1-cp37-cp37m-win32.whl", hash = "sha256:da0edb984699769ce02e18e3392d54b59a7a3f93acd285a68043f5bde4fc028e"}, - {file = "grpcio_tools-1.49.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9887cd622770271101a7dd1832845d64744c3f88fd11ccb2620394079197a42e"}, - {file = "grpcio_tools-1.49.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:8440fe7dae6a40c279e3a24b82793735babd38ecbb0d07bb712ff9c8963185d9"}, - {file = "grpcio_tools-1.49.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:b5de2bb7dd6b6231da9b1556ade981513330b740e767f1d902c71ceee0a7d196"}, - {file = "grpcio_tools-1.49.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:1e6f06a763aea7836b63d9c117347f2bf7038008ceef72758815c9e09c5fb1fc"}, - {file = "grpcio_tools-1.49.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e31562f90120318c5395aabec0f2f69ad8c14b6676996b7730d9d2eaf9415d57"}, - {file = "grpcio_tools-1.49.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49ef9a4e389a618157a9daa9fafdfeeaef1ece9adda7f50f85db928f24d4b3e8"}, - {file = "grpcio_tools-1.49.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b384cb8e8d9bcb55ee8f9b064374561c7a1a05d848249581403d36fc7060032f"}, - {file = "grpcio_tools-1.49.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:73732f77943ac3e898879cbb29c27253aa3c47566b8a59780fd24c6a54de1b66"}, - {file = "grpcio_tools-1.49.1-cp38-cp38-win32.whl", hash = "sha256:b594b2745a5ba9e7a76ce561bc5ab40bc65bb44743c505529b1e4f12af29104d"}, - {file = "grpcio_tools-1.49.1-cp38-cp38-win_amd64.whl", hash = "sha256:680fbc88f8709ddcabb88f86749f2d8e429160890cff2c70680880a6970d4eef"}, - {file = "grpcio_tools-1.49.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:e8c3869121860f6767eedb7d24fc54dfd71e737fdfbb26e1334684606f3274fd"}, - {file = "grpcio_tools-1.49.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:73e9d7c886ba10e20c97d1dab0ff961ba5800757ae5e31be21b1cda8130c52f8"}, - {file = "grpcio_tools-1.49.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:1760de2dd2c4f08de87b039043a4797f3c17193656e7e3eb84e92f0517083c0c"}, - {file = "grpcio_tools-1.49.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cd4b1e216dd04d9245ee8f4e601a1f98c25e6e417ea5cf8d825c50589a8b447e"}, - {file = "grpcio_tools-1.49.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1c28751ab5955cae563d07677e799233f0fe1c0fc49d9cbd61ff1957e83617f"}, - {file = "grpcio_tools-1.49.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c24239c3ee9ed16314c14b4e24437b5079ebc344f343f33629a582f8699f583b"}, - {file = "grpcio_tools-1.49.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:892d3dacf1942820f0b7a868a30e6fbcdf5bec08543b682c7274b0101cee632d"}, - {file = "grpcio_tools-1.49.1-cp39-cp39-win32.whl", hash = "sha256:704d21509ec06efc9d034dbe70e7152715aac004941f4f0f553cf3a0aff15bd5"}, - {file = "grpcio_tools-1.49.1-cp39-cp39-win_amd64.whl", hash = "sha256:1efa0c221c719433f441ac0e026fc3c4dbc9a1a08a552ecdc707775e2f2fbbae"}, + {file = "grpcio-tools-1.50.0.tar.gz", hash = "sha256:88b75f2afd889c7c6939f58d76b58ab84de4723c7de882a1f8448af6632e256f"}, + {file = "grpcio_tools-1.50.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:fad4f44220847691605e4079ec67b6663064d5b1283bee5a89dd6f7726672d08"}, + {file = "grpcio_tools-1.50.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:4fe1cd201f0e0f601551c20bea2a8c1e109767ce341ac263e88d5a0acd0a124c"}, + {file = "grpcio_tools-1.50.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:66113bc60db8ccb15e2d7b79efd5757bba33b15e653f20f473c776857bd4415d"}, + {file = "grpcio_tools-1.50.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:71e98662f76d3bcf031f36f727ffee594c57006d5dce7b6d031d030d194f054d"}, + {file = "grpcio_tools-1.50.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1891aca6042fc403c1d492d03ce1c4bac0053c26132063b7b0171665a0aa6086"}, + {file = "grpcio_tools-1.50.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:54528cec27238761055b28a34ea29392c8e90ac8d8966ee3ccfc9f906eb26fd3"}, + {file = "grpcio_tools-1.50.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89010e8e0d158966af478d9b345d29d98e52d62464cb93ef192125e94fb904c1"}, + {file = "grpcio_tools-1.50.0-cp310-cp310-win32.whl", hash = "sha256:e42ac5aa09debbf7b72444af7f390713e85b8808557e920d9d8071f14127c0eb"}, + {file = "grpcio_tools-1.50.0-cp310-cp310-win_amd64.whl", hash = "sha256:c876dbaa6897e45ed074df1562006ac9fdeb5c7235f623cda2c29894fe4f8fba"}, + {file = "grpcio_tools-1.50.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:3995d96635d6decaf9d5d2e071677782b08ea3186d7cb8ad93f6c2dd97ec9097"}, + {file = "grpcio_tools-1.50.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1736defd9fc54f942f500dd71555816487e52f07d25b0a58ff4089c375f0b5c3"}, + {file = "grpcio_tools-1.50.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fecbc5e773c1c1971641227d8fcd5de97533dbbc0bd7854b436ceb57243f4b2a"}, + {file = "grpcio_tools-1.50.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8178a475de70e6ae0a7e3f6d7ec1415c9dd10680654dbea9a1caf79665b0ed4"}, + {file = "grpcio_tools-1.50.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d9b8b59a76f67270ce8a4fd7af005732356d5b4ebaffa4598343ea63677bad63"}, + {file = "grpcio_tools-1.50.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ead51910d83123648b1e36e2de8483b66e9ba358ca1d8965005ba7713cbcb7db"}, + {file = "grpcio_tools-1.50.0-cp311-cp311-win32.whl", hash = "sha256:4c928610ceda7890cf7ddd3cb4cc16fb81995063bf4fdb5171bba572d508079b"}, + {file = "grpcio_tools-1.50.0-cp311-cp311-win_amd64.whl", hash = "sha256:56c28f50b9c88cbfc62861de86a1f8b2454d3b48cff85a73d523ec22f130635d"}, + {file = "grpcio_tools-1.50.0-cp37-cp37m-linux_armv7l.whl", hash = "sha256:675634367846fc906e8b6016830b3d361f9e788bce7820b4e6432d914cb0e25d"}, + {file = "grpcio_tools-1.50.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:f6dbd9d86aa2fbeb8a104ae8edd23c55f182829db58cae3f28842024e08b1071"}, + {file = "grpcio_tools-1.50.0-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:2869fba92670d5be3730f1ad108cbacd3f9a7b7afa57187a3c39379513b3fde7"}, + {file = "grpcio_tools-1.50.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f211a8a6f159b496097cabbf4bdbb6d01db0fd7a8d516dd0e577ee2ddccded0"}, + {file = "grpcio_tools-1.50.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27f734e44631d755481a179ffd5f2659eaf1e9755cf3539e291d7a28d8373cc8"}, + {file = "grpcio_tools-1.50.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:9bc97b52d577406c240b0d6acf68ab5100c67c903792f1f9ce04d871ca83de53"}, + {file = "grpcio_tools-1.50.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a9816a6c8d1b7adb65217d01624006704faf02b5daa15b6e145b63a5e559d791"}, + {file = "grpcio_tools-1.50.0-cp37-cp37m-win32.whl", hash = "sha256:c5dde47bae615554349dd09185eaaeffd9a28e961de347cf0b3b98a853b43a6c"}, + {file = "grpcio_tools-1.50.0-cp37-cp37m-win_amd64.whl", hash = "sha256:e51133b78b67bd4c06e99bea9c631f29acd52af2b2a254926d192f277f9067b5"}, + {file = "grpcio_tools-1.50.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fec4336e067f2bdd45d74c2466dcd06f245e972d2139187d87cfccf847dbaf50"}, + {file = "grpcio_tools-1.50.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:8d74672162474e2592c50708408b4ee145994451ee5171c9b8f580504b987ad3"}, + {file = "grpcio_tools-1.50.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:e6baae6fa9a88c1895193f9610b54d483753ad3e3f91f5d5a10a506afe342a43"}, + {file = "grpcio_tools-1.50.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:927c8cbbfbf3bca46cd05bbabfd8cb6c99f5dbaeb860f2786f9286ce6894ad3d"}, + {file = "grpcio_tools-1.50.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fe40a66d3ccde959f7290b881c78e8c420fd54289fede5c9119d993703e1a49"}, + {file = "grpcio_tools-1.50.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:86a9e9b2cc30797472bd421c75b42fe1062519ef5675243cb33d99de33af279c"}, + {file = "grpcio_tools-1.50.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c1646bfec6d707bcaafea69088770151ba291ede788f2cda5efdc09c1eac5c1c"}, + {file = "grpcio_tools-1.50.0-cp38-cp38-win32.whl", hash = "sha256:dedfe2947ad4ccc8cb2f7d0cac5766dd623901f29e1953a2854126559c317885"}, + {file = "grpcio_tools-1.50.0-cp38-cp38-win_amd64.whl", hash = "sha256:73ba21fe500e2476d391cc92aa05545fdc71bdeedb5b9ec2b22e9a3ccc298a2f"}, + {file = "grpcio_tools-1.50.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:23381ac6a19de0d3cb470f4d1fb450267014234a3923e3c43abb858049f5caf4"}, + {file = "grpcio_tools-1.50.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:c61730ae21533ef992a89df0321ebe59e344f9f0d1d54572b92d9468969c75b3"}, + {file = "grpcio_tools-1.50.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:e5adcbec235dce997efacb44cca5cde24db47ad9c432fcbca875185d2095c55a"}, + {file = "grpcio_tools-1.50.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c40f918d1fe0cb094bfb09ab7567a92f5368fa0b9935235680fe205ddc37518"}, + {file = "grpcio_tools-1.50.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2313c8c73a5003d789fe05f4d01112ccf8330d3d141f5b77b514200d742193f6"}, + {file = "grpcio_tools-1.50.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:578f9137ebca0f152136b47b4895a6b4dabcf17e41f957b03637d1ae096ff67b"}, + {file = "grpcio_tools-1.50.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d8aaff094d94d2119fffe7e0d08ca59d42076971f9d85857e99288e500ab31f0"}, + {file = "grpcio_tools-1.50.0-cp39-cp39-win32.whl", hash = "sha256:e1a8f9a57bbcc2e633aaf327e39830527f3c1f7add18c7580f3058fe9a0fa780"}, + {file = "grpcio_tools-1.50.0-cp39-cp39-win_amd64.whl", hash = "sha256:b7eb7a84d9171c0ae1550833f4a6ca52372bed9db0fa10f8c9dbe6ca65f97a8c"}, ] joblib = [ {file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"}, {file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"}, ] numpy = [ - {file = "numpy-1.23.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c9f707b5bb73bf277d812ded9896f9512a43edff72712f31667d0a8c2f8e71ee"}, - {file = "numpy-1.23.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ffcf105ecdd9396e05a8e58e81faaaf34d3f9875f137c7372450baa5d77c9a54"}, - {file = "numpy-1.23.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ea3f98a0ffce3f8f57675eb9119f3f4edb81888b6874bc1953f91e0b1d4f440"}, - {file = "numpy-1.23.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:004f0efcb2fe1c0bd6ae1fcfc69cc8b6bf2407e0f18be308612007a0762b4089"}, - {file = "numpy-1.23.3-cp310-cp310-win32.whl", hash = "sha256:98dcbc02e39b1658dc4b4508442a560fe3ca5ca0d989f0df062534e5ca3a5c1a"}, - {file = "numpy-1.23.3-cp310-cp310-win_amd64.whl", hash = "sha256:39a664e3d26ea854211867d20ebcc8023257c1800ae89773cbba9f9e97bae036"}, - {file = "numpy-1.23.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1f27b5322ac4067e67c8f9378b41c746d8feac8bdd0e0ffede5324667b8a075c"}, - {file = "numpy-1.23.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2ad3ec9a748a8943e6eb4358201f7e1c12ede35f510b1a2221b70af4bb64295c"}, - {file = "numpy-1.23.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bdc9febce3e68b697d931941b263c59e0c74e8f18861f4064c1f712562903411"}, - {file = "numpy-1.23.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:301c00cf5e60e08e04d842fc47df641d4a181e651c7135c50dc2762ffe293dbd"}, - {file = "numpy-1.23.3-cp311-cp311-win32.whl", hash = "sha256:7cd1328e5bdf0dee621912f5833648e2daca72e3839ec1d6695e91089625f0b4"}, - {file = "numpy-1.23.3-cp311-cp311-win_amd64.whl", hash = "sha256:8355fc10fd33a5a70981a5b8a0de51d10af3688d7a9e4a34fcc8fa0d7467bb7f"}, - {file = "numpy-1.23.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bc6e8da415f359b578b00bcfb1d08411c96e9a97f9e6c7adada554a0812a6cc6"}, - {file = "numpy-1.23.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:22d43376ee0acd547f3149b9ec12eec2f0ca4a6ab2f61753c5b29bb3e795ac4d"}, - {file = "numpy-1.23.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a64403f634e5ffdcd85e0b12c08f04b3080d3e840aef118721021f9b48fc1460"}, - {file = "numpy-1.23.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efd9d3abe5774404becdb0748178b48a218f1d8c44e0375475732211ea47c67e"}, - {file = "numpy-1.23.3-cp38-cp38-win32.whl", hash = "sha256:f8c02ec3c4c4fcb718fdf89a6c6f709b14949408e8cf2a2be5bfa9c49548fd85"}, - {file = "numpy-1.23.3-cp38-cp38-win_amd64.whl", hash = "sha256:e868b0389c5ccfc092031a861d4e158ea164d8b7fdbb10e3b5689b4fc6498df6"}, - {file = "numpy-1.23.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:09f6b7bdffe57fc61d869a22f506049825d707b288039d30f26a0d0d8ea05164"}, - {file = "numpy-1.23.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8c79d7cf86d049d0c5089231a5bcd31edb03555bd93d81a16870aa98c6cfb79d"}, - {file = "numpy-1.23.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5d5420053bbb3dd64c30e58f9363d7a9c27444c3648e61460c1237f9ec3fa14"}, - {file = "numpy-1.23.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5422d6a1ea9b15577a9432e26608c73a78faf0b9039437b075cf322c92e98e7"}, - {file = "numpy-1.23.3-cp39-cp39-win32.whl", hash = "sha256:c1ba66c48b19cc9c2975c0d354f24058888cdc674bebadceb3cdc9ec403fb5d1"}, - {file = "numpy-1.23.3-cp39-cp39-win_amd64.whl", hash = "sha256:78a63d2df1d947bd9d1b11d35564c2f9e4b57898aae4626638056ec1a231c40c"}, - {file = "numpy-1.23.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:17c0e467ade9bda685d5ac7f5fa729d8d3e76b23195471adae2d6a6941bd2c18"}, - {file = "numpy-1.23.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91b8d6768a75247026e951dce3b2aac79dc7e78622fc148329135ba189813584"}, - {file = "numpy-1.23.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:94c15ca4e52671a59219146ff584488907b1f9b3fc232622b47e2cf832e94fb8"}, - {file = "numpy-1.23.3.tar.gz", hash = "sha256:51bf49c0cd1d52be0a240aa66f3458afc4b95d8993d2d04f0d91fa60c10af6cd"}, + {file = "numpy-1.23.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d79ada05005f6f4f337d3bb9de8a7774f259341c70bc88047a1f7b96a4bcb2"}, + {file = "numpy-1.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:926db372bc4ac1edf81cfb6c59e2a881606b409ddc0d0920b988174b2e2a767f"}, + {file = "numpy-1.23.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c237129f0e732885c9a6076a537e974160482eab8f10db6292e92154d4c67d71"}, + {file = "numpy-1.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8365b942f9c1a7d0f0dc974747d99dd0a0cdfc5949a33119caf05cb314682d3"}, + {file = "numpy-1.23.4-cp310-cp310-win32.whl", hash = "sha256:2341f4ab6dba0834b685cce16dad5f9b6606ea8a00e6da154f5dbded70fdc4dd"}, + {file = "numpy-1.23.4-cp310-cp310-win_amd64.whl", hash = "sha256:d331afac87c92373826af83d2b2b435f57b17a5c74e6268b79355b970626e329"}, + {file = "numpy-1.23.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:488a66cb667359534bc70028d653ba1cf307bae88eab5929cd707c761ff037db"}, + {file = "numpy-1.23.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ce03305dd694c4873b9429274fd41fc7eb4e0e4dea07e0af97a933b079a5814f"}, + {file = "numpy-1.23.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8981d9b5619569899666170c7c9748920f4a5005bf79c72c07d08c8a035757b0"}, + {file = "numpy-1.23.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a70a7d3ce4c0e9284e92285cba91a4a3f5214d87ee0e95928f3614a256a1488"}, + {file = "numpy-1.23.4-cp311-cp311-win32.whl", hash = "sha256:5e13030f8793e9ee42f9c7d5777465a560eb78fa7e11b1c053427f2ccab90c79"}, + {file = "numpy-1.23.4-cp311-cp311-win_amd64.whl", hash = "sha256:7607b598217745cc40f751da38ffd03512d33ec06f3523fb0b5f82e09f6f676d"}, + {file = "numpy-1.23.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7ab46e4e7ec63c8a5e6dbf5c1b9e1c92ba23a7ebecc86c336cb7bf3bd2fb10e5"}, + {file = "numpy-1.23.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a8aae2fb3180940011b4862b2dd3756616841c53db9734b27bb93813cd79fce6"}, + {file = "numpy-1.23.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c053d7557a8f022ec823196d242464b6955a7e7e5015b719e76003f63f82d0f"}, + {file = "numpy-1.23.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0882323e0ca4245eb0a3d0a74f88ce581cc33aedcfa396e415e5bba7bf05f68"}, + {file = "numpy-1.23.4-cp38-cp38-win32.whl", hash = "sha256:dada341ebb79619fe00a291185bba370c9803b1e1d7051610e01ed809ef3a4ba"}, + {file = "numpy-1.23.4-cp38-cp38-win_amd64.whl", hash = "sha256:0fe563fc8ed9dc4474cbf70742673fc4391d70f4363f917599a7fa99f042d5a8"}, + {file = "numpy-1.23.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c67b833dbccefe97cdd3f52798d430b9d3430396af7cdb2a0c32954c3ef73894"}, + {file = "numpy-1.23.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f76025acc8e2114bb664294a07ede0727aa75d63a06d2fae96bf29a81747e4a7"}, + {file = "numpy-1.23.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12ac457b63ec8ded85d85c1e17d85efd3c2b0967ca39560b307a35a6703a4735"}, + {file = "numpy-1.23.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95de7dc7dc47a312f6feddd3da2500826defdccbc41608d0031276a24181a2c0"}, + {file = "numpy-1.23.4-cp39-cp39-win32.whl", hash = "sha256:f2f390aa4da44454db40a1f0201401f9036e8d578a25f01a6e237cea238337ef"}, + {file = "numpy-1.23.4-cp39-cp39-win_amd64.whl", hash = "sha256:f260da502d7441a45695199b4e7fd8ca87db659ba1c78f2bbf31f934fe76ae0e"}, + {file = "numpy-1.23.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:61be02e3bf810b60ab74e81d6d0d36246dbfb644a462458bb53b595791251911"}, + {file = "numpy-1.23.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:296d17aed51161dbad3c67ed6d164e51fcd18dbcd5dd4f9d0a9c6055dce30810"}, + {file = "numpy-1.23.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4d52914c88b4930dafb6c48ba5115a96cbab40f45740239d9f4159c4ba779962"}, + {file = "numpy-1.23.4.tar.gz", hash = "sha256:ed2cc92af0efad20198638c69bb0fc2870a58dabfba6eb722c933b48556c686c"}, ] packaging = [ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, ] protobuf = [ - {file = "protobuf-4.21.7-cp310-abi3-win32.whl", hash = "sha256:c7cb105d69a87416bd9023e64324e1c089593e6dae64d2536f06bcbe49cd97d8"}, - {file = "protobuf-4.21.7-cp310-abi3-win_amd64.whl", hash = "sha256:3ec85328a35a16463c6f419dbce3c0fc42b3e904d966f17f48bae39597c7a543"}, - {file = "protobuf-4.21.7-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:db9056b6a11cb5131036d734bcbf91ef3ef9235d6b681b2fc431cbfe5a7f2e56"}, - {file = "protobuf-4.21.7-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:ca200645d6235ce0df3ccfdff1567acbab35c4db222a97357806e015f85b5744"}, - {file = "protobuf-4.21.7-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:b019c79e23a80735cc8a71b95f76a49a262f579d6b84fd20a0b82279f40e2cc1"}, - {file = "protobuf-4.21.7-cp37-cp37m-win32.whl", hash = "sha256:d3f89ccf7182293feba2de2739c8bf34fed1ed7c65a5cf987be00311acac57c1"}, - {file = "protobuf-4.21.7-cp37-cp37m-win_amd64.whl", hash = "sha256:a74d96cd960b87b4b712797c741bb3ea3a913f5c2dc4b6cbe9c0f8360b75297d"}, - {file = "protobuf-4.21.7-cp38-cp38-win32.whl", hash = "sha256:8e09d1916386eca1ef1353767b6efcebc0a6859ed7f73cb7fb974feba3184830"}, - {file = "protobuf-4.21.7-cp38-cp38-win_amd64.whl", hash = "sha256:9e355f2a839d9930d83971b9f562395e13493f0e9211520f8913bd11efa53c02"}, - {file = "protobuf-4.21.7-cp39-cp39-win32.whl", hash = "sha256:f370c0a71712f8965023dd5b13277444d3cdfecc96b2c778b0e19acbfd60df6e"}, - {file = "protobuf-4.21.7-cp39-cp39-win_amd64.whl", hash = "sha256:9643684232b6b340b5e63bb69c9b4904cdd39e4303d498d1a92abddc7e895b7f"}, - {file = "protobuf-4.21.7-py2.py3-none-any.whl", hash = "sha256:8066322588d4b499869bf9f665ebe448e793036b552f68c585a9b28f1e393f66"}, - {file = "protobuf-4.21.7-py3-none-any.whl", hash = "sha256:58b81358ec6c0b5d50df761460ae2db58405c063fd415e1101209221a0a810e1"}, - {file = "protobuf-4.21.7.tar.gz", hash = "sha256:71d9dba03ed3432c878a801e2ea51e034b0ea01cf3a4344fb60166cb5f6c8757"}, + {file = "protobuf-4.21.8-cp310-abi3-win32.whl", hash = "sha256:c252c55ee15175aa1b21b7b9896e6add5162d066d5202e75c39f96136f08cce3"}, + {file = "protobuf-4.21.8-cp310-abi3-win_amd64.whl", hash = "sha256:809ca0b225d3df42655a12f311dd0f4148a943c51f1ad63c38343e457492b689"}, + {file = "protobuf-4.21.8-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bbececaf3cfea9ea65ebb7974e6242d310d2a7772a6f015477e0d79993af4511"}, + {file = "protobuf-4.21.8-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:b02eabb9ebb1a089ed20626a90ad7a69cee6bcd62c227692466054b19c38dd1f"}, + {file = "protobuf-4.21.8-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:4761201b93e024bb70ee3a6a6425d61f3152ca851f403ba946fb0cde88872661"}, + {file = "protobuf-4.21.8-cp37-cp37m-win32.whl", hash = "sha256:f2d55ff22ec300c4d954d3b0d1eeb185681ec8ad4fbecff8a5aee6a1cdd345ba"}, + {file = "protobuf-4.21.8-cp37-cp37m-win_amd64.whl", hash = "sha256:c5f94911dd8feb3cd3786fc90f7565c9aba7ce45d0f254afd625b9628f578c3f"}, + {file = "protobuf-4.21.8-cp38-cp38-win32.whl", hash = "sha256:b37b76efe84d539f16cba55ee0036a11ad91300333abd213849cbbbb284b878e"}, + {file = "protobuf-4.21.8-cp38-cp38-win_amd64.whl", hash = "sha256:2c92a7bfcf4ae76a8ac72e545e99a7407e96ffe52934d690eb29a8809ee44d7b"}, + {file = "protobuf-4.21.8-cp39-cp39-win32.whl", hash = "sha256:89d641be4b5061823fa0e463c50a2607a97833e9f8cfb36c2f91ef5ccfcc3861"}, + {file = "protobuf-4.21.8-cp39-cp39-win_amd64.whl", hash = "sha256:bc471cf70a0f53892fdd62f8cd4215f0af8b3f132eeee002c34302dff9edd9b6"}, + {file = "protobuf-4.21.8-py2.py3-none-any.whl", hash = "sha256:a55545ce9eec4030cf100fcb93e861c622d927ef94070c1a3c01922902464278"}, + {file = "protobuf-4.21.8-py3-none-any.whl", hash = "sha256:0f236ce5016becd989bf39bd20761593e6d8298eccd2d878eda33012645dc369"}, + {file = "protobuf-4.21.8.tar.gz", hash = "sha256:427426593b55ff106c84e4a88cac855175330cb6eb7e889e85aaa7b5652b686d"}, ] psutil = [ - {file = "psutil-5.9.2-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:8f024fbb26c8daf5d70287bb3edfafa22283c255287cf523c5d81721e8e5d82c"}, - {file = "psutil-5.9.2-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:b2f248ffc346f4f4f0d747ee1947963613216b06688be0be2e393986fe20dbbb"}, - {file = "psutil-5.9.2-cp27-cp27m-win32.whl", hash = "sha256:b1928b9bf478d31fdffdb57101d18f9b70ed4e9b0e41af751851813547b2a9ab"}, - {file = "psutil-5.9.2-cp27-cp27m-win_amd64.whl", hash = "sha256:404f4816c16a2fcc4eaa36d7eb49a66df2d083e829d3e39ee8759a411dbc9ecf"}, - {file = "psutil-5.9.2-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:94e621c6a4ddb2573d4d30cba074f6d1aa0186645917df42c811c473dd22b339"}, - {file = "psutil-5.9.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:256098b4f6ffea6441eb54ab3eb64db9ecef18f6a80d7ba91549195d55420f84"}, - {file = "psutil-5.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:614337922702e9be37a39954d67fdb9e855981624d8011a9927b8f2d3c9625d9"}, - {file = "psutil-5.9.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:39ec06dc6c934fb53df10c1672e299145ce609ff0611b569e75a88f313634969"}, - {file = "psutil-5.9.2-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3ac2c0375ef498e74b9b4ec56df3c88be43fe56cac465627572dbfb21c4be34"}, - {file = "psutil-5.9.2-cp310-cp310-win32.whl", hash = "sha256:e4c4a7636ffc47b7141864f1c5e7d649f42c54e49da2dd3cceb1c5f5d29bfc85"}, - {file = "psutil-5.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:f4cb67215c10d4657e320037109939b1c1d2fd70ca3d76301992f89fe2edb1f1"}, - {file = "psutil-5.9.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:dc9bda7d5ced744622f157cc8d8bdd51735dafcecff807e928ff26bdb0ff097d"}, - {file = "psutil-5.9.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d75291912b945a7351d45df682f9644540d564d62115d4a20d45fa17dc2d48f8"}, - {file = "psutil-5.9.2-cp36-cp36m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4018d5f9b6651f9896c7a7c2c9f4652e4eea53f10751c4e7d08a9093ab587ec"}, - {file = "psutil-5.9.2-cp36-cp36m-win32.whl", hash = "sha256:f40ba362fefc11d6bea4403f070078d60053ed422255bd838cd86a40674364c9"}, - {file = "psutil-5.9.2-cp36-cp36m-win_amd64.whl", hash = "sha256:9770c1d25aee91417eba7869139d629d6328a9422ce1cdd112bd56377ca98444"}, - {file = "psutil-5.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:42638876b7f5ef43cef8dcf640d3401b27a51ee3fa137cb2aa2e72e188414c32"}, - {file = "psutil-5.9.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:91aa0dac0c64688667b4285fa29354acfb3e834e1fd98b535b9986c883c2ce1d"}, - {file = "psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fb54941aac044a61db9d8eb56fc5bee207db3bc58645d657249030e15ba3727"}, - {file = "psutil-5.9.2-cp37-cp37m-win32.whl", hash = "sha256:7cbb795dcd8ed8fd238bc9e9f64ab188f3f4096d2e811b5a82da53d164b84c3f"}, - {file = "psutil-5.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:5d39e3a2d5c40efa977c9a8dd4f679763c43c6c255b1340a56489955dbca767c"}, - {file = "psutil-5.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fd331866628d18223a4265371fd255774affd86244fc307ef66eaf00de0633d5"}, - {file = "psutil-5.9.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b315febaebae813326296872fdb4be92ad3ce10d1d742a6b0c49fb619481ed0b"}, - {file = "psutil-5.9.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7929a516125f62399d6e8e026129c8835f6c5a3aab88c3fff1a05ee8feb840d"}, - {file = "psutil-5.9.2-cp38-cp38-win32.whl", hash = "sha256:561dec454853846d1dd0247b44c2e66a0a0c490f937086930ec4b8f83bf44f06"}, - {file = "psutil-5.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:67b33f27fc0427483b61563a16c90d9f3b547eeb7af0ef1b9fe024cdc9b3a6ea"}, - {file = "psutil-5.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b3591616fa07b15050b2f87e1cdefd06a554382e72866fcc0ab2be9d116486c8"}, - {file = "psutil-5.9.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:14b29f581b5edab1f133563272a6011925401804d52d603c5c606936b49c8b97"}, - {file = "psutil-5.9.2-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4642fd93785a29353d6917a23e2ac6177308ef5e8be5cc17008d885cb9f70f12"}, - {file = "psutil-5.9.2-cp39-cp39-win32.whl", hash = "sha256:ed29ea0b9a372c5188cdb2ad39f937900a10fb5478dc077283bf86eeac678ef1"}, - {file = "psutil-5.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:68b35cbff92d1f7103d8f1db77c977e72f49fcefae3d3d2b91c76b0e7aef48b8"}, - {file = "psutil-5.9.2.tar.gz", hash = "sha256:feb861a10b6c3bb00701063b37e4afc754f8217f0f09c42280586bd6ac712b5c"}, + {file = "psutil-5.9.3-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:b4a247cd3feaae39bb6085fcebf35b3b8ecd9b022db796d89c8f05067ca28e71"}, + {file = "psutil-5.9.3-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:5fa88e3d5d0b480602553d362c4b33a63e0c40bfea7312a7bf78799e01e0810b"}, + {file = "psutil-5.9.3-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:767ef4fa33acda16703725c0473a91e1832d296c37c63896c7153ba81698f1ab"}, + {file = "psutil-5.9.3-cp27-cp27m-win32.whl", hash = "sha256:9a4af6ed1094f867834f5f07acd1250605a0874169a5fcadbcec864aec2496a6"}, + {file = "psutil-5.9.3-cp27-cp27m-win_amd64.whl", hash = "sha256:fa5e32c7d9b60b2528108ade2929b115167fe98d59f89555574715054f50fa31"}, + {file = "psutil-5.9.3-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:fe79b4ad4836e3da6c4650cb85a663b3a51aef22e1a829c384e18fae87e5e727"}, + {file = "psutil-5.9.3-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:db8e62016add2235cc87fb7ea000ede9e4ca0aa1f221b40cef049d02d5d2593d"}, + {file = "psutil-5.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:941a6c2c591da455d760121b44097781bc970be40e0e43081b9139da485ad5b7"}, + {file = "psutil-5.9.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:71b1206e7909792d16933a0d2c1c7f04ae196186c51ba8567abae1d041f06dcb"}, + {file = "psutil-5.9.3-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f57d63a2b5beaf797b87024d018772439f9d3103a395627b77d17a8d72009543"}, + {file = "psutil-5.9.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7507f6c7b0262d3e7b0eeda15045bf5881f4ada70473b87bc7b7c93b992a7d7"}, + {file = "psutil-5.9.3-cp310-cp310-win32.whl", hash = "sha256:1b540599481c73408f6b392cdffef5b01e8ff7a2ac8caae0a91b8222e88e8f1e"}, + {file = "psutil-5.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:547ebb02031fdada635452250ff39942db8310b5c4a8102dfe9384ee5791e650"}, + {file = "psutil-5.9.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d8c3cc6bb76492133474e130a12351a325336c01c96a24aae731abf5a47fe088"}, + {file = "psutil-5.9.3-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07d880053c6461c9b89cd5d4808f3b8336665fa3acdefd6777662c5ed73a851a"}, + {file = "psutil-5.9.3-cp36-cp36m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e8b50241dd3c2ed498507f87a6602825073c07f3b7e9560c58411c14fe1e1c9"}, + {file = "psutil-5.9.3-cp36-cp36m-win32.whl", hash = "sha256:828c9dc9478b34ab96be75c81942d8df0c2bb49edbb481f597314d92b6441d89"}, + {file = "psutil-5.9.3-cp36-cp36m-win_amd64.whl", hash = "sha256:ed15edb14f52925869250b1375f0ff58ca5c4fa8adefe4883cfb0737d32f5c02"}, + {file = "psutil-5.9.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:d266cd05bd4a95ca1c2b9b5aac50d249cf7c94a542f47e0b22928ddf8b80d1ef"}, + {file = "psutil-5.9.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e4939ff75149b67aef77980409f156f0082fa36accc475d45c705bb00c6c16a"}, + {file = "psutil-5.9.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68fa227c32240c52982cb931801c5707a7f96dd8927f9102d6c7771ea1ff5698"}, + {file = "psutil-5.9.3-cp37-cp37m-win32.whl", hash = "sha256:beb57d8a1ca0ae0eb3d08ccaceb77e1a6d93606f0e1754f0d60a6ebd5c288837"}, + {file = "psutil-5.9.3-cp37-cp37m-win_amd64.whl", hash = "sha256:12500d761ac091f2426567f19f95fd3f15a197d96befb44a5c1e3cbe6db5752c"}, + {file = "psutil-5.9.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ba38cf9984d5462b506e239cf4bc24e84ead4b1d71a3be35e66dad0d13ded7c1"}, + {file = "psutil-5.9.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:46907fa62acaac364fff0b8a9da7b360265d217e4fdeaca0a2397a6883dffba2"}, + {file = "psutil-5.9.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a04a1836894c8279e5e0a0127c0db8e198ca133d28be8a2a72b4db16f6cf99c1"}, + {file = "psutil-5.9.3-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a4e07611997acf178ad13b842377e3d8e9d0a5bac43ece9bfc22a96735d9a4f"}, + {file = "psutil-5.9.3-cp38-cp38-win32.whl", hash = "sha256:6ced1ad823ecfa7d3ce26fe8aa4996e2e53fb49b7fed8ad81c80958501ec0619"}, + {file = "psutil-5.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:35feafe232d1aaf35d51bd42790cbccb882456f9f18cdc411532902370d660df"}, + {file = "psutil-5.9.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:538fcf6ae856b5e12d13d7da25ad67f02113c96f5989e6ad44422cb5994ca7fc"}, + {file = "psutil-5.9.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a3d81165b8474087bb90ec4f333a638ccfd1d69d34a9b4a1a7eaac06648f9fbe"}, + {file = "psutil-5.9.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a7826e68b0cf4ce2c1ee385d64eab7d70e3133171376cac53d7c1790357ec8f"}, + {file = "psutil-5.9.3-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ec296f565191f89c48f33d9544d8d82b0d2af7dd7d2d4e6319f27a818f8d1cc"}, + {file = "psutil-5.9.3-cp39-cp39-win32.whl", hash = "sha256:9ec95df684583b5596c82bb380c53a603bb051cf019d5c849c47e117c5064395"}, + {file = "psutil-5.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:4bd4854f0c83aa84a5a40d3b5d0eb1f3c128f4146371e03baed4589fe4f3c931"}, + {file = "psutil-5.9.3.tar.gz", hash = "sha256:7ccfcdfea4fc4b0a02ca2c31de7fcd186beb9cff8207800e14ab66f79c773af6"}, ] pyparsing = [ {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, @@ -453,8 +457,8 @@ PyYAML = [ {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, ] setuptools = [ - {file = "setuptools-65.4.1-py3-none-any.whl", hash = "sha256:1b6bdc6161661409c5f21508763dc63ab20a9ac2f8ba20029aaaa7fdb9118012"}, - {file = "setuptools-65.4.1.tar.gz", hash = "sha256:3050e338e5871e70c72983072fe34f6032ae1cdeeeb67338199c2f74e083a80e"}, + {file = "setuptools-65.5.0-py3-none-any.whl", hash = "sha256:f62ea9da9ed6289bfe868cd6845968a2c854d1427f8548d52cae02a42b4f0356"}, + {file = "setuptools-65.5.0.tar.gz", hash = "sha256:512e5536220e38146176efb833d4a62aa726b7bbff82cfbc8ba9eaa3996e0b17"}, ] six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, @@ -487,6 +491,6 @@ typer = [ {file = "typer-0.6.1.tar.gz", hash = "sha256:2d5720a5e63f73eaf31edaa15f6ab87f35f0690f8ca233017d7d23d743a91d73"}, ] typing-extensions = [ - {file = "typing_extensions-4.3.0-py3-none-any.whl", hash = "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02"}, - {file = "typing_extensions-4.3.0.tar.gz", hash = "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6"}, + {file = "typing_extensions-4.4.0-py3-none-any.whl", hash = "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e"}, + {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, ]