do not load aesthetic clip model until it's needed
add refresh button for aesthetic embeddings add aesthetic params to images' infotext
This commit is contained in:
parent
7d6b388d71
commit
df57064093
|
@ -40,6 +40,8 @@ def iter_to_batched(iterable, n=1):
|
||||||
|
|
||||||
|
|
||||||
def create_ui():
|
def create_ui():
|
||||||
|
import modules.ui
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
with gr.Accordion("Open for Clip Aesthetic!", open=False):
|
with gr.Accordion("Open for Clip Aesthetic!", open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -55,6 +57,8 @@ def create_ui():
|
||||||
label="Aesthetic imgs embedding",
|
label="Aesthetic imgs embedding",
|
||||||
value="None")
|
value="None")
|
||||||
|
|
||||||
|
modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
|
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
|
||||||
placeholder="This text is used to rotate the feature space of the imgs embs",
|
placeholder="This text is used to rotate the feature space of the imgs embs",
|
||||||
|
@ -66,11 +70,21 @@ def create_ui():
|
||||||
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
|
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
|
||||||
|
|
||||||
|
|
||||||
|
aesthetic_clip_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def aesthetic_clip():
|
||||||
|
global aesthetic_clip_model
|
||||||
|
|
||||||
|
if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
|
||||||
|
aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
|
||||||
|
aesthetic_clip_model.cpu()
|
||||||
|
|
||||||
|
return aesthetic_clip_model
|
||||||
|
|
||||||
|
|
||||||
def generate_imgs_embd(name, folder, batch_size):
|
def generate_imgs_embd(name, folder, batch_size):
|
||||||
# clipModel = CLIPModel.from_pretrained(
|
model = aesthetic_clip().to(device)
|
||||||
# shared.sd_model.cond_stage_model.clipModel.name_or_path
|
|
||||||
# )
|
|
||||||
model = shared.clip_model.to(device)
|
|
||||||
processor = CLIPProcessor.from_pretrained(model.name_or_path)
|
processor = CLIPProcessor.from_pretrained(model.name_or_path)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -91,7 +105,7 @@ def generate_imgs_embd(name, folder, batch_size):
|
||||||
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
|
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
|
||||||
torch.save(embs, path)
|
torch.save(embs, path)
|
||||||
|
|
||||||
model = model.cpu()
|
model.cpu()
|
||||||
del processor
|
del processor
|
||||||
del embs
|
del embs
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
@ -132,7 +146,7 @@ class AestheticCLIP:
|
||||||
self.image_embs = None
|
self.image_embs = None
|
||||||
self.load_image_embs(None)
|
self.load_image_embs(None)
|
||||||
|
|
||||||
def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
|
def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
|
||||||
aesthetic_slerp=True, aesthetic_imgs_text="",
|
aesthetic_slerp=True, aesthetic_imgs_text="",
|
||||||
aesthetic_slerp_angle=0.15,
|
aesthetic_slerp_angle=0.15,
|
||||||
aesthetic_text_negative=False):
|
aesthetic_text_negative=False):
|
||||||
|
@ -145,6 +159,18 @@ class AestheticCLIP:
|
||||||
self.aesthetic_steps = aesthetic_steps
|
self.aesthetic_steps = aesthetic_steps
|
||||||
self.load_image_embs(image_embs_name)
|
self.load_image_embs(image_embs_name)
|
||||||
|
|
||||||
|
if self.image_embs_name is not None:
|
||||||
|
p.extra_generation_params.update({
|
||||||
|
"Aesthetic LR": aesthetic_lr,
|
||||||
|
"Aesthetic weight": aesthetic_weight,
|
||||||
|
"Aesthetic steps": aesthetic_steps,
|
||||||
|
"Aesthetic embedding": self.image_embs_name,
|
||||||
|
"Aesthetic slerp": aesthetic_slerp,
|
||||||
|
"Aesthetic text": aesthetic_imgs_text,
|
||||||
|
"Aesthetic text negative": aesthetic_text_negative,
|
||||||
|
"Aesthetic slerp angle": aesthetic_slerp_angle,
|
||||||
|
})
|
||||||
|
|
||||||
def set_skip(self, skip):
|
def set_skip(self, skip):
|
||||||
self.skip = skip
|
self.skip = skip
|
||||||
|
|
||||||
|
@ -168,7 +194,7 @@ class AestheticCLIP:
|
||||||
|
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||||
|
|
||||||
model = copy.deepcopy(shared.clip_model).to(device)
|
model = copy.deepcopy(aesthetic_clip()).to(device)
|
||||||
model.requires_grad_(True)
|
model.requires_grad_(True)
|
||||||
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
|
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
|
||||||
text_embs_2 = model.get_text_features(
|
text_embs_2 = model.get_text_features(
|
||||||
|
|
|
@ -4,13 +4,22 @@ import gradio as gr
|
||||||
from modules.shared import script_path
|
from modules.shared import script_path
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
|
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
|
||||||
re_param = re.compile(re_param_code)
|
re_param = re.compile(re_param_code)
|
||||||
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
||||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||||
type_of_gr_update = type(gr.update())
|
type_of_gr_update = type(gr.update())
|
||||||
|
|
||||||
|
|
||||||
|
def quote(text):
|
||||||
|
if ',' not in str(text):
|
||||||
|
return text
|
||||||
|
|
||||||
|
text = str(text)
|
||||||
|
text = text.replace('\\', '\\\\')
|
||||||
|
text = text.replace('"', '\\"')
|
||||||
|
return f'"{text}"'
|
||||||
|
|
||||||
def parse_generation_parameters(x: str):
|
def parse_generation_parameters(x: str):
|
||||||
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
||||||
```
|
```
|
||||||
|
@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
valtype = type(output.value)
|
valtype = type(output.value)
|
||||||
val = valtype(v)
|
|
||||||
|
if valtype == bool and v == "False":
|
||||||
|
val = False
|
||||||
|
else:
|
||||||
|
val = valtype(v)
|
||||||
|
|
||||||
res.append(gr.update(value=val))
|
res.append(gr.update(value=val))
|
||||||
except Exception:
|
except Exception:
|
||||||
res.append(gr.update())
|
res.append(gr.update())
|
||||||
|
|
|
@ -109,10 +109,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
||||||
inpainting_mask_invert=inpainting_mask_invert,
|
inpainting_mask_invert=inpainting_mask_invert,
|
||||||
)
|
)
|
||||||
|
|
||||||
shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
|
shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
|
||||||
aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text,
|
|
||||||
aesthetic_slerp_angle,
|
|
||||||
aesthetic_text_negative)
|
|
||||||
|
|
||||||
if shared.cmd_opts.enable_console_prompts:
|
if shared.cmd_opts.enable_console_prompts:
|
||||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
|
@ -12,7 +12,7 @@ from skimage import exposure
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
@ -318,7 +318,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
||||||
|
|
||||||
|
|
|
@ -234,9 +234,6 @@ def load_model(checkpoint_info=None):
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path:
|
|
||||||
shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path)
|
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
|
|
||||||
print(f"Model loaded.")
|
print(f"Model loaded.")
|
||||||
|
|
|
@ -36,9 +36,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
||||||
firstphase_height=firstphase_height if enable_hr else None,
|
firstphase_height=firstphase_height if enable_hr else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
|
shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
|
||||||
aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle,
|
|
||||||
aesthetic_text_negative)
|
|
||||||
|
|
||||||
if cmd_opts.enable_console_prompts:
|
if cmd_opts.enable_console_prompts:
|
||||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
|
@ -597,27 +597,29 @@ def apply_setting(key, value):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||||
|
def refresh():
|
||||||
|
refresh_method()
|
||||||
|
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
||||||
|
|
||||||
|
for k, v in args.items():
|
||||||
|
setattr(refresh_component, k, v)
|
||||||
|
|
||||||
|
return gr.update(**(args or {}))
|
||||||
|
|
||||||
|
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
|
||||||
|
refresh_button.click(
|
||||||
|
fn=refresh,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[refresh_component]
|
||||||
|
)
|
||||||
|
return refresh_button
|
||||||
|
|
||||||
|
|
||||||
def create_ui(wrap_gradio_gpu_call):
|
def create_ui(wrap_gradio_gpu_call):
|
||||||
import modules.img2img
|
import modules.img2img
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
|
|
||||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
|
||||||
def refresh():
|
|
||||||
refresh_method()
|
|
||||||
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
|
||||||
|
|
||||||
for k, v in args.items():
|
|
||||||
setattr(refresh_component, k, v)
|
|
||||||
|
|
||||||
return gr.update(**(args or {}))
|
|
||||||
|
|
||||||
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
|
|
||||||
refresh_button.click(
|
|
||||||
fn = refresh,
|
|
||||||
inputs = [],
|
|
||||||
outputs = [refresh_component]
|
|
||||||
)
|
|
||||||
return refresh_button
|
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||||
|
@ -802,6 +804,14 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
||||||
(firstphase_width, "First pass size-1"),
|
(firstphase_width, "First pass size-1"),
|
||||||
(firstphase_height, "First pass size-2"),
|
(firstphase_height, "First pass size-2"),
|
||||||
|
(aesthetic_lr, "Aesthetic LR"),
|
||||||
|
(aesthetic_weight, "Aesthetic weight"),
|
||||||
|
(aesthetic_steps, "Aesthetic steps"),
|
||||||
|
(aesthetic_imgs, "Aesthetic embedding"),
|
||||||
|
(aesthetic_slerp, "Aesthetic slerp"),
|
||||||
|
(aesthetic_imgs_text, "Aesthetic text"),
|
||||||
|
(aesthetic_text_negative, "Aesthetic text negative"),
|
||||||
|
(aesthetic_slerp_angle, "Aesthetic slerp angle"),
|
||||||
]
|
]
|
||||||
|
|
||||||
txt2img_preview_params = [
|
txt2img_preview_params = [
|
||||||
|
@ -1077,6 +1087,14 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
(seed_resize_from_w, "Seed resize from-1"),
|
(seed_resize_from_w, "Seed resize from-1"),
|
||||||
(seed_resize_from_h, "Seed resize from-2"),
|
(seed_resize_from_h, "Seed resize from-2"),
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
|
(aesthetic_lr_im, "Aesthetic LR"),
|
||||||
|
(aesthetic_weight_im, "Aesthetic weight"),
|
||||||
|
(aesthetic_steps_im, "Aesthetic steps"),
|
||||||
|
(aesthetic_imgs_im, "Aesthetic embedding"),
|
||||||
|
(aesthetic_slerp_im, "Aesthetic slerp"),
|
||||||
|
(aesthetic_imgs_text_im, "Aesthetic text"),
|
||||||
|
(aesthetic_text_negative_im, "Aesthetic text negative"),
|
||||||
|
(aesthetic_slerp_angle_im, "Aesthetic slerp angle"),
|
||||||
]
|
]
|
||||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||||
|
|
||||||
|
|
|
@ -477,7 +477,7 @@ input[type="range"]{
|
||||||
padding: 0;
|
padding: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
|
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_aesthetic_embeddings{
|
||||||
max-width: 2.5em;
|
max-width: 2.5em;
|
||||||
min-width: 2.5em;
|
min-width: 2.5em;
|
||||||
height: 2.4em;
|
height: 2.4em;
|
||||||
|
|
Loading…
Reference in New Issue