185 lines
6.7 KiB
Python
185 lines
6.7 KiB
Python
import huggingface_hub
|
|
import argparse
|
|
import shutil
|
|
import time
|
|
|
|
REQUIRED_MODELS = {
|
|
"bigscience/bloom-560m": "main",
|
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0": "main",
|
|
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq": "main",
|
|
"tiiuae/falcon-7b": "main",
|
|
"TechxGenus/gemma-2b-GPTQ": "main",
|
|
"google/gemma-2b": "main",
|
|
"openai-community/gpt2": "main",
|
|
"turboderp/Llama-3-8B-Instruct-exl2": "2.5bpw",
|
|
"huggingface/llama-7b-gptq": "main",
|
|
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit": "main",
|
|
"huggingface/llama-7b": "main",
|
|
"FasterDecoding/medusa-vicuna-7b-v1.3": "refs/pr/1",
|
|
"mistralai/Mistral-7B-Instruct-v0.1": "main",
|
|
"OpenAssistant/oasst-sft-1-pythia-12b": "main",
|
|
"stabilityai/stablelm-tuned-alpha-3b": "main",
|
|
"google/paligemma-3b-pt-224": "main",
|
|
"microsoft/phi-2": "main",
|
|
"Qwen/Qwen1.5-0.5B": "main",
|
|
"bigcode/starcoder": "main",
|
|
"Narsil/starcoder-gptq": "main",
|
|
"bigcode/starcoder2-3b": "main",
|
|
"HuggingFaceM4/idefics-9b-instruct": "main",
|
|
"HuggingFaceM4/idefics2-8b": "main",
|
|
"llava-hf/llava-v1.6-mistral-7b-hf": "main",
|
|
"state-spaces/mamba-130m": "main",
|
|
"mosaicml/mpt-7b": "main",
|
|
"bigscience/mt0-base": "main",
|
|
"google/flan-t5-xxl": "main",
|
|
"lmsys/vicuna-7b-v1.3": "main",
|
|
}
|
|
|
|
|
|
def cleanup_cache(token: str, cache_dir: str):
|
|
# Retrieve the size per model for all models used in the CI.
|
|
size_per_model = {}
|
|
extension_per_model = {}
|
|
checkpoints_per_model = {}
|
|
for model_id, revision in REQUIRED_MODELS.items():
|
|
print(f"Crawling {model_id}...")
|
|
model_size = 0
|
|
checkpoints = {}
|
|
all_files = huggingface_hub.list_repo_files(
|
|
model_id,
|
|
repo_type="model",
|
|
revision=revision,
|
|
token=token,
|
|
)
|
|
|
|
extension = None
|
|
if any(".safetensors" in filename for filename in all_files):
|
|
extension = ".safetensors"
|
|
elif any(".pt" in filename for filename in all_files):
|
|
extension = ".pt"
|
|
elif any(".bin" in filename for filename in all_files):
|
|
extension = ".bin"
|
|
|
|
extension_per_model[model_id] = extension
|
|
|
|
for filename in all_files:
|
|
if filename.endswith(extension):
|
|
file_url = huggingface_hub.hf_hub_url(
|
|
model_id, filename, revision=revision
|
|
)
|
|
file_metadata = huggingface_hub.get_hf_file_metadata(
|
|
file_url, token=token
|
|
)
|
|
model_size += file_metadata.size * 1e-9 # in GB
|
|
checkpoints[filename] = file_metadata.size * 1e-9
|
|
|
|
size_per_model[model_id] = model_size
|
|
checkpoints_per_model[model_id] = checkpoints
|
|
|
|
total_required_size = sum(size_per_model.values())
|
|
print(f"Total required disk for checkpoints: {total_required_size:.2f} GB")
|
|
|
|
cached_dir = huggingface_hub.scan_cache_dir(cache_dir)
|
|
|
|
cache_size_per_model = {}
|
|
cached_required_size_per_model = {}
|
|
cached_shas_per_model = {}
|
|
|
|
# Retrieve the SHAs and model ids of other non-necessary models in the cache.
|
|
for repo in cached_dir.repos:
|
|
if repo.repo_id in REQUIRED_MODELS:
|
|
cached_required_size_per_model[repo.repo_id] = 0
|
|
|
|
for checkpoint in checkpoints_per_model[repo.repo_id]:
|
|
filepath = huggingface_hub.try_to_load_from_cache(
|
|
repo.repo_id,
|
|
checkpoint,
|
|
cache_dir=cache_dir,
|
|
revision=REQUIRED_MODELS[repo.repo_id],
|
|
)
|
|
|
|
if isinstance(filepath, str):
|
|
cached_required_size_per_model[
|
|
repo.repo_id
|
|
] += checkpoints_per_model[repo.repo_id][checkpoint]
|
|
elif repo.repo_type == "model":
|
|
cache_size_per_model[repo.repo_id] = repo.size_on_disk * 1e-9 # in GB
|
|
|
|
shas = []
|
|
for rev in repo.revisions:
|
|
shas.append(rev.commit_hash)
|
|
cached_shas_per_model[repo.repo_id] = shas
|
|
|
|
total_required_cached_size = sum(cached_required_size_per_model.values())
|
|
total_other_cached_size = sum(cache_size_per_model.values())
|
|
|
|
print("total_required_size", total_required_size)
|
|
print("total_required_cached_size", total_required_cached_size)
|
|
total_non_cached_required_size = total_required_size - total_required_cached_size
|
|
assert total_non_cached_required_size >= -0.001
|
|
|
|
print(
|
|
f"Total non-cached required models size: {total_non_cached_required_size:.2f} GB (to be downloaded)"
|
|
)
|
|
print(
|
|
f"Total HF cached models size: {total_other_cached_size + total_required_cached_size:.2f} GB"
|
|
)
|
|
print(
|
|
f"Total non-necessary HF cached models size: {total_other_cached_size:.2f} GB"
|
|
)
|
|
|
|
free_memory = shutil.disk_usage("/data").free * 1e-9
|
|
print(f"Free memory: {free_memory:.2f} GB")
|
|
|
|
if free_memory + total_other_cached_size < total_non_cached_required_size * 1.05:
|
|
raise ValueError(
|
|
"Not enough space on device to execute the complete CI, please clean up the CI machine"
|
|
)
|
|
|
|
while free_memory < 10 + total_non_cached_required_size * 1.05:
|
|
if len(cache_size_per_model) == 0:
|
|
raise ValueError("This should not happen.")
|
|
|
|
largest_model_id = max(cache_size_per_model, key=cache_size_per_model.get)
|
|
|
|
print("Removing", largest_model_id)
|
|
for sha in cached_shas_per_model[largest_model_id]:
|
|
huggingface_hub.scan_cache_dir(cache_dir).delete_revisions(sha).execute()
|
|
|
|
del cache_size_per_model[largest_model_id]
|
|
|
|
free_memory = shutil.disk_usage("/data").free * 1e-9
|
|
|
|
return extension_per_model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Cache cleaner")
|
|
parser.add_argument(
|
|
"--token", help="Hugging Face Hub token.", required=True, type=str
|
|
)
|
|
parser.add_argument("--cache-dir", help="Hub cache path.", required=True, type=str)
|
|
args = parser.parse_args()
|
|
|
|
start = time.time()
|
|
extension_per_model = cleanup_cache(args.token, args.cache_dir)
|
|
end = time.time()
|
|
|
|
print(f"Cache cleanup done in {end - start:.2f} s")
|
|
|
|
print("Downloading required models")
|
|
start = time.time()
|
|
for model_id, revision in REQUIRED_MODELS.items():
|
|
print(f"Downloading {model_id}'s *{extension_per_model[model_id]}...")
|
|
huggingface_hub.snapshot_download(
|
|
model_id,
|
|
repo_type="model",
|
|
revision=revision,
|
|
token=args.token,
|
|
allow_patterns=f"*{extension_per_model[model_id]}",
|
|
cache_dir=args.cache_dir,
|
|
)
|
|
end = time.time()
|
|
|
|
print(f"Models download done in {end - start:.2f} s")
|