diff --git a/webui.py b/webui.py index bb53e5fff..5b990a5f9 100644 --- a/webui.py +++ b/webui.py @@ -8,12 +8,12 @@ from omegaconf import OmegaConf from PIL import Image from itertools import islice from einops import rearrange, repeat -from torchvision.utils import make_grid from torch import autocast from contextlib import contextmanager, nullcontext import mimetypes import random import math +import csv import k_diffusion as K from ldm.util import instantiate_from_config @@ -28,6 +28,8 @@ mimetypes.add_type('application/javascript', '.js') opt_C = 4 opt_f = 8 +invalid_filename_chars = '<>:"/\|?*' + parser = argparse.ArgumentParser() parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None) parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",) @@ -127,13 +129,14 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp model = model.half().to(device) -def image_grid(imgs, batch_size): +def image_grid(imgs, batch_size, round_down=False): if opt.n_rows > 0: rows = opt.n_rows elif opt.n_rows == 0: rows = batch_size else: - rows = round(math.sqrt(len(imgs))) + rows = math.sqrt(len(imgs)) + rows = int(rows) if round_down else round(rows) cols = math.ceil(len(imgs) / rows) @@ -146,7 +149,7 @@ def image_grid(imgs, batch_size): return grid -def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int): +def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int): torch.cuda.empty_cache() outpath = opt.outdir or "outputs/txt2img-samples" @@ -155,6 +158,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi seed = random.randrange(4294967294) seed = int(seed) + keep_same_seed = False is_PLMS = sampler_name == 'PLMS' is_DDIM = sampler_name == 'DDIM' @@ -177,59 +181,99 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi batch_size = n_samples assert prompt is not None - data = [batch_size * [prompt]] + prompts = batch_size * [prompt] sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) base_count = len(os.listdir(sample_path)) grid_count = len(os.listdir(outpath)) - 1 + prompt_matrix_prompts = [] + comment = "" + if prompt_matrix: + keep_same_seed = True + comment = "Image prompts:\n\n" + + items = prompt.split("|") + combination_count = 2 ** (len(items)-1) + for combination_num in range(combination_count): + current = items[0] + label = 'A' + + for n, text in enumerate(items[1:]): + if combination_num & (2**n) > 0: + current += ("" if text.strip().startswith(",") else ", ") + text + label += chr(ord('B') + n) + + comment += " - " + label + "\n" + + prompt_matrix_prompts.append(current) + n_iter = math.ceil(len(prompt_matrix_prompts) / batch_size) + + comment += "\nwhere:\n" + for n, text in enumerate(items): + comment += " " + chr(ord('A') + n) + " = " + items[n] + "\n" + precision_scope = autocast if opt.precision == "autocast" else nullcontext output_images = [] with torch.no_grad(), precision_scope("cuda"), model.ema_scope(): for n in range(n_iter): - for batch_index, prompts in enumerate(data): - uc = None - if cfg_scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - shape = [opt_C, height // opt_f, width // opt_f] + if prompt_matrix: + prompts = prompt_matrix_prompts[n*batch_size:(n+1)*batch_size] - current_seed = seed + n * len(data) + batch_index + uc = None + if cfg_scale != 1.0: + uc = model.get_learned_conditioning(len(prompts) * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt_C, height // opt_f, width // opt_f] + + batch_seed = seed if keep_same_seed else seed + n * len(prompts) + + # we manually generate all input noises because each one should have a specific seed + xs = [] + for i in range(len(prompts)): + current_seed = seed if keep_same_seed else batch_seed + i torch.manual_seed(current_seed) + xs.append(torch.randn(shape, device=device)) + x = torch.stack(xs) - if is_Kdif: - sigmas = model_wrap.get_sigmas(ddim_steps) - x = torch.randn([n_samples, *shape], device=device) * sigmas[0] # for GPU draw - model_wrap_cfg = CFGDenoiser(model_wrap) - samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False) + if is_Kdif: + sigmas = model_wrap.get_sigmas(ddim_steps) + x = x * sigmas[0] + model_wrap_cfg = CFGDenoiser(model_wrap) + samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False) - elif sampler is not None: - samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=n_samples, shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=None) + elif sampler is not None: + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=len(prompts), shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=x) - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - if not opt.skip_save or not opt.skip_grid: - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - x_sample = x_sample.astype(np.uint8) + if not opt.skip_save or not opt.skip_grid: + for i, x_sample in enumerate(x_samples_ddim): + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = x_sample.astype(np.uint8) + + if use_GFPGAN and GFPGAN is not None: + cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True) + x_sample = restored_img + + image = Image.fromarray(x_sample) + filename = f"{base_count:05}-{seed if keep_same_seed else batch_seed + i}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png" + + image.save(os.path.join(sample_path, filename)) + + output_images.append(image) + base_count += 1 - if use_GFPGAN and GFPGAN is not None: - cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True) - x_sample = restored_img - image = Image.fromarray(x_sample) - image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png")) - output_images.append(image) - base_count += 1 if not opt.skip_grid: # additionally, save as grid - grid = image_grid(output_images, batch_size) + grid = image_grid(output_images, batch_size, round_down=prompt_matrix) grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) grid_count += 1 @@ -242,8 +286,49 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''} """.strip() + if len(comment) > 0: + info += "\n\n" + comment + return output_images, seed, info +class Flagging(gr.FlaggingCallback): + + def setup(self, components, flagging_dir: str): + pass + + def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int: + os.makedirs("log/images", exist_ok=True) + + # those must match the "dream" function + prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data + + filenames = [] + + with open("log/log.csv", "a", encoding="utf8", newline='') as file: + import time + import base64 + + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "cfgs", "steps", "filename"]) + + filename_base = str(int(time.time() * 1000)) + for i, filedata in enumerate(images): + filename = "log/images/"+filename_base + ("" if len(images) == 1 else "-"+str(i+1)) + ".png" + + if filedata.startswith("data:image/png;base64,"): + filedata = filedata[len("data:image/png;base64,"):] + + with open(filename, "wb") as imgfile: + imgfile.write(base64.decodebytes(filedata.encode('utf-8'))) + + filenames.append(filename) + + writer.writerow([prompt, seed, width, height, cfg_scale, ddim_steps, filenames[0]]) + + print("Logged:", filenames[0]) + dream_interface = gr.Interface( dream, @@ -252,10 +337,11 @@ dream_interface = gr.Interface( gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50), gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), + gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), - gr.Slider(minimum=1, maximum=16, step=1, label='Sampling iterations', value=1), - gr.Slider(minimum=1, maximum=4, step=1, label='Samples per iteration', value=1), - gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0), + gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1), + gr.Slider(minimum=1, maximum=4, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), + gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly should the image follow the prompt)', value=7.0), gr.Number(label='Seed', value=-1), gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512), gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512), @@ -267,7 +353,7 @@ dream_interface = gr.Interface( ], title="Stable Diffusion Text-to-Image K", description="Generate images from text with Stable Diffusion (using K-LMS)", - allow_flagging="never" + flagging_callback=Flagging() ) @@ -346,8 +432,8 @@ def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_e x_sample = restored_img image = Image.fromarray(x_sample) + image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png")) - image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png")) output_images.append(image) base_count += 1