diff --git a/scripts/txt2img.py b/scripts/txt2img.py index da77e1a..f99a8ab 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -1,3 +1,5 @@ +import PIL +import gradio as gr import argparse, os, sys, glob import torch import numpy as np @@ -5,7 +7,7 @@ from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange from itertools import islice -from einops import rearrange +from einops import rearrange, repeat from torchvision.utils import make_grid import time from pytorch_lightning import seed_everything @@ -16,6 +18,71 @@ from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler +parser = argparse.ArgumentParser() + +parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/img2img-samples" +) + +parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", +) + +parser.add_argument( + "--skip_save", + action='store_true', + help="do not save indiviual samples. For speed measurements.", +) +parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", +) +parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor, most often 8 or 16", +) +parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", +) +parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", +) +parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v1-inference.yaml", + help="path to config which constructs model", +) +parser.add_argument( + "--ckpt", + type=str, + default="models/ldm/stable-diffusion-v1/model.ckpt", + help="path to checkpoint of model", +) +parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" +) + +opt = parser.parse_args() def chunk(it, size): it = iter(it) @@ -41,167 +108,48 @@ def load_model_from_config(config, ckpt, verbose=False): model.eval() return model +def load_img_pil(img_pil): + image = img_pil.convert("RGB") + w, h = image.size + print(f"loaded input image of size ({w}, {h})") + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + print(f"cropped image to size ({w}, {h})") + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.*image - 1. -def main(): - parser = argparse.ArgumentParser() +def load_img(path): + return load_img_pil(Image.open(path)) - parser.add_argument( - "--prompt", - type=str, - nargs="?", - default="a painting of a virus monster playing guitar", - help="the prompt to render" - ) - parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/txt2img-samples" - ) - parser.add_argument( - "--skip_grid", - action='store_true', - help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", - ) - parser.add_argument( - "--skip_save", - action='store_true', - help="do not save individual samples. For speed measurements.", - ) - parser.add_argument( - "--ddim_steps", - type=int, - default=50, - help="number of ddim sampling steps", - ) - parser.add_argument( - "--plms", - action='store_true', - help="use plms sampling", - ) - parser.add_argument( - "--laion400m", - action='store_true', - help="uses the LAION400M model", - ) - parser.add_argument( - "--fixed_code", - action='store_true', - help="if enabled, uses the same starting code across samples ", - ) - parser.add_argument( - "--ddim_eta", - type=float, - default=0.0, - help="ddim eta (eta=0.0 corresponds to deterministic sampling", - ) - parser.add_argument( - "--n_iter", - type=int, - default=2, - help="sample this often", - ) - parser.add_argument( - "--H", - type=int, - default=512, - help="image height, in pixel space", - ) - parser.add_argument( - "--W", - type=int, - default=512, - help="image width, in pixel space", - ) - parser.add_argument( - "--C", - type=int, - default=4, - help="latent channels", - ) - parser.add_argument( - "--f", - type=int, - default=8, - help="downsampling factor", - ) - parser.add_argument( - "--n_samples", - type=int, - default=3, - help="how many samples to produce for each given prompt. A.k.a. batch size", - ) - parser.add_argument( - "--n_rows", - type=int, - default=0, - help="rows in the grid (default: n_samples)", - ) - parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", - ) - parser.add_argument( - "--from-file", - type=str, - help="if specified, load prompts from this file", - ) - parser.add_argument( - "--config", - type=str, - default="configs/stable-diffusion/v1-inference.yaml", - help="path to config which constructs model", - ) - parser.add_argument( - "--ckpt", - type=str, - default="models/ldm/stable-diffusion-v1/model.ckpt", - help="path to checkpoint of model", - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="the seed (for reproducible sampling)", - ) - parser.add_argument( - "--precision", - type=str, - help="evaluate at this precision", - choices=["full", "autocast"], - default="autocast" - ) - opt = parser.parse_args() +config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml") +model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt") - if opt.laion400m: - print("Falling back to LAION 400M model...") - opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" - opt.ckpt = "models/ldm/text2img-large/model.ckpt" - opt.outdir = "outputs/txt2img-samples-laion400m" +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +model = model.half().to(device) - seed_everything(opt.seed) +def dream(prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int): + torch.cuda.empty_cache() - config = OmegaConf.load(f"{opt.config}") - model = load_model_from_config(config, f"{opt.ckpt}") + opt.H = height + opt.W = width - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - model = model.to(device) + rng_seed = seed_everything(seed) - if opt.plms: + if plms: sampler = PLMSSampler(model) else: sampler = DDIMSampler(model) + opt.outdir = "outputs/txt2img-samples" + os.makedirs(opt.outdir, exist_ok=True) outpath = opt.outdir - batch_size = opt.n_samples + batch_size = n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size if not opt.from_file: - prompt = opt.prompt assert prompt is not None data = [batch_size * [prompt]] @@ -217,32 +165,33 @@ def main(): grid_count = len(os.listdir(outpath)) - 1 start_code = None - if opt.fixed_code: - start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + if fixed_code: + start_code = torch.randn([n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) precision_scope = autocast if opt.precision=="autocast" else nullcontext + output_images = [] with torch.no_grad(): with precision_scope("cuda"): with model.ema_scope(): tic = time.time() all_samples = list() - for n in trange(opt.n_iter, desc="Sampling"): + for n in trange(n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"): uc = None - if opt.scale != 1.0: + if cfg_scale != 1.0: uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) c = model.get_learned_conditioning(prompts) shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, - batch_size=opt.n_samples, + batch_size=n_samples, shape=shape, verbose=False, - unconditional_guidance_scale=opt.scale, + unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, - eta=opt.ddim_eta, + eta=ddim_eta, x_T=start_code) x_samples_ddim = model.decode_first_stage(samples_ddim) @@ -252,7 +201,8 @@ def main(): for x_sample in x_samples_ddim: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') Image.fromarray(x_sample.astype(np.uint8)).save( - os.path.join(sample_path, f"{base_count:05}.png")) + os.path.join(sample_path, f"{base_count:05}-{rng_seed}_{prompt.replace(' ', '_')[:128]}.png")) + output_images.append(Image.fromarray(x_sample.astype(np.uint8))) base_count += 1 if not opt.skip_grid: @@ -270,10 +220,155 @@ def main(): grid_count += 1 toc = time.time() + del sampler + return output_images, rng_seed - print(f"Your samples are ready and waiting for you here: \n{outpath} \n" - f" \nEnjoy.") +def translation(prompt: str, init_img, ddim_steps: int, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int): + torch.cuda.empty_cache() + rng_seed = seed_everything(seed) + sampler = DDIMSampler(model) -if __name__ == "__main__": - main() + opt.outdir = "outputs/img2img-samples" + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + batch_size = n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = prompt + assert prompt is not None + data = [batch_size * [prompt]] + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + image = init_img.convert("RGB") + w, h = image.size + print(f"loaded input image of size ({w}, {h})") + w, h = map(lambda x: x - x % 32, (width, height)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + print(f"cropped image to size ({w}, {h})") + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + + output_images = [] + precision_scope = autocast if opt.precision == "autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + init_image = 2.*image - 1. + init_image = init_image.to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + + sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) + + assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(denoising_strength * ddim_steps) + print(f"target t_enc is {t_enc} steps") + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if cfg_scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc,) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + if not opt.skip_save: + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}-{rng_seed}_{prompt.replace(' ', '_')[:128]}.png")) + output_images.append(Image.fromarray(x_sample.astype(np.uint8))) + base_count += 1 + all_samples.append(x_samples) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + Image.fromarray(grid.astype(np.uint8)) + grid_count += 1 + + toc = time.time() + del sampler + return output_images, rng_seed + +dream_interface = gr.Interface( + dream, + inputs=[ + gr.Textbox(placeholder="A corgi wearing a top hat as an oil painting.", lines=1), + gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50), + gr.Checkbox(label='Enable PLMS sampling', value=False), + gr.Checkbox(label='Enable Fixed Code sampling', value=False), + gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), + gr.Slider(minimum=1, maximum=8, step=1, label='Sampling iterations', value=2), + gr.Slider(minimum=1, maximum=8, step=1, label='Samples per iteration', value=2), + gr.Slider(minimum=1.0, maximum=20.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0), + gr.Number(label='Seed', value=-1), + gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512), + gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512), + ], + outputs=[ + gr.Gallery(), + gr.Number(label='Seed') + ], + title="Stable Diffusion Text-to-Image", + description="Generate images from text with Stable Diffusion", +) + +# prompt, init_img, ddim_steps, plms, ddim_eta, n_iter, n_samples, cfg_scale, denoising_strength, seed + +img2img_interface = gr.Interface( + translation, + inputs=[ + gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1), + gr.Image(value="https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg", source="upload", interactive=True, type="pil"), + gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50), + gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), + gr.Slider(minimum=1, maximum=8, step=1, label='Sampling iterations', value=2), + gr.Slider(minimum=1, maximum=8, step=1, label='Samples per iteration', value=2), + gr.Slider(minimum=1.0, maximum=20.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0), + gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75), + gr.Number(label='Seed', value=-1), + gr.Slider(minimum=64, maximum=2048, step=64, label="Resize Height", value=512), + gr.Slider(minimum=64, maximum=2048, step=64, label="Resize Width", value=512), + ], + outputs=[ + gr.Gallery(), + gr.Number(label='Seed') + ], + title="Stable Diffusion Image-to-Image", + description="Generate images from images with Stable Diffusion", +) + +demo = gr.TabbedInterface(interface_list=[dream_interface, img2img_interface], tab_names=["Dream", "Image Translation"]) + +demo.launch()