feat(server): enable hf-transfer (#76)

This commit is contained in:
OlivierDehaene 2023-02-18 14:04:11 +01:00 committed by GitHub
parent 6796d38c6d
commit 17bc841b1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 49 additions and 27 deletions

View File

@ -33,6 +33,7 @@ ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \
HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
MODEL_ID=bigscience/bloom-560m \
QUANTIZE=false \
NUM_SHARD=1 \

View File

@ -98,23 +98,18 @@ fn main() -> ExitCode {
})
.expect("Error setting Ctrl-C handler");
// Download weights
if weights_cache_override.is_none() {
// Download weights for sharded models
if weights_cache_override.is_none() && num_shard > 1 {
let mut download_argv = vec![
"text-generation-server".to_string(),
"download-weights".to_string(),
model_id.clone(),
"--extension".to_string(),
".safetensors".to_string(),
"--logger-level".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
if num_shard == 1 {
download_argv.push("--extension".to_string());
download_argv.push(".bin".to_string());
} else {
download_argv.push("--extension".to_string());
download_argv.push(".safetensors".to_string());
}
// Model optional revision
if let Some(ref revision) = revision {
@ -131,6 +126,9 @@ fn main() -> ExitCode {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// Enable hf transfer for insane download speeds
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
// Start process
tracing::info!("Starting download process.");
let mut download_process = match Popen::create(
@ -209,12 +207,6 @@ fn main() -> ExitCode {
}
sleep(Duration::from_millis(100));
}
} else {
tracing::info!(
"weights_cache_override is set to {:?}.",
weights_cache_override
);
tracing::info!("Skipping download.")
}
// Shared shutdown bool
@ -479,6 +471,9 @@ fn shard_manager(
// Safetensors load fast
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Enable hf transfer for insane download speeds
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {

32
server/poetry.lock generated
View File

@ -192,6 +192,14 @@ grpcio = ">=1.51.1"
protobuf = ">=4.21.6,<5.0dev"
setuptools = "*"
[[package]]
name = "hf-transfer"
version = "0.1.0"
description = ""
category = "main"
optional = false
python-versions = ">=3.7"
[[package]]
name = "idna"
version = "3.4"
@ -622,7 +630,7 @@ bnb = ["bitsandbytes"]
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "f3cab6881b52045770a90ec9be7415a0ee499d9e980892d544f68073700cf321"
content-hash = "ef6da62cff76be3eeb45eac98326d6e4fac5d35796b8bdcf555575323ce97ba2"
[metadata.files]
accelerate = [
@ -861,6 +869,28 @@ grpcio-tools = [
{file = "grpcio_tools-1.51.1-cp39-cp39-win32.whl", hash = "sha256:40ef70e8c5d0310dedff9af502b520b4c7e215bce94094527fb959150a0c594a"},
{file = "grpcio_tools-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b8acf4eaa0ebe37e2f69108de49efd935b7abe9c7e58ba737490b99906aa76"},
]
hf-transfer = [
{file = "hf_transfer-0.1.0-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0f41bb04898d041b774220048f237d10560ec27e1decd01a04d323c64202e8fe"},
{file = "hf_transfer-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94510d4e3a66aa5afa06b61ff537895c3f1b93d689575a8a840ea9ec3189d3d8"},
{file = "hf_transfer-0.1.0-cp310-none-win_amd64.whl", hash = "sha256:e85134084c7e9e9daa74331c4690f821a30afedacb97355d1a66c243317c3e7a"},
{file = "hf_transfer-0.1.0-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:3441b0cba24afad7fffbcfc0eb7e31d3df127d092feeee73e5b08bb5752c903b"},
{file = "hf_transfer-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14938d44c71a07b452612a90499b5f021b2277e1c93c66f60d06d746b2d0661d"},
{file = "hf_transfer-0.1.0-cp311-cp311-manylinux_2_34_x86_64.whl", hash = "sha256:ae01b9844995622beee0f1b7ff0240e269bfc28ea46149eb4abbf63b4683f3e2"},
{file = "hf_transfer-0.1.0-cp311-none-win_amd64.whl", hash = "sha256:79e9505bffd3a1086be13033a805c8e6f4bb763de03a4197b959984def587e7f"},
{file = "hf_transfer-0.1.0-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:5656deb183e271d37925de0d6989d7a4b1eefae42d771f10907f41fce08bdada"},
{file = "hf_transfer-0.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcb16ec07f1ad343b7189745b6f659b7f82b864a55d3b84fe34361f64e4abc76"},
{file = "hf_transfer-0.1.0-cp37-none-win_amd64.whl", hash = "sha256:5314b708bc2a8cf844885d350cd13ba0b528466d3eb9766e4d8d39d080e718c0"},
{file = "hf_transfer-0.1.0-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:fe245d0d84bbc113870144c56b50425f9b6cacc3e361b3559b0786ac076ba260"},
{file = "hf_transfer-0.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49099d41c05a19202dca0306bfa7f42cedaea57ccc783641b1533de860b6a1f4"},
{file = "hf_transfer-0.1.0-cp38-none-win_amd64.whl", hash = "sha256:0d7bb607a7372908ffa2d55f1e6790430c5706434c2d1d664db4928730c2c7e4"},
{file = "hf_transfer-0.1.0-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:70f40927bc4a19ab50605bb542bd3858eb465ad65c94cfcaf36cf36d68fc5169"},
{file = "hf_transfer-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d92203155f451a9b517267d2a0966b282615037e286ff0420a2963f67d451de3"},
{file = "hf_transfer-0.1.0-cp39-none-win_amd64.whl", hash = "sha256:c1c2799154e4bd03d2b2e2907d494005f707686b80d5aa4421c859ffa612ffa3"},
{file = "hf_transfer-0.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0900bd5698a77fb44c95639eb3ec97202d13e1bd4282cde5c81ed48e3f9341e"},
{file = "hf_transfer-0.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:427032df4e83a1bedaa76383a5b825cf779fdfc206681c28b21476bc84089280"},
{file = "hf_transfer-0.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e36fe371af6e31795621ffab53a2dcb107088fcfeb2951eaa2a10bfdc8b8863b"},
{file = "hf_transfer-0.1.0.tar.gz", hash = "sha256:f692ef717ded50e441b4d40b6aea625772f68b90414aeef86bb33eab40cb09a4"},
]
idna = [
{file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
{file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},

View File

@ -22,6 +22,7 @@ loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.0"
[tool.poetry.extras]
bnb = ["bitsandbytes"]

View File

@ -1,8 +1,6 @@
import time
import concurrent
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from loguru import logger
from pathlib import Path
@ -147,20 +145,17 @@ def download_weights(
)
return Path(local_file)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_file, filename=filename) for filename in filenames
]
# We do this instead of using tqdm because we want to parse the logs with the launcher
start_time = time.time()
files = []
for i, future in enumerate(concurrent.futures.as_completed(futures)):
for i, filename in enumerate(filenames):
file = download_file(filename)
elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1)
remaining = len(filenames) - (i + 1)
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}")
files.append(future.result())
logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
files.append(file)
return files