use modelloader for #4956
This commit is contained in:
parent
2a649154ec
commit
4b0dc206ed
|
@ -1,4 +1,3 @@
|
||||||
import contextlib
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -11,12 +10,9 @@ from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import devices, paths, lowvram
|
from modules import devices, paths, lowvram, modelloader
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
blip_local_dir = os.path.join('models', 'Interrogator')
|
|
||||||
blip_local_file = os.path.join(blip_local_dir, 'model_base_caption_capfilt_large.pth')
|
|
||||||
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
|
||||||
clip_model_name = 'ViT-L/14'
|
clip_model_name = 'ViT-L/14'
|
||||||
|
|
||||||
Category = namedtuple("Category", ["name", "topn", "items"])
|
Category = namedtuple("Category", ["name", "topn", "items"])
|
||||||
|
@ -49,16 +45,14 @@ class InterrogateModels:
|
||||||
def load_blip_model(self):
|
def load_blip_model(self):
|
||||||
import models.blip
|
import models.blip
|
||||||
|
|
||||||
if not os.path.isfile(blip_local_file):
|
files = modelloader.load_models(
|
||||||
if not os.path.isdir(blip_local_dir):
|
model_path=os.path.join(paths.models_path, "BLIP"),
|
||||||
os.mkdir(blip_local_dir)
|
model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
|
||||||
|
ext_filter=[".pth"],
|
||||||
|
download_name='model_base_caption_capfilt_large.pth',
|
||||||
|
)
|
||||||
|
|
||||||
print("Downloading BLIP...")
|
blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
||||||
from requests import get as reqget
|
|
||||||
open(blip_local_file, 'wb').write(reqget(blip_model_url, allow_redirects=True).content)
|
|
||||||
print("BLIP downloaded to", blip_local_file + '.')
|
|
||||||
|
|
||||||
blip_model = models.blip.blip_decoder(pretrained=blip_local_file, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
|
||||||
blip_model.eval()
|
blip_model.eval()
|
||||||
|
|
||||||
return blip_model
|
return blip_model
|
||||||
|
|
Loading…
Reference in New Issue