feat(server): enable hf-transfer (#76)
This commit is contained in:
parent
6796d38c6d
commit
17bc841b1b
|
@ -33,6 +33,7 @@ ENV LANG=C.UTF-8 \
|
|||
LC_ALL=C.UTF-8 \
|
||||
DEBIAN_FRONTEND=noninteractive \
|
||||
HUGGINGFACE_HUB_CACHE=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
MODEL_ID=bigscience/bloom-560m \
|
||||
QUANTIZE=false \
|
||||
NUM_SHARD=1 \
|
||||
|
|
|
@ -98,23 +98,18 @@ fn main() -> ExitCode {
|
|||
})
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// Download weights
|
||||
if weights_cache_override.is_none() {
|
||||
// Download weights for sharded models
|
||||
if weights_cache_override.is_none() && num_shard > 1 {
|
||||
let mut download_argv = vec![
|
||||
"text-generation-server".to_string(),
|
||||
"download-weights".to_string(),
|
||||
model_id.clone(),
|
||||
"--extension".to_string(),
|
||||
".safetensors".to_string(),
|
||||
"--logger-level".to_string(),
|
||||
"INFO".to_string(),
|
||||
"--json-output".to_string(),
|
||||
];
|
||||
if num_shard == 1 {
|
||||
download_argv.push("--extension".to_string());
|
||||
download_argv.push(".bin".to_string());
|
||||
} else {
|
||||
download_argv.push("--extension".to_string());
|
||||
download_argv.push(".safetensors".to_string());
|
||||
}
|
||||
|
||||
// Model optional revision
|
||||
if let Some(ref revision) = revision {
|
||||
|
@ -131,6 +126,9 @@ fn main() -> ExitCode {
|
|||
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
|
||||
};
|
||||
|
||||
// Enable hf transfer for insane download speeds
|
||||
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
|
||||
|
||||
// Start process
|
||||
tracing::info!("Starting download process.");
|
||||
let mut download_process = match Popen::create(
|
||||
|
@ -209,12 +207,6 @@ fn main() -> ExitCode {
|
|||
}
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
} else {
|
||||
tracing::info!(
|
||||
"weights_cache_override is set to {:?}.",
|
||||
weights_cache_override
|
||||
);
|
||||
tracing::info!("Skipping download.")
|
||||
}
|
||||
|
||||
// Shared shutdown bool
|
||||
|
@ -479,6 +471,9 @@ fn shard_manager(
|
|||
// Safetensors load fast
|
||||
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
|
||||
|
||||
// Enable hf transfer for insane download speeds
|
||||
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
|
||||
|
||||
// If huggingface_hub_cache is some, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||
|
|
|
@ -192,6 +192,14 @@ grpcio = ">=1.51.1"
|
|||
protobuf = ">=4.21.6,<5.0dev"
|
||||
setuptools = "*"
|
||||
|
||||
[[package]]
|
||||
name = "hf-transfer"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.4"
|
||||
|
@ -622,7 +630,7 @@ bnb = ["bitsandbytes"]
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "f3cab6881b52045770a90ec9be7415a0ee499d9e980892d544f68073700cf321"
|
||||
content-hash = "ef6da62cff76be3eeb45eac98326d6e4fac5d35796b8bdcf555575323ce97ba2"
|
||||
|
||||
[metadata.files]
|
||||
accelerate = [
|
||||
|
@ -861,6 +869,28 @@ grpcio-tools = [
|
|||
{file = "grpcio_tools-1.51.1-cp39-cp39-win32.whl", hash = "sha256:40ef70e8c5d0310dedff9af502b520b4c7e215bce94094527fb959150a0c594a"},
|
||||
{file = "grpcio_tools-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b8acf4eaa0ebe37e2f69108de49efd935b7abe9c7e58ba737490b99906aa76"},
|
||||
]
|
||||
hf-transfer = [
|
||||
{file = "hf_transfer-0.1.0-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0f41bb04898d041b774220048f237d10560ec27e1decd01a04d323c64202e8fe"},
|
||||
{file = "hf_transfer-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94510d4e3a66aa5afa06b61ff537895c3f1b93d689575a8a840ea9ec3189d3d8"},
|
||||
{file = "hf_transfer-0.1.0-cp310-none-win_amd64.whl", hash = "sha256:e85134084c7e9e9daa74331c4690f821a30afedacb97355d1a66c243317c3e7a"},
|
||||
{file = "hf_transfer-0.1.0-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:3441b0cba24afad7fffbcfc0eb7e31d3df127d092feeee73e5b08bb5752c903b"},
|
||||
{file = "hf_transfer-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14938d44c71a07b452612a90499b5f021b2277e1c93c66f60d06d746b2d0661d"},
|
||||
{file = "hf_transfer-0.1.0-cp311-cp311-manylinux_2_34_x86_64.whl", hash = "sha256:ae01b9844995622beee0f1b7ff0240e269bfc28ea46149eb4abbf63b4683f3e2"},
|
||||
{file = "hf_transfer-0.1.0-cp311-none-win_amd64.whl", hash = "sha256:79e9505bffd3a1086be13033a805c8e6f4bb763de03a4197b959984def587e7f"},
|
||||
{file = "hf_transfer-0.1.0-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:5656deb183e271d37925de0d6989d7a4b1eefae42d771f10907f41fce08bdada"},
|
||||
{file = "hf_transfer-0.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcb16ec07f1ad343b7189745b6f659b7f82b864a55d3b84fe34361f64e4abc76"},
|
||||
{file = "hf_transfer-0.1.0-cp37-none-win_amd64.whl", hash = "sha256:5314b708bc2a8cf844885d350cd13ba0b528466d3eb9766e4d8d39d080e718c0"},
|
||||
{file = "hf_transfer-0.1.0-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:fe245d0d84bbc113870144c56b50425f9b6cacc3e361b3559b0786ac076ba260"},
|
||||
{file = "hf_transfer-0.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49099d41c05a19202dca0306bfa7f42cedaea57ccc783641b1533de860b6a1f4"},
|
||||
{file = "hf_transfer-0.1.0-cp38-none-win_amd64.whl", hash = "sha256:0d7bb607a7372908ffa2d55f1e6790430c5706434c2d1d664db4928730c2c7e4"},
|
||||
{file = "hf_transfer-0.1.0-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:70f40927bc4a19ab50605bb542bd3858eb465ad65c94cfcaf36cf36d68fc5169"},
|
||||
{file = "hf_transfer-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d92203155f451a9b517267d2a0966b282615037e286ff0420a2963f67d451de3"},
|
||||
{file = "hf_transfer-0.1.0-cp39-none-win_amd64.whl", hash = "sha256:c1c2799154e4bd03d2b2e2907d494005f707686b80d5aa4421c859ffa612ffa3"},
|
||||
{file = "hf_transfer-0.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0900bd5698a77fb44c95639eb3ec97202d13e1bd4282cde5c81ed48e3f9341e"},
|
||||
{file = "hf_transfer-0.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:427032df4e83a1bedaa76383a5b825cf779fdfc206681c28b21476bc84089280"},
|
||||
{file = "hf_transfer-0.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e36fe371af6e31795621ffab53a2dcb107088fcfeb2951eaa2a10bfdc8b8863b"},
|
||||
{file = "hf_transfer-0.1.0.tar.gz", hash = "sha256:f692ef717ded50e441b4d40b6aea625772f68b90414aeef86bb33eab40cb09a4"},
|
||||
]
|
||||
idna = [
|
||||
{file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
|
||||
{file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},
|
||||
|
|
|
@ -22,6 +22,7 @@ loguru = "^0.6.0"
|
|||
opentelemetry-api = "^1.15.0"
|
||||
opentelemetry-exporter-otlp = "^1.15.0"
|
||||
opentelemetry-instrumentation-grpc = "^0.36b0"
|
||||
hf-transfer = "^0.1.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
bnb = ["bitsandbytes"]
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import time
|
||||
import concurrent
|
||||
import os
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
|
@ -147,20 +145,17 @@ def download_weights(
|
|||
)
|
||||
return Path(local_file)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=5)
|
||||
futures = [
|
||||
executor.submit(download_file, filename=filename) for filename in filenames
|
||||
]
|
||||
|
||||
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||
start_time = time.time()
|
||||
files = []
|
||||
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||
for i, filename in enumerate(filenames):
|
||||
file = download_file(filename)
|
||||
|
||||
elapsed = timedelta(seconds=int(time.time() - start_time))
|
||||
remaining = len(futures) - (i + 1)
|
||||
remaining = len(filenames) - (i + 1)
|
||||
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
|
||||
|
||||
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}")
|
||||
files.append(future.result())
|
||||
logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
|
||||
files.append(file)
|
||||
|
||||
return files
|
||||
|
|
Loading…
Reference in New Issue