feat(server): add retry on download (#384)
This commit is contained in:
parent
444400b457
commit
87dc034b59
File diff suppressed because it is too large
Load Diff
|
@ -37,7 +37,7 @@ class FlashRW(FlashCausalLM):
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16
|
dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("RW is only available on GPU")
|
raise NotImplementedError("RW is only available on GPU")
|
||||||
|
|
||||||
|
@ -124,7 +124,7 @@ class FlashRWSharded(FlashRW):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.bfloat16
|
dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
raise NotImplementedError("FlashRW is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ class RW(CausalLM):
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16
|
dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
|
@ -23,7 +23,11 @@ def weight_hub_files(
|
||||||
"""Get the weights filenames on the hub"""
|
"""Get the weights filenames on the hub"""
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
info = api.model_info(model_id, revision=revision)
|
info = api.model_info(model_id, revision=revision)
|
||||||
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
|
filenames = [
|
||||||
|
s.rfilename
|
||||||
|
for s in info.siblings
|
||||||
|
if s.rfilename.endswith(extension) and len(s.rfilename.split("/")) == 1
|
||||||
|
]
|
||||||
|
|
||||||
if not filenames:
|
if not filenames:
|
||||||
raise EntryNotFoundError(
|
raise EntryNotFoundError(
|
||||||
|
@ -130,24 +134,31 @@ def download_weights(
|
||||||
) -> List[Path]:
|
) -> List[Path]:
|
||||||
"""Download the safetensors files from the hub"""
|
"""Download the safetensors files from the hub"""
|
||||||
|
|
||||||
def download_file(filename):
|
def download_file(filename, tries=5):
|
||||||
local_file = try_to_load_from_cache(model_id, revision, filename)
|
local_file = try_to_load_from_cache(model_id, revision, filename)
|
||||||
if local_file is not None:
|
if local_file is not None:
|
||||||
logger.info(f"File {filename} already present in cache.")
|
logger.info(f"File {filename} already present in cache.")
|
||||||
return Path(local_file)
|
return Path(local_file)
|
||||||
|
|
||||||
logger.info(f"Download file: {filename}")
|
for i in range(tries):
|
||||||
start_time = time.time()
|
try:
|
||||||
local_file = hf_hub_download(
|
logger.info(f"Download file: {filename}")
|
||||||
filename=filename,
|
start_time = time.time()
|
||||||
repo_id=model_id,
|
local_file = hf_hub_download(
|
||||||
revision=revision,
|
filename=filename,
|
||||||
local_files_only=False,
|
repo_id=model_id,
|
||||||
)
|
revision=revision,
|
||||||
logger.info(
|
local_files_only=False,
|
||||||
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
)
|
||||||
)
|
logger.info(
|
||||||
return Path(local_file)
|
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
||||||
|
)
|
||||||
|
return Path(local_file)
|
||||||
|
except Exception as e:
|
||||||
|
if i + 1 == tries:
|
||||||
|
raise e
|
||||||
|
logger.error(e)
|
||||||
|
logger.info(f"Retry {i + 1}/{tries - 1}")
|
||||||
|
|
||||||
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
Loading…
Reference in New Issue