feat(server): enable hf-transfer (#76)

This commit is contained in:
OlivierDehaene 2023-02-18 14:04:11 +01:00 committed by GitHub
parent 6796d38c6d
commit 17bc841b1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 49 additions and 27 deletions

View File

@ -33,6 +33,7 @@ ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \ LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \ DEBIAN_FRONTEND=noninteractive \
HUGGINGFACE_HUB_CACHE=/data \ HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
MODEL_ID=bigscience/bloom-560m \ MODEL_ID=bigscience/bloom-560m \
QUANTIZE=false \ QUANTIZE=false \
NUM_SHARD=1 \ NUM_SHARD=1 \

View File

@ -98,23 +98,18 @@ fn main() -> ExitCode {
}) })
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
// Download weights // Download weights for sharded models
if weights_cache_override.is_none() { if weights_cache_override.is_none() && num_shard > 1 {
let mut download_argv = vec![ let mut download_argv = vec![
"text-generation-server".to_string(), "text-generation-server".to_string(),
"download-weights".to_string(), "download-weights".to_string(),
model_id.clone(), model_id.clone(),
"--extension".to_string(),
".safetensors".to_string(),
"--logger-level".to_string(), "--logger-level".to_string(),
"INFO".to_string(), "INFO".to_string(),
"--json-output".to_string(), "--json-output".to_string(),
]; ];
if num_shard == 1 {
download_argv.push("--extension".to_string());
download_argv.push(".bin".to_string());
} else {
download_argv.push("--extension".to_string());
download_argv.push(".safetensors".to_string());
}
// Model optional revision // Model optional revision
if let Some(ref revision) = revision { if let Some(ref revision) = revision {
@ -131,6 +126,9 @@ fn main() -> ExitCode {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
}; };
// Enable hf transfer for insane download speeds
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
// Start process // Start process
tracing::info!("Starting download process."); tracing::info!("Starting download process.");
let mut download_process = match Popen::create( let mut download_process = match Popen::create(
@ -209,12 +207,6 @@ fn main() -> ExitCode {
} }
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
} }
} else {
tracing::info!(
"weights_cache_override is set to {:?}.",
weights_cache_override
);
tracing::info!("Skipping download.")
} }
// Shared shutdown bool // Shared shutdown bool
@ -479,6 +471,9 @@ fn shard_manager(
// Safetensors load fast // Safetensors load fast
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Enable hf transfer for insane download speeds
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
// If huggingface_hub_cache is some, pass it to the shard // If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache { if let Some(huggingface_hub_cache) = huggingface_hub_cache {

32
server/poetry.lock generated
View File

@ -192,6 +192,14 @@ grpcio = ">=1.51.1"
protobuf = ">=4.21.6,<5.0dev" protobuf = ">=4.21.6,<5.0dev"
setuptools = "*" setuptools = "*"
[[package]]
name = "hf-transfer"
version = "0.1.0"
description = ""
category = "main"
optional = false
python-versions = ">=3.7"
[[package]] [[package]]
name = "idna" name = "idna"
version = "3.4" version = "3.4"
@ -622,7 +630,7 @@ bnb = ["bitsandbytes"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "f3cab6881b52045770a90ec9be7415a0ee499d9e980892d544f68073700cf321" content-hash = "ef6da62cff76be3eeb45eac98326d6e4fac5d35796b8bdcf555575323ce97ba2"
[metadata.files] [metadata.files]
accelerate = [ accelerate = [
@ -861,6 +869,28 @@ grpcio-tools = [
{file = "grpcio_tools-1.51.1-cp39-cp39-win32.whl", hash = "sha256:40ef70e8c5d0310dedff9af502b520b4c7e215bce94094527fb959150a0c594a"}, {file = "grpcio_tools-1.51.1-cp39-cp39-win32.whl", hash = "sha256:40ef70e8c5d0310dedff9af502b520b4c7e215bce94094527fb959150a0c594a"},
{file = "grpcio_tools-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b8acf4eaa0ebe37e2f69108de49efd935b7abe9c7e58ba737490b99906aa76"}, {file = "grpcio_tools-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b8acf4eaa0ebe37e2f69108de49efd935b7abe9c7e58ba737490b99906aa76"},
] ]
hf-transfer = [
{file = "hf_transfer-0.1.0-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0f41bb04898d041b774220048f237d10560ec27e1decd01a04d323c64202e8fe"},
{file = "hf_transfer-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94510d4e3a66aa5afa06b61ff537895c3f1b93d689575a8a840ea9ec3189d3d8"},
{file = "hf_transfer-0.1.0-cp310-none-win_amd64.whl", hash = "sha256:e85134084c7e9e9daa74331c4690f821a30afedacb97355d1a66c243317c3e7a"},
{file = "hf_transfer-0.1.0-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:3441b0cba24afad7fffbcfc0eb7e31d3df127d092feeee73e5b08bb5752c903b"},
{file = "hf_transfer-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14938d44c71a07b452612a90499b5f021b2277e1c93c66f60d06d746b2d0661d"},
{file = "hf_transfer-0.1.0-cp311-cp311-manylinux_2_34_x86_64.whl", hash = "sha256:ae01b9844995622beee0f1b7ff0240e269bfc28ea46149eb4abbf63b4683f3e2"},
{file = "hf_transfer-0.1.0-cp311-none-win_amd64.whl", hash = "sha256:79e9505bffd3a1086be13033a805c8e6f4bb763de03a4197b959984def587e7f"},
{file = "hf_transfer-0.1.0-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:5656deb183e271d37925de0d6989d7a4b1eefae42d771f10907f41fce08bdada"},
{file = "hf_transfer-0.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcb16ec07f1ad343b7189745b6f659b7f82b864a55d3b84fe34361f64e4abc76"},
{file = "hf_transfer-0.1.0-cp37-none-win_amd64.whl", hash = "sha256:5314b708bc2a8cf844885d350cd13ba0b528466d3eb9766e4d8d39d080e718c0"},
{file = "hf_transfer-0.1.0-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:fe245d0d84bbc113870144c56b50425f9b6cacc3e361b3559b0786ac076ba260"},
{file = "hf_transfer-0.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49099d41c05a19202dca0306bfa7f42cedaea57ccc783641b1533de860b6a1f4"},
{file = "hf_transfer-0.1.0-cp38-none-win_amd64.whl", hash = "sha256:0d7bb607a7372908ffa2d55f1e6790430c5706434c2d1d664db4928730c2c7e4"},
{file = "hf_transfer-0.1.0-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:70f40927bc4a19ab50605bb542bd3858eb465ad65c94cfcaf36cf36d68fc5169"},
{file = "hf_transfer-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d92203155f451a9b517267d2a0966b282615037e286ff0420a2963f67d451de3"},
{file = "hf_transfer-0.1.0-cp39-none-win_amd64.whl", hash = "sha256:c1c2799154e4bd03d2b2e2907d494005f707686b80d5aa4421c859ffa612ffa3"},
{file = "hf_transfer-0.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0900bd5698a77fb44c95639eb3ec97202d13e1bd4282cde5c81ed48e3f9341e"},
{file = "hf_transfer-0.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:427032df4e83a1bedaa76383a5b825cf779fdfc206681c28b21476bc84089280"},
{file = "hf_transfer-0.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e36fe371af6e31795621ffab53a2dcb107088fcfeb2951eaa2a10bfdc8b8863b"},
{file = "hf_transfer-0.1.0.tar.gz", hash = "sha256:f692ef717ded50e441b4d40b6aea625772f68b90414aeef86bb33eab40cb09a4"},
]
idna = [ idna = [
{file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
{file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},

View File

@ -22,6 +22,7 @@ loguru = "^0.6.0"
opentelemetry-api = "^1.15.0" opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0" opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.0"
[tool.poetry.extras] [tool.poetry.extras]
bnb = ["bitsandbytes"] bnb = ["bitsandbytes"]

View File

@ -1,8 +1,6 @@
import time import time
import concurrent
import os import os
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta from datetime import timedelta
from loguru import logger from loguru import logger
from pathlib import Path from pathlib import Path
@ -147,20 +145,17 @@ def download_weights(
) )
return Path(local_file) return Path(local_file)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_file, filename=filename) for filename in filenames
]
# We do this instead of using tqdm because we want to parse the logs with the launcher # We do this instead of using tqdm because we want to parse the logs with the launcher
start_time = time.time() start_time = time.time()
files = [] files = []
for i, future in enumerate(concurrent.futures.as_completed(futures)): for i, filename in enumerate(filenames):
file = download_file(filename)
elapsed = timedelta(seconds=int(time.time() - start_time)) elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1) remaining = len(filenames) - (i + 1)
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0 eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}") logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
files.append(future.result()) files.append(file)
return files return files