diff --git a/Dockerfile b/Dockerfile index 907379dc..228909dd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,6 +33,7 @@ ENV LANG=C.UTF-8 \ LC_ALL=C.UTF-8 \ DEBIAN_FRONTEND=noninteractive \ HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ MODEL_ID=bigscience/bloom-560m \ QUANTIZE=false \ NUM_SHARD=1 \ diff --git a/launcher/src/main.rs b/launcher/src/main.rs index ac118566..218e3f3a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -98,23 +98,18 @@ fn main() -> ExitCode { }) .expect("Error setting Ctrl-C handler"); - // Download weights - if weights_cache_override.is_none() { + // Download weights for sharded models + if weights_cache_override.is_none() && num_shard > 1 { let mut download_argv = vec![ "text-generation-server".to_string(), "download-weights".to_string(), model_id.clone(), + "--extension".to_string(), + ".safetensors".to_string(), "--logger-level".to_string(), "INFO".to_string(), "--json-output".to_string(), ]; - if num_shard == 1 { - download_argv.push("--extension".to_string()); - download_argv.push(".bin".to_string()); - } else { - download_argv.push("--extension".to_string()); - download_argv.push(".safetensors".to_string()); - } // Model optional revision if let Some(ref revision) = revision { @@ -131,6 +126,9 @@ fn main() -> ExitCode { env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); }; + // Enable hf transfer for insane download speeds + env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into())); + // Start process tracing::info!("Starting download process."); let mut download_process = match Popen::create( @@ -209,12 +207,6 @@ fn main() -> ExitCode { } sleep(Duration::from_millis(100)); } - } else { - tracing::info!( - "weights_cache_override is set to {:?}.", - weights_cache_override - ); - tracing::info!("Skipping download.") } // Shared shutdown bool @@ -479,6 +471,9 @@ fn shard_manager( // Safetensors load fast env.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); + // Enable hf transfer for insane download speeds + env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into())); + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container if let Some(huggingface_hub_cache) = huggingface_hub_cache { diff --git a/server/poetry.lock b/server/poetry.lock index 0e0655cb..03b97512 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -192,6 +192,14 @@ grpcio = ">=1.51.1" protobuf = ">=4.21.6,<5.0dev" setuptools = "*" +[[package]] +name = "hf-transfer" +version = "0.1.0" +description = "" +category = "main" +optional = false +python-versions = ">=3.7" + [[package]] name = "idna" version = "3.4" @@ -622,7 +630,7 @@ bnb = ["bitsandbytes"] [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "f3cab6881b52045770a90ec9be7415a0ee499d9e980892d544f68073700cf321" +content-hash = "ef6da62cff76be3eeb45eac98326d6e4fac5d35796b8bdcf555575323ce97ba2" [metadata.files] accelerate = [ @@ -861,6 +869,28 @@ grpcio-tools = [ {file = "grpcio_tools-1.51.1-cp39-cp39-win32.whl", hash = "sha256:40ef70e8c5d0310dedff9af502b520b4c7e215bce94094527fb959150a0c594a"}, {file = "grpcio_tools-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b8acf4eaa0ebe37e2f69108de49efd935b7abe9c7e58ba737490b99906aa76"}, ] +hf-transfer = [ + {file = "hf_transfer-0.1.0-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0f41bb04898d041b774220048f237d10560ec27e1decd01a04d323c64202e8fe"}, + {file = "hf_transfer-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94510d4e3a66aa5afa06b61ff537895c3f1b93d689575a8a840ea9ec3189d3d8"}, + {file = "hf_transfer-0.1.0-cp310-none-win_amd64.whl", hash = "sha256:e85134084c7e9e9daa74331c4690f821a30afedacb97355d1a66c243317c3e7a"}, + {file = "hf_transfer-0.1.0-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:3441b0cba24afad7fffbcfc0eb7e31d3df127d092feeee73e5b08bb5752c903b"}, + {file = "hf_transfer-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14938d44c71a07b452612a90499b5f021b2277e1c93c66f60d06d746b2d0661d"}, + {file = "hf_transfer-0.1.0-cp311-cp311-manylinux_2_34_x86_64.whl", hash = "sha256:ae01b9844995622beee0f1b7ff0240e269bfc28ea46149eb4abbf63b4683f3e2"}, + {file = "hf_transfer-0.1.0-cp311-none-win_amd64.whl", hash = "sha256:79e9505bffd3a1086be13033a805c8e6f4bb763de03a4197b959984def587e7f"}, + {file = "hf_transfer-0.1.0-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:5656deb183e271d37925de0d6989d7a4b1eefae42d771f10907f41fce08bdada"}, + {file = "hf_transfer-0.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcb16ec07f1ad343b7189745b6f659b7f82b864a55d3b84fe34361f64e4abc76"}, + {file = "hf_transfer-0.1.0-cp37-none-win_amd64.whl", hash = "sha256:5314b708bc2a8cf844885d350cd13ba0b528466d3eb9766e4d8d39d080e718c0"}, + {file = "hf_transfer-0.1.0-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:fe245d0d84bbc113870144c56b50425f9b6cacc3e361b3559b0786ac076ba260"}, + {file = "hf_transfer-0.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49099d41c05a19202dca0306bfa7f42cedaea57ccc783641b1533de860b6a1f4"}, + {file = "hf_transfer-0.1.0-cp38-none-win_amd64.whl", hash = "sha256:0d7bb607a7372908ffa2d55f1e6790430c5706434c2d1d664db4928730c2c7e4"}, + {file = "hf_transfer-0.1.0-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:70f40927bc4a19ab50605bb542bd3858eb465ad65c94cfcaf36cf36d68fc5169"}, + {file = "hf_transfer-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d92203155f451a9b517267d2a0966b282615037e286ff0420a2963f67d451de3"}, + {file = "hf_transfer-0.1.0-cp39-none-win_amd64.whl", hash = "sha256:c1c2799154e4bd03d2b2e2907d494005f707686b80d5aa4421c859ffa612ffa3"}, + {file = "hf_transfer-0.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0900bd5698a77fb44c95639eb3ec97202d13e1bd4282cde5c81ed48e3f9341e"}, + {file = "hf_transfer-0.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:427032df4e83a1bedaa76383a5b825cf779fdfc206681c28b21476bc84089280"}, + {file = "hf_transfer-0.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e36fe371af6e31795621ffab53a2dcb107088fcfeb2951eaa2a10bfdc8b8863b"}, + {file = "hf_transfer-0.1.0.tar.gz", hash = "sha256:f692ef717ded50e441b4d40b6aea625772f68b90414aeef86bb33eab40cb09a4"}, +] idna = [ {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, diff --git a/server/pyproject.toml b/server/pyproject.toml index 4722f703..025aebf3 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -22,6 +22,7 @@ loguru = "^0.6.0" opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" +hf-transfer = "^0.1.0" [tool.poetry.extras] bnb = ["bitsandbytes"] diff --git a/server/text_generation/utils/hub.py b/server/text_generation/utils/hub.py index 713488b5..7166b4cb 100644 --- a/server/text_generation/utils/hub.py +++ b/server/text_generation/utils/hub.py @@ -1,8 +1,6 @@ import time -import concurrent import os -from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from loguru import logger from pathlib import Path @@ -147,20 +145,17 @@ def download_weights( ) return Path(local_file) - executor = ThreadPoolExecutor(max_workers=5) - futures = [ - executor.submit(download_file, filename=filename) for filename in filenames - ] - # We do this instead of using tqdm because we want to parse the logs with the launcher start_time = time.time() files = [] - for i, future in enumerate(concurrent.futures.as_completed(futures)): + for i, filename in enumerate(filenames): + file = download_file(filename) + elapsed = timedelta(seconds=int(time.time() - start_time)) - remaining = len(futures) - (i + 1) + remaining = len(filenames) - (i + 1) eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0 - logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}") - files.append(future.result()) + logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}") + files.append(file) return files