79 lines
2.6 KiB
Python
79 lines
2.6 KiB
Python
|
import itertools
|
||
|
import os
|
||
|
from pathlib import Path
|
||
|
import html
|
||
|
import gc
|
||
|
|
||
|
import gradio as gr
|
||
|
import torch
|
||
|
from PIL import Image
|
||
|
from modules import shared
|
||
|
from modules.shared import device, aesthetic_embeddings
|
||
|
from transformers import CLIPModel, CLIPProcessor
|
||
|
|
||
|
from tqdm.auto import tqdm
|
||
|
|
||
|
|
||
|
def get_all_images_in_folder(folder):
|
||
|
return [os.path.join(folder, f) for f in os.listdir(folder) if
|
||
|
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
|
||
|
|
||
|
|
||
|
def check_is_valid_image_file(filename):
|
||
|
return filename.lower().endswith(('.png', '.jpg', '.jpeg'))
|
||
|
|
||
|
|
||
|
def batched(dataset, total, n=1):
|
||
|
for ndx in range(0, total, n):
|
||
|
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
|
||
|
|
||
|
|
||
|
def iter_to_batched(iterable, n=1):
|
||
|
it = iter(iterable)
|
||
|
while True:
|
||
|
chunk = tuple(itertools.islice(it, n))
|
||
|
if not chunk:
|
||
|
return
|
||
|
yield chunk
|
||
|
|
||
|
|
||
|
def generate_imgs_embd(name, folder, batch_size):
|
||
|
# clipModel = CLIPModel.from_pretrained(
|
||
|
# shared.sd_model.cond_stage_model.clipModel.name_or_path
|
||
|
# )
|
||
|
model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device)
|
||
|
processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
embs = []
|
||
|
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
|
||
|
desc=f"Generating embeddings for {name}"):
|
||
|
if shared.state.interrupted:
|
||
|
break
|
||
|
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
|
||
|
outputs = model.get_image_features(**inputs).cpu()
|
||
|
embs.append(torch.clone(outputs))
|
||
|
inputs.to("cpu")
|
||
|
del inputs, outputs
|
||
|
|
||
|
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
|
||
|
|
||
|
# The generated embedding will be located here
|
||
|
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
|
||
|
torch.save(embs, path)
|
||
|
|
||
|
model = model.cpu()
|
||
|
del model
|
||
|
del processor
|
||
|
del embs
|
||
|
gc.collect()
|
||
|
torch.cuda.empty_cache()
|
||
|
res = f"""
|
||
|
Done generating embedding for {name}!
|
||
|
Hypernetwork saved to {html.escape(path)}
|
||
|
"""
|
||
|
shared.update_aesthetic_embeddings()
|
||
|
return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding",
|
||
|
value=sorted(aesthetic_embeddings.keys())[0] if len(
|
||
|
aesthetic_embeddings) > 0 else None), res, ""
|