Merge branch 'ae'
This commit is contained in:
commit
7d6b388d71
|
@ -71,6 +71,8 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
|
||||
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
|
||||
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
||||
- Aesthetic Gradients, a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
||||
|
||||
|
||||
## Installation and Running
|
||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||
|
|
|
@ -0,0 +1,215 @@
|
|||
import copy
|
||||
import itertools
|
||||
import os
|
||||
from pathlib import Path
|
||||
import html
|
||||
import gc
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import optim
|
||||
|
||||
from modules import shared
|
||||
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
|
||||
from tqdm.auto import tqdm, trange
|
||||
from modules.shared import opts, device
|
||||
|
||||
|
||||
def get_all_images_in_folder(folder):
|
||||
return [os.path.join(folder, f) for f in os.listdir(folder) if
|
||||
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
|
||||
|
||||
|
||||
def check_is_valid_image_file(filename):
|
||||
return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))
|
||||
|
||||
|
||||
def batched(dataset, total, n=1):
|
||||
for ndx in range(0, total, n):
|
||||
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
|
||||
|
||||
|
||||
def iter_to_batched(iterable, n=1):
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = tuple(itertools.islice(it, n))
|
||||
if not chunk:
|
||||
return
|
||||
yield chunk
|
||||
|
||||
|
||||
def create_ui():
|
||||
with gr.Group():
|
||||
with gr.Accordion("Open for Clip Aesthetic!", open=False):
|
||||
with gr.Row():
|
||||
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
|
||||
value=0.9)
|
||||
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
|
||||
|
||||
with gr.Row():
|
||||
aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
|
||||
placeholder="Aesthetic learning rate", value="0.0001")
|
||||
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
|
||||
aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
|
||||
label="Aesthetic imgs embedding",
|
||||
value="None")
|
||||
|
||||
with gr.Row():
|
||||
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
|
||||
placeholder="This text is used to rotate the feature space of the imgs embs",
|
||||
value="")
|
||||
aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
|
||||
value=0.1)
|
||||
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
|
||||
|
||||
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
|
||||
|
||||
|
||||
def generate_imgs_embd(name, folder, batch_size):
|
||||
# clipModel = CLIPModel.from_pretrained(
|
||||
# shared.sd_model.cond_stage_model.clipModel.name_or_path
|
||||
# )
|
||||
model = shared.clip_model.to(device)
|
||||
processor = CLIPProcessor.from_pretrained(model.name_or_path)
|
||||
|
||||
with torch.no_grad():
|
||||
embs = []
|
||||
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
|
||||
desc=f"Generating embeddings for {name}"):
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
|
||||
outputs = model.get_image_features(**inputs).cpu()
|
||||
embs.append(torch.clone(outputs))
|
||||
inputs.to("cpu")
|
||||
del inputs, outputs
|
||||
|
||||
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
|
||||
|
||||
# The generated embedding will be located here
|
||||
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
|
||||
torch.save(embs, path)
|
||||
|
||||
model = model.cpu()
|
||||
del processor
|
||||
del embs
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
res = f"""
|
||||
Done generating embedding for {name}!
|
||||
Aesthetic embedding saved to {html.escape(path)}
|
||||
"""
|
||||
shared.update_aesthetic_embeddings()
|
||||
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
|
||||
value="None"), \
|
||||
gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
|
||||
label="Imgs embedding",
|
||||
value="None"), res, ""
|
||||
|
||||
|
||||
def slerp(low, high, val):
|
||||
low_norm = low / torch.norm(low, dim=1, keepdim=True)
|
||||
high_norm = high / torch.norm(high, dim=1, keepdim=True)
|
||||
omega = torch.acos((low_norm * high_norm).sum(1))
|
||||
so = torch.sin(omega)
|
||||
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
|
||||
return res
|
||||
|
||||
|
||||
class AestheticCLIP:
|
||||
def __init__(self):
|
||||
self.skip = False
|
||||
self.aesthetic_steps = 0
|
||||
self.aesthetic_weight = 0
|
||||
self.aesthetic_lr = 0
|
||||
self.slerp = False
|
||||
self.aesthetic_text_negative = ""
|
||||
self.aesthetic_slerp_angle = 0
|
||||
self.aesthetic_imgs_text = ""
|
||||
|
||||
self.image_embs_name = None
|
||||
self.image_embs = None
|
||||
self.load_image_embs(None)
|
||||
|
||||
def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
|
||||
aesthetic_slerp=True, aesthetic_imgs_text="",
|
||||
aesthetic_slerp_angle=0.15,
|
||||
aesthetic_text_negative=False):
|
||||
self.aesthetic_imgs_text = aesthetic_imgs_text
|
||||
self.aesthetic_slerp_angle = aesthetic_slerp_angle
|
||||
self.aesthetic_text_negative = aesthetic_text_negative
|
||||
self.slerp = aesthetic_slerp
|
||||
self.aesthetic_lr = aesthetic_lr
|
||||
self.aesthetic_weight = aesthetic_weight
|
||||
self.aesthetic_steps = aesthetic_steps
|
||||
self.load_image_embs(image_embs_name)
|
||||
|
||||
def set_skip(self, skip):
|
||||
self.skip = skip
|
||||
|
||||
def load_image_embs(self, image_embs_name):
|
||||
if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
|
||||
image_embs_name = None
|
||||
self.image_embs_name = None
|
||||
if image_embs_name is not None and self.image_embs_name != image_embs_name:
|
||||
self.image_embs_name = image_embs_name
|
||||
self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
|
||||
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
|
||||
self.image_embs.requires_grad_(False)
|
||||
|
||||
def __call__(self, z, remade_batch_tokens):
|
||||
if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
|
||||
tokenizer = shared.sd_model.cond_stage_model.tokenizer
|
||||
if not opts.use_old_emphasis_implementation:
|
||||
remade_batch_tokens = [
|
||||
[tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
|
||||
remade_batch_tokens]
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||
|
||||
model = copy.deepcopy(shared.clip_model).to(device)
|
||||
model.requires_grad_(True)
|
||||
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
|
||||
text_embs_2 = model.get_text_features(
|
||||
**tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
|
||||
if self.aesthetic_text_negative:
|
||||
text_embs_2 = self.image_embs - text_embs_2
|
||||
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
|
||||
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
|
||||
else:
|
||||
img_embs = self.image_embs
|
||||
|
||||
with torch.enable_grad():
|
||||
|
||||
# We optimize the model to maximize the similarity
|
||||
optimizer = optim.Adam(
|
||||
model.text_model.parameters(), lr=self.aesthetic_lr
|
||||
)
|
||||
|
||||
for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
|
||||
text_embs = model.get_text_features(input_ids=tokens)
|
||||
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
|
||||
sim = text_embs @ img_embs.T
|
||||
loss = -sim
|
||||
optimizer.zero_grad()
|
||||
loss.mean().backward()
|
||||
optimizer.step()
|
||||
|
||||
zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
if opts.CLIP_stop_at_last_layers > 1:
|
||||
zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
zn = model.text_model.final_layer_norm(zn)
|
||||
else:
|
||||
zn = zn.last_hidden_state
|
||||
model.cpu()
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
|
||||
if self.slerp:
|
||||
z = slerp(z, zn, self.aesthetic_weight)
|
||||
else:
|
||||
z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
|
||||
|
||||
return z
|
|
@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args):
|
|||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
|
||||
is_inpaint = mode == 1
|
||||
is_batch = mode == 2
|
||||
|
||||
|
@ -109,6 +109,11 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
|||
inpainting_mask_invert=inpainting_mask_invert,
|
||||
)
|
||||
|
||||
shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
|
||||
aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text,
|
||||
aesthetic_slerp_angle,
|
||||
aesthetic_text_negative)
|
||||
|
||||
if shared.cmd_opts.enable_console_prompts:
|
||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
|||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
|
||||
def apply_optimizations():
|
||||
undo_optimizations()
|
||||
|
||||
|
@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
|
||||
remade_tokens = remade_tokens[:last_comma]
|
||||
length = len(remade_tokens)
|
||||
|
||||
|
||||
rem = int(math.ceil(length / 75)) * 75 - length
|
||||
remade_tokens += [id_end] * rem + reloc_tokens
|
||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||
|
||||
|
||||
if embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
|
@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
|
||||
def process_text_old(self, text):
|
||||
id_start = self.wrapped.tokenizer.bos_token_id
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
|
@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
|
||||
token_count = len(remade_tokens)
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
|
@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
|
||||
def forward(self, text):
|
||||
use_old = opts.use_old_emphasis_implementation
|
||||
if use_old:
|
||||
|
@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
|
||||
if use_old:
|
||||
self.hijack.fixes = hijack_fixes
|
||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
|
||||
|
||||
z = None
|
||||
i = 0
|
||||
while max(map(len, remade_batch_tokens)) != 0:
|
||||
|
@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
if fix[0] == i:
|
||||
fixes.append(fix[1])
|
||||
self.hijack.fixes.append(fixes)
|
||||
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
for j in range(len(remade_batch_tokens)):
|
||||
|
@ -333,19 +333,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
|
||||
z1 = self.process_tokens(tokens, multipliers)
|
||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||
|
||||
z = shared.aesthetic_clip(z, remade_batch_tokens)
|
||||
|
||||
remade_batch_tokens = rem_tokens
|
||||
batch_multipliers = rem_multipliers
|
||||
i += 1
|
||||
|
||||
|
||||
return z
|
||||
|
||||
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
if not opts.use_old_emphasis_implementation:
|
||||
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
|
||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
|
||||
|
@ -385,8 +385,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
|||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
emb = embedding.vec
|
||||
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||
|
||||
vecs.append(tensor)
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict()
|
|||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
|
||||
from transformers import logging
|
||||
from transformers import logging, CLIPModel
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except Exception:
|
||||
|
@ -234,6 +234,9 @@ def load_model(checkpoint_info=None):
|
|||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path:
|
||||
shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path)
|
||||
|
||||
sd_model.eval()
|
||||
|
||||
print(f"Model loaded.")
|
||||
|
|
|
@ -3,6 +3,7 @@ import datetime
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
import gradio as gr
|
||||
import tqdm
|
||||
|
@ -30,6 +31,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
|
|||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(models_path, 'aesthetic_embeddings'), help="aesthetic_embeddings directory(default: aesthetic_embeddings)")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
|
@ -106,6 +108,21 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
|||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
loaded_hypernetwork = None
|
||||
|
||||
|
||||
os.makedirs(cmd_opts.aesthetic_embeddings_dir, exist_ok=True)
|
||||
aesthetic_embeddings = {}
|
||||
|
||||
|
||||
def update_aesthetic_embeddings():
|
||||
global aesthetic_embeddings
|
||||
aesthetic_embeddings = {f.replace(".pt", ""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
|
||||
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
|
||||
aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings)
|
||||
|
||||
|
||||
update_aesthetic_embeddings()
|
||||
|
||||
|
||||
def reload_hypernetworks():
|
||||
global hypernetworks
|
||||
|
||||
|
@ -387,6 +404,11 @@ sd_upscalers = []
|
|||
|
||||
sd_model = None
|
||||
|
||||
clip_model = None
|
||||
|
||||
from modules.aesthetic_clip import AestheticCLIP
|
||||
aesthetic_clip = AestheticCLIP()
|
||||
|
||||
progress_print_out = sys.stdout
|
||||
|
||||
|
||||
|
|
|
@ -276,6 +276,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
epoch_num = embedding.step // len(ds)
|
||||
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|
||||
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import modules.scripts
|
||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
||||
StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, cmd_opts
|
||||
import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
|
||||
|
||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
|
||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
|
@ -35,6 +36,10 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
|||
firstphase_height=firstphase_height if enable_hr else None,
|
||||
)
|
||||
|
||||
shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
|
||||
aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle,
|
||||
aesthetic_text_negative)
|
||||
|
||||
if cmd_opts.enable_console_prompts:
|
||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
|
@ -53,4 +58,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
|||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||
|
||||
|
|
|
@ -25,7 +25,9 @@ import gradio.routes
|
|||
|
||||
from modules import sd_hijack, sd_models, localization
|
||||
from modules.paths import script_path
|
||||
from modules.shared import opts, cmd_opts, restricted_opts
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings
|
||||
|
||||
if cmd_opts.deepdanbooru:
|
||||
from modules.deepbooru import get_deepbooru_tags
|
||||
import modules.shared as shared
|
||||
|
@ -41,8 +43,11 @@ from modules import prompt_parser
|
|||
from modules.images import save_image
|
||||
import modules.textual_inversion.ui
|
||||
import modules.hypernetworks.ui
|
||||
|
||||
import modules.aesthetic_clip as aesthetic_clip
|
||||
import modules.images_history as img_his
|
||||
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||
mimetypes.init()
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
|
@ -655,6 +660,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
|
||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
|
||||
|
||||
aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui()
|
||||
|
||||
with gr.Group():
|
||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
|
||||
|
||||
|
@ -709,7 +716,16 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
denoising_strength,
|
||||
firstphase_width,
|
||||
firstphase_height,
|
||||
aesthetic_lr,
|
||||
aesthetic_weight,
|
||||
aesthetic_steps,
|
||||
aesthetic_imgs,
|
||||
aesthetic_slerp,
|
||||
aesthetic_imgs_text,
|
||||
aesthetic_slerp_angle,
|
||||
aesthetic_text_negative
|
||||
] + custom_inputs,
|
||||
|
||||
outputs=[
|
||||
txt2img_gallery,
|
||||
generation_info,
|
||||
|
@ -870,6 +886,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
|
||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
|
||||
|
||||
aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui()
|
||||
|
||||
with gr.Group():
|
||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
|
||||
|
||||
|
@ -960,6 +978,14 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
inpainting_mask_invert,
|
||||
img2img_batch_input_dir,
|
||||
img2img_batch_output_dir,
|
||||
aesthetic_lr_im,
|
||||
aesthetic_weight_im,
|
||||
aesthetic_steps_im,
|
||||
aesthetic_imgs_im,
|
||||
aesthetic_slerp_im,
|
||||
aesthetic_imgs_text_im,
|
||||
aesthetic_slerp_angle_im,
|
||||
aesthetic_text_negative_im,
|
||||
] + custom_inputs,
|
||||
outputs=[
|
||||
img2img_gallery,
|
||||
|
@ -1220,6 +1246,18 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
with gr.Column():
|
||||
create_embedding = gr.Button(value="Create embedding", variant='primary')
|
||||
|
||||
with gr.Tab(label="Create aesthetic images embedding"):
|
||||
|
||||
new_embedding_name_ae = gr.Textbox(label="Name")
|
||||
process_src_ae = gr.Textbox(label='Source directory')
|
||||
batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
gr.HTML(value="")
|
||||
|
||||
with gr.Column():
|
||||
create_embedding_ae = gr.Button(value="Create images embedding", variant='primary')
|
||||
|
||||
with gr.Tab(label="Create hypernetwork"):
|
||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||
|
@ -1309,6 +1347,21 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
]
|
||||
)
|
||||
|
||||
create_embedding_ae.click(
|
||||
fn=aesthetic_clip.generate_imgs_embd,
|
||||
inputs=[
|
||||
new_embedding_name_ae,
|
||||
process_src_ae,
|
||||
batch_ae
|
||||
],
|
||||
outputs=[
|
||||
aesthetic_imgs,
|
||||
aesthetic_imgs_im,
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
]
|
||||
)
|
||||
|
||||
create_hypernetwork.click(
|
||||
fn=modules.hypernetworks.ui.create_hypernetwork,
|
||||
inputs=[
|
||||
|
|
Loading…
Reference in New Issue