From a2ae5a655518b150a34b95d7afecc87a43280406 Mon Sep 17 00:00:00 2001 From: "Tiago F. Santos" Date: Thu, 24 Nov 2022 13:04:45 +0000 Subject: [PATCH] [interrogator] mkdir check --- modules/interrogate.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 1a9c758e3..f177a5a85 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -14,7 +14,8 @@ import modules.shared as shared from modules import devices, paths, lowvram blip_image_eval_size = 384 -blip_model_local = os.path.join('models', 'Interrogator', 'BLIP_model.pth') +blip_local_dir = os.path.join('models', 'Interrogator') +blip_local_file = os.path.join(blip_local_dir, 'model_base_caption_capfilt_large.pth') blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' clip_model_name = 'ViT-L/14' @@ -48,13 +49,16 @@ class InterrogateModels: def load_blip_model(self): import models.blip - if not os.path.isfile(blip_model_local): - print("Downloading BLIP...") - import requests as req - open(blip_model_local, 'wb').write(req.get(blip_model_url, allow_redirects=True).content) - print("BLIP downloaded to", blip_model_local + '.') + if not os.path.isfile(blip_local_file): + if not os.path.isdir(blip_local_dir): + os.mkdir(blip_local_dir) - blip_model = models.blip.blip_decoder(pretrained=blip_model_local, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) + print("Downloading BLIP...") + from requests import get as reqget + open(blip_local_file, 'wb').write(reqget(blip_model_url, allow_redirects=True).content) + print("BLIP downloaded to", blip_local_file + '.') + + blip_model = models.blip.blip_decoder(pretrained=blip_local_file, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) blip_model.eval() return blip_model