From 3395c29127e2dfc4467f04b40b2aec7ef3ec1196 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 23 Aug 2022 00:34:49 +0300 Subject: [PATCH] added prompt matrix feature all images in batches now have proper seeds, not just the first one added code to remove bad characters from filenames added code to flag output which writes it to csv and saves images renamed some fields in UI for clarity --- webui.py | 166 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 126 insertions(+), 40 deletions(-) diff --git a/webui.py b/webui.py index bb53e5fff..5b990a5f9 100644 --- a/webui.py +++ b/webui.py @@ -8,12 +8,12 @@ from omegaconf import OmegaConf from PIL import Image from itertools import islice from einops import rearrange, repeat -from torchvision.utils import make_grid from torch import autocast from contextlib import contextmanager, nullcontext import mimetypes import random import math +import csv import k_diffusion as K from ldm.util import instantiate_from_config @@ -28,6 +28,8 @@ mimetypes.add_type('application/javascript', '.js') opt_C = 4 opt_f = 8 +invalid_filename_chars = '<>:"/\|?*' + parser = argparse.ArgumentParser() parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None) parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",) @@ -127,13 +129,14 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp model = model.half().to(device) -def image_grid(imgs, batch_size): +def image_grid(imgs, batch_size, round_down=False): if opt.n_rows > 0: rows = opt.n_rows elif opt.n_rows == 0: rows = batch_size else: - rows = round(math.sqrt(len(imgs))) + rows = math.sqrt(len(imgs)) + rows = int(rows) if round_down else round(rows) cols = math.ceil(len(imgs) / rows) @@ -146,7 +149,7 @@ def image_grid(imgs, batch_size): return grid -def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int): +def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int): torch.cuda.empty_cache() outpath = opt.outdir or "outputs/txt2img-samples" @@ -155,6 +158,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi seed = random.randrange(4294967294) seed = int(seed) + keep_same_seed = False is_PLMS = sampler_name == 'PLMS' is_DDIM = sampler_name == 'DDIM' @@ -177,59 +181,99 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi batch_size = n_samples assert prompt is not None - data = [batch_size * [prompt]] + prompts = batch_size * [prompt] sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) base_count = len(os.listdir(sample_path)) grid_count = len(os.listdir(outpath)) - 1 + prompt_matrix_prompts = [] + comment = "" + if prompt_matrix: + keep_same_seed = True + comment = "Image prompts:\n\n" + + items = prompt.split("|") + combination_count = 2 ** (len(items)-1) + for combination_num in range(combination_count): + current = items[0] + label = 'A' + + for n, text in enumerate(items[1:]): + if combination_num & (2**n) > 0: + current += ("" if text.strip().startswith(",") else ", ") + text + label += chr(ord('B') + n) + + comment += " - " + label + "\n" + + prompt_matrix_prompts.append(current) + n_iter = math.ceil(len(prompt_matrix_prompts) / batch_size) + + comment += "\nwhere:\n" + for n, text in enumerate(items): + comment += " " + chr(ord('A') + n) + " = " + items[n] + "\n" + precision_scope = autocast if opt.precision == "autocast" else nullcontext output_images = [] with torch.no_grad(), precision_scope("cuda"), model.ema_scope(): for n in range(n_iter): - for batch_index, prompts in enumerate(data): - uc = None - if cfg_scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - shape = [opt_C, height // opt_f, width // opt_f] + if prompt_matrix: + prompts = prompt_matrix_prompts[n*batch_size:(n+1)*batch_size] - current_seed = seed + n * len(data) + batch_index + uc = None + if cfg_scale != 1.0: + uc = model.get_learned_conditioning(len(prompts) * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt_C, height // opt_f, width // opt_f] + + batch_seed = seed if keep_same_seed else seed + n * len(prompts) + + # we manually generate all input noises because each one should have a specific seed + xs = [] + for i in range(len(prompts)): + current_seed = seed if keep_same_seed else batch_seed + i torch.manual_seed(current_seed) + xs.append(torch.randn(shape, device=device)) + x = torch.stack(xs) - if is_Kdif: - sigmas = model_wrap.get_sigmas(ddim_steps) - x = torch.randn([n_samples, *shape], device=device) * sigmas[0] # for GPU draw - model_wrap_cfg = CFGDenoiser(model_wrap) - samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False) + if is_Kdif: + sigmas = model_wrap.get_sigmas(ddim_steps) + x = x * sigmas[0] + model_wrap_cfg = CFGDenoiser(model_wrap) + samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False) - elif sampler is not None: - samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=n_samples, shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=None) + elif sampler is not None: + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=len(prompts), shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=x) - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - if not opt.skip_save or not opt.skip_grid: - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - x_sample = x_sample.astype(np.uint8) + if not opt.skip_save or not opt.skip_grid: + for i, x_sample in enumerate(x_samples_ddim): + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = x_sample.astype(np.uint8) + + if use_GFPGAN and GFPGAN is not None: + cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True) + x_sample = restored_img + + image = Image.fromarray(x_sample) + filename = f"{base_count:05}-{seed if keep_same_seed else batch_seed + i}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png" + + image.save(os.path.join(sample_path, filename)) + + output_images.append(image) + base_count += 1 - if use_GFPGAN and GFPGAN is not None: - cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True) - x_sample = restored_img - image = Image.fromarray(x_sample) - image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png")) - output_images.append(image) - base_count += 1 if not opt.skip_grid: # additionally, save as grid - grid = image_grid(output_images, batch_size) + grid = image_grid(output_images, batch_size, round_down=prompt_matrix) grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) grid_count += 1 @@ -242,8 +286,49 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''} """.strip() + if len(comment) > 0: + info += "\n\n" + comment + return output_images, seed, info +class Flagging(gr.FlaggingCallback): + + def setup(self, components, flagging_dir: str): + pass + + def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int: + os.makedirs("log/images", exist_ok=True) + + # those must match the "dream" function + prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data + + filenames = [] + + with open("log/log.csv", "a", encoding="utf8", newline='') as file: + import time + import base64 + + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "cfgs", "steps", "filename"]) + + filename_base = str(int(time.time() * 1000)) + for i, filedata in enumerate(images): + filename = "log/images/"+filename_base + ("" if len(images) == 1 else "-"+str(i+1)) + ".png" + + if filedata.startswith("data:image/png;base64,"): + filedata = filedata[len("data:image/png;base64,"):] + + with open(filename, "wb") as imgfile: + imgfile.write(base64.decodebytes(filedata.encode('utf-8'))) + + filenames.append(filename) + + writer.writerow([prompt, seed, width, height, cfg_scale, ddim_steps, filenames[0]]) + + print("Logged:", filenames[0]) + dream_interface = gr.Interface( dream, @@ -252,10 +337,11 @@ dream_interface = gr.Interface( gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50), gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), + gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), - gr.Slider(minimum=1, maximum=16, step=1, label='Sampling iterations', value=1), - gr.Slider(minimum=1, maximum=4, step=1, label='Samples per iteration', value=1), - gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0), + gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1), + gr.Slider(minimum=1, maximum=4, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), + gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly should the image follow the prompt)', value=7.0), gr.Number(label='Seed', value=-1), gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512), gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512), @@ -267,7 +353,7 @@ dream_interface = gr.Interface( ], title="Stable Diffusion Text-to-Image K", description="Generate images from text with Stable Diffusion (using K-LMS)", - allow_flagging="never" + flagging_callback=Flagging() ) @@ -346,8 +432,8 @@ def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_e x_sample = restored_img image = Image.fromarray(x_sample) + image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png")) - image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png")) output_images.append(image) base_count += 1