diff --git a/webui.py b/webui.py index 95dcc7512..c8a62c4d1 100644 --- a/webui.py +++ b/webui.py @@ -97,16 +97,21 @@ class KDiffusionSampler: sigmas = self.model_wrap.get_sigmas(S) x = x_T * sigmas[0] model_wrap_cfg = CFGDenoiser(self.model_wrap) + samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False) return samples_ddim, None -def create_random_tensors(seed, shape, count, same_seed=False): +def create_random_tensors(shape, seeds): xs = [] - for i in range(count): - current_seed = seed if same_seed else seed + i - torch.manual_seed(current_seed) + for seed in seeds: + torch.manual_seed(seed) + + # randn results depend on device; gpu and cpu get different results for same seed; + # the way I see it, it's better to do this on CPU, so that everyone gets same result; + # but the original script had it like this so i do not dare change it for now because + # it will break everyone's seeds. xs.append(torch.randn(shape, device=device)) x = torch.stack(xs) return x @@ -190,7 +195,7 @@ def draw_prompt_matrix(im, width, height, all_prompts): color_inactive = (153, 153, 153) pad_top = height // 4 - pad_left = width * 3 // 4 + pad_left = width * 3 // 4 if len(all_prompts) > 2 else 0 cols = im.width // width rows = im.height // height @@ -226,63 +231,53 @@ def draw_prompt_matrix(im, width, height, all_prompts): return result -def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int): - torch.cuda.empty_cache() +def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN): + """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" - outpath = opt.outdir or "outputs/txt2img-samples" + assert prompt is not None + torch.cuda.empty_cache() if seed == -1: seed = random.randrange(4294967294) - seed = int(seed) - keep_same_seed = False - - if sampler_name == 'PLMS': - sampler = PLMSSampler(model) - elif sampler_name == 'DDIM': - sampler = DDIMSampler(model) - elif sampler_name == 'k-diffusion': - sampler = KDiffusionSampler(model) - else: - raise Exception("Unknown sampler: " + sampler_name) os.makedirs(outpath, exist_ok=True) - batch_size = n_samples - - assert prompt is not None - prompts = batch_size * [prompt] - sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) base_count = len(os.listdir(sample_path)) grid_count = len(os.listdir(outpath)) - 1 - prompt_matrix_prompts = [] prompt_matrix_parts = [] if prompt_matrix: - keep_same_seed = True - + all_prompts = [] prompt_matrix_parts = prompt.split("|") - combination_count = 2 ** (len(prompt_matrix_parts)-1) + combination_count = 2 ** (len(prompt_matrix_parts) - 1) for combination_num in range(combination_count): current = prompt_matrix_parts[0] for n, text in enumerate(prompt_matrix_parts[1:]): - if combination_num & (2**n) > 0: + if combination_num & (2 ** n) > 0: current += ("" if text.strip().startswith(",") else ", ") + text - prompt_matrix_prompts.append(current) - n_iter = math.ceil(len(prompt_matrix_prompts) / batch_size) + all_prompts.append(current) - print(f"Prompt matrix will create {len(prompt_matrix_prompts)} images using a total of {n_iter} batches.") + n_iter = math.ceil(len(all_prompts) / batch_size) + all_seeds = len(all_prompts) * [seed] + + print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.") + else: + all_prompts = batch_size * n_iter * [prompt] + all_seeds = [seed + x for x in range(len(all_prompts))] precision_scope = autocast if opt.precision == "autocast" else nullcontext output_images = [] with torch.no_grad(), precision_scope("cuda"), model.ema_scope(): + init_data = func_init() + for n in range(n_iter): - if prompt_matrix: - prompts = prompt_matrix_prompts[n*batch_size:(n+1)*batch_size] + prompts = all_prompts[n * batch_size:(n + 1) * batch_size] + seeds = all_seeds[n * batch_size:(n + 1) * batch_size] uc = None if cfg_scale != 1.0: @@ -290,14 +285,11 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro if isinstance(prompts, tuple): prompts = list(prompts) c = model.get_learned_conditioning(prompts) - shape = [opt_C, height // opt_f, width // opt_f] - - batch_seed = seed if keep_same_seed else seed + n * len(prompts) # we manually generate all input noises because each one should have a specific seed - x = create_random_tensors(batch_seed, shape, count=len(prompts), same_seed=keep_same_seed) + x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds) - samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=len(prompts), shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=x) + samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -312,7 +304,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro x_sample = restored_img image = Image.fromarray(x_sample) - filename = f"{base_count:05}-{seed if keep_same_seed else batch_seed + i}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png" + filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png" image.save(os.path.join(sample_path, filename)) @@ -323,21 +315,68 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro grid = image_grid(output_images, batch_size, round_down=prompt_matrix) if prompt_matrix: - grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) + + try: + grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) + except Exception: + import traceback + print("Error creating prompt_matrix text:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + output_images.insert(0, grid) grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) grid_count += 1 - del sampler - info = f""" {prompt} -Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''} - """.strip() +Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''} + """.strip() return output_images, seed, info + +def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int): + outpath = opt.outdir or "outputs/txt2img-samples" + + if sampler_name == 'PLMS': + sampler = PLMSSampler(model) + elif sampler_name == 'DDIM': + sampler = DDIMSampler(model) + elif sampler_name == 'k-diffusion': + sampler = KDiffusionSampler(model) + else: + raise Exception("Unknown sampler: " + sampler_name) + + def init(): + pass + + def sample(init_data, x, conditioning, unconditional_conditioning): + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x) + return samples_ddim + + output_images, seed, info = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=prompt_matrix, + use_GFPGAN=use_GFPGAN + ) + + del sampler + + return output_images, seed, info + + class Flagging(gr.FlaggingCallback): def setup(self, components, flagging_dir: str): @@ -348,7 +387,7 @@ class Flagging(gr.FlaggingCallback): os.makedirs("log/images", exist_ok=True) - # those must match the "dream" function + # those must match the "txt2img" function prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data filenames = [] @@ -379,8 +418,8 @@ class Flagging(gr.FlaggingCallback): print("Logged:", filenames[0]) -dream_interface = gr.Interface( - dream, +txt2img_interface = gr.Interface( + txt2img, inputs=[ gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1), gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50), @@ -406,104 +445,70 @@ dream_interface = gr.Interface( ) -def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int): - torch.cuda.empty_cache() - +def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int): outpath = opt.outdir or "outputs/img2img-samples" - if seed == -1: - seed = random.randrange(4294967294) + sampler = KDiffusionSampler(model) - model_wrap = K.external.CompVisDenoiser(model) + assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(denoising_strength * ddim_steps) - os.makedirs(outpath, exist_ok=True) + def init(): + image = init_img.convert("RGB") + image = image.resize((width, height), resample=Image.Resampling.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) - batch_size = n_samples - - assert prompt is not None - - sample_path = os.path.join(outpath, "samples") - os.makedirs(sample_path, exist_ok=True) - base_count = len(os.listdir(sample_path)) - grid_count = len(os.listdir(outpath)) - 1 - - image = init_img.convert("RGB") - image = image.resize((width, height), resample=Image.Resampling.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - - output_images = [] - precision_scope = autocast if opt.precision == "autocast" else nullcontext - with torch.no_grad(), precision_scope("cuda"), model.ema_scope(): init_image = 2. * image - 1. init_image = init_image.to(device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space - x0 = init_latent - assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' - t_enc = int(denoising_strength * ddim_steps) + return init_latent, - for n in range(n_iter): - prompts = batch_size * [prompt] + def sample(init_data, x, conditioning, unconditional_conditioning): + x0, = init_data - uc = None - if cfg_scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) + sigmas = sampler.model_wrap.get_sigmas(ddim_steps) + noise = x * sigmas[ddim_steps - t_enc - 1] - batch_seed = seed + n * len(prompts) + xi = x0 + noise + sigma_sched = sigmas[ddim_steps - t_enc - 1:] + model_wrap_cfg = CFGDenoiser(sampler.model_wrap) + samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False) + return samples_ddim - sigmas = model_wrap.get_sigmas(ddim_steps) - noise = create_random_tensors(batch_seed, x0.shape[1:], count=len(prompts)) - noise = noise * sigmas[ddim_steps - t_enc - 1] + output_images, seed, info = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name='k-diffusion', + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=prompt_matrix, + use_GFPGAN=use_GFPGAN + ) - xi = x0 + noise - sigma_sched = sigmas[ddim_steps - t_enc - 1:] - model_wrap_cfg = CFGDenoiser(model_wrap) - extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale} + del sampler - samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args=extra_args, disable=False) - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + return output_images, seed, info - if not opt.skip_save or not opt.skip_grid: - for i, x_sample in enumerate(x_samples_ddim): - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - x_sample = x_sample.astype(np.uint8) - - if use_GFPGAN and GFPGAN is not None: - cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True) - x_sample = restored_img - - image = Image.fromarray(x_sample) - image.save(os.path.join(sample_path, f"{base_count:05}-{batch_seed+i}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png")) - - output_images.append(image) - base_count += 1 - - if not opt.skip_grid: - # additionally, save as grid - grid = image_grid(output_images, batch_size) - grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - grid_count += 1 - - return output_images, seed - - -# prompt, init_img, ddim_steps, plms, ddim_eta, n_iter, n_samples, cfg_scale, denoising_strength, seed img2img_interface = gr.Interface( - translation, + img2img, inputs=[ gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1), gr.Image(value="https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg", source="upload", interactive=True, type="pil"), gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), - gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), + gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1), gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0), @@ -514,7 +519,8 @@ img2img_interface = gr.Interface( ], outputs=[ gr.Gallery(), - gr.Number(label='Seed') + gr.Number(label='Seed'), + gr.Textbox(label="Copy-paste generation parameters"), ], title="Stable Diffusion Image-to-Image", description="Generate images from images with Stable Diffusion", @@ -522,7 +528,7 @@ img2img_interface = gr.Interface( ) interfaces = [ - (dream_interface, "txt2img"), + (txt2img_interface, "txt2img"), (img2img_interface, "img2img") ]