feat(server): add retry on download (#384)

This commit is contained in:
OlivierDehaene 2023-05-31 10:57:53 +02:00 committed by GitHub
parent 444400b457
commit 87dc034b59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 309 additions and 298 deletions

View File

@ -37,7 +37,7 @@ class FlashRW(FlashCausalLM):
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 dtype = torch.float16
else: else:
raise NotImplementedError("RW is only available on GPU") raise NotImplementedError("RW is only available on GPU")
@ -124,7 +124,7 @@ class FlashRWSharded(FlashRW):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 dtype = torch.float16
else: else:
raise NotImplementedError("FlashRW is only available on GPU") raise NotImplementedError("FlashRW is only available on GPU")

View File

@ -16,7 +16,7 @@ class RW(CausalLM):
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 dtype = torch.float16
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")

View File

@ -23,7 +23,11 @@ def weight_hub_files(
"""Get the weights filenames on the hub""" """Get the weights filenames on the hub"""
api = HfApi() api = HfApi()
info = api.model_info(model_id, revision=revision) info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] filenames = [
s.rfilename
for s in info.siblings
if s.rfilename.endswith(extension) and len(s.rfilename.split("/")) == 1
]
if not filenames: if not filenames:
raise EntryNotFoundError( raise EntryNotFoundError(
@ -130,24 +134,31 @@ def download_weights(
) -> List[Path]: ) -> List[Path]:
"""Download the safetensors files from the hub""" """Download the safetensors files from the hub"""
def download_file(filename): def download_file(filename, tries=5):
local_file = try_to_load_from_cache(model_id, revision, filename) local_file = try_to_load_from_cache(model_id, revision, filename)
if local_file is not None: if local_file is not None:
logger.info(f"File {filename} already present in cache.") logger.info(f"File {filename} already present in cache.")
return Path(local_file) return Path(local_file)
logger.info(f"Download file: {filename}") for i in range(tries):
start_time = time.time() try:
local_file = hf_hub_download( logger.info(f"Download file: {filename}")
filename=filename, start_time = time.time()
repo_id=model_id, local_file = hf_hub_download(
revision=revision, filename=filename,
local_files_only=False, repo_id=model_id,
) revision=revision,
logger.info( local_files_only=False,
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}." )
) logger.info(
return Path(local_file) f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
)
return Path(local_file)
except Exception as e:
if i + 1 == tries:
raise e
logger.error(e)
logger.info(f"Retry {i + 1}/{tries - 1}")
# We do this instead of using tqdm because we want to parse the logs with the launcher # We do this instead of using tqdm because we want to parse the logs with the launcher
start_time = time.time() start_time = time.time()