fixed bug with images not resizing for img2img
added GFPGAN as an option for img2img added GFPGAN as a tab added autodetection for row counts for grids, enabled by default removed Fixed Code sampling because no one can figure out what it does; maybe someone will be upset by removal and will tell me
This commit is contained in:
parent
3324f31e84
commit
b63d0726cd
116
webui.py
116
webui.py
|
@ -13,6 +13,7 @@ from torch import autocast
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import random
|
import random
|
||||||
|
import math
|
||||||
|
|
||||||
import k_diffusion as K
|
import k_diffusion as K
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
@ -31,7 +32,7 @@ parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
|
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
|
||||||
parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",)
|
parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",)
|
||||||
parser.add_argument("--skip_save", action='store_true', help="do not save indiviual samples. For speed measurements.",)
|
parser.add_argument("--skip_save", action='store_true', help="do not save indiviual samples. For speed measurements.",)
|
||||||
parser.add_argument("--n_rows", type=int, default=0, help="rows in the grid (default: n_samples)",)
|
parser.add_argument("--n_rows", type=int, default=-1, help="rows in the grid; use -1 for autodetect and 0 for n_rows to be same as batch_size (default: -1)",)
|
||||||
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
|
||||||
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
|
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
|
||||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||||
|
@ -118,6 +119,7 @@ if os.path.exists(GFPGAN_dir):
|
||||||
print("Error loading GFPGAN:", file=sys.stderr)
|
print("Error loading GFPGAN:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
|
config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
|
||||||
model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt")
|
model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt")
|
||||||
|
|
||||||
|
@ -125,18 +127,26 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
|
||||||
model = model.half().to(device)
|
model = model.half().to(device)
|
||||||
|
|
||||||
|
|
||||||
def image_grid(imgs, rows):
|
def image_grid(imgs, batch_size):
|
||||||
cols = len(imgs) // rows
|
if opt.n_rows > 0:
|
||||||
|
rows = opt.n_rows
|
||||||
|
elif opt.n_rows == 0:
|
||||||
|
rows = batch_size
|
||||||
|
else:
|
||||||
|
rows = round(math.sqrt(len(imgs)))
|
||||||
|
|
||||||
|
cols = math.ceil(len(imgs) / rows)
|
||||||
|
|
||||||
w, h = imgs[0].size
|
w, h = imgs[0].size
|
||||||
grid = Image.new('RGB', size=(cols * w, rows * h))
|
grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
|
||||||
|
|
||||||
for i, img in enumerate(imgs):
|
for i, img in enumerate(imgs):
|
||||||
grid.paste(img, box=(i % cols * w, i // cols * h))
|
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||||
|
|
||||||
return grid
|
return grid
|
||||||
|
|
||||||
def dream(prompt: str, ddim_steps: int, sampler_name: str, fixed_code: bool, use_GFPGAN: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
|
|
||||||
|
def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
outpath = opt.outdir or "outputs/txt2img-samples"
|
outpath = opt.outdir or "outputs/txt2img-samples"
|
||||||
|
@ -165,7 +175,6 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, fixed_code: bool, use
|
||||||
os.makedirs(outpath, exist_ok=True)
|
os.makedirs(outpath, exist_ok=True)
|
||||||
|
|
||||||
batch_size = n_samples
|
batch_size = n_samples
|
||||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
|
||||||
|
|
||||||
assert prompt is not None
|
assert prompt is not None
|
||||||
data = [batch_size * [prompt]]
|
data = [batch_size * [prompt]]
|
||||||
|
@ -175,15 +184,9 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, fixed_code: bool, use
|
||||||
base_count = len(os.listdir(sample_path))
|
base_count = len(os.listdir(sample_path))
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
|
||||||
start_code = None
|
|
||||||
if fixed_code:
|
|
||||||
start_code = torch.randn([n_samples, opt_C, height // opt_f, width // opt_f], device=device)
|
|
||||||
|
|
||||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||||
output_images = []
|
output_images = []
|
||||||
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
|
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
|
||||||
all_samples = []
|
|
||||||
|
|
||||||
for n in range(n_iter):
|
for n in range(n_iter):
|
||||||
for batch_index, prompts in enumerate(data):
|
for batch_index, prompts in enumerate(data):
|
||||||
uc = None
|
uc = None
|
||||||
|
@ -204,7 +207,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, fixed_code: bool, use
|
||||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False)
|
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False)
|
||||||
|
|
||||||
elif sampler is not None:
|
elif sampler is not None:
|
||||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=n_samples, shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=start_code)
|
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=n_samples, shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=None)
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
@ -224,12 +227,9 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, fixed_code: bool, use
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
base_count += 1
|
base_count += 1
|
||||||
|
|
||||||
if not opt.skip_grid:
|
|
||||||
all_samples.append(x_sample)
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
if not opt.skip_grid:
|
||||||
# additionally, save as grid
|
# additionally, save as grid
|
||||||
grid = image_grid(output_images, rows=n_rows)
|
grid = image_grid(output_images, batch_size)
|
||||||
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
grid_count += 1
|
grid_count += 1
|
||||||
|
|
||||||
|
@ -251,7 +251,6 @@ dream_interface = gr.Interface(
|
||||||
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
|
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
|
||||||
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
||||||
gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"),
|
gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"),
|
||||||
gr.Checkbox(label='Enable Fixed Code sampling', value=False),
|
|
||||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||||
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
||||||
gr.Slider(minimum=1, maximum=16, step=1, label='Sampling iterations', value=1),
|
gr.Slider(minimum=1, maximum=16, step=1, label='Sampling iterations', value=1),
|
||||||
|
@ -272,7 +271,7 @@ dream_interface = gr.Interface(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def translation(prompt: str, init_img, ddim_steps: int, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int):
|
def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
outpath = opt.outdir or "outputs/img2img-samples"
|
outpath = opt.outdir or "outputs/img2img-samples"
|
||||||
|
@ -280,14 +279,11 @@ def translation(prompt: str, init_img, ddim_steps: int, ddim_eta: float, n_iter:
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
seed = random.randrange(4294967294)
|
seed = random.randrange(4294967294)
|
||||||
|
|
||||||
sampler = DDIMSampler(model)
|
|
||||||
|
|
||||||
model_wrap = K.external.CompVisDenoiser(model)
|
model_wrap = K.external.CompVisDenoiser(model)
|
||||||
|
|
||||||
os.makedirs(outpath, exist_ok=True)
|
os.makedirs(outpath, exist_ok=True)
|
||||||
|
|
||||||
batch_size = n_samples
|
batch_size = n_samples
|
||||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
|
||||||
|
|
||||||
assert prompt is not None
|
assert prompt is not None
|
||||||
data = [batch_size * [prompt]]
|
data = [batch_size * [prompt]]
|
||||||
|
@ -299,28 +295,23 @@ def translation(prompt: str, init_img, ddim_steps: int, ddim_eta: float, n_iter:
|
||||||
seedit = 0
|
seedit = 0
|
||||||
|
|
||||||
image = init_img.convert("RGB")
|
image = init_img.convert("RGB")
|
||||||
w, h = image.size
|
image = image.resize((width, height), resample=PIL.Image.Resampling.LANCZOS)
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
|
|
||||||
output_images = []
|
output_images = []
|
||||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||||
with torch.no_grad():
|
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
|
||||||
with precision_scope("cuda"):
|
|
||||||
init_image = 2. * image - 1.
|
init_image = 2. * image - 1.
|
||||||
init_image = init_image.to(device)
|
init_image = init_image.to(device)
|
||||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
||||||
x0 = init_latent
|
x0 = init_latent
|
||||||
|
|
||||||
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)
|
|
||||||
|
|
||||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
t_enc = int(denoising_strength * ddim_steps)
|
t_enc = int(denoising_strength * ddim_steps)
|
||||||
print(f"target t_enc is {t_enc} steps")
|
|
||||||
with model.ema_scope():
|
|
||||||
all_samples = list()
|
|
||||||
for n in range(n_iter):
|
for n in range(n_iter):
|
||||||
for batch_index, prompts in enumerate(data):
|
for batch_index, prompts in enumerate(data):
|
||||||
uc = None
|
uc = None
|
||||||
|
@ -338,7 +329,6 @@ def translation(prompt: str, init_img, ddim_steps: int, ddim_eta: float, n_iter:
|
||||||
noise = torch.randn_like(x0) * sigmas[ddim_steps - t_enc - 1] # for GPU draw
|
noise = torch.randn_like(x0) * sigmas[ddim_steps - t_enc - 1] # for GPU draw
|
||||||
xi = x0 + noise
|
xi = x0 + noise
|
||||||
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
|
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
|
||||||
# x = torch.randn([n_samples, *shape]).to(device) * sigmas[0] # for CPU draw
|
|
||||||
model_wrap_cfg = CFGDenoiser(model_wrap)
|
model_wrap_cfg = CFGDenoiser(model_wrap)
|
||||||
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
|
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
|
||||||
|
|
||||||
|
@ -346,31 +336,27 @@ def translation(prompt: str, init_img, ddim_steps: int, ddim_eta: float, n_iter:
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
if not opt.skip_save:
|
if not opt.skip_save or not opt.skip_grid:
|
||||||
for x_sample in x_samples_ddim:
|
for x_sample in x_samples_ddim:
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
image = Image.fromarray(x_sample.astype(np.uint8))
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
|
||||||
|
if use_GFPGAN and GFPGAN is not None:
|
||||||
|
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
|
||||||
|
x_sample = restored_img
|
||||||
|
|
||||||
|
image = Image.fromarray(x_sample)
|
||||||
|
|
||||||
image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png"))
|
image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png"))
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
base_count += 1
|
base_count += 1
|
||||||
seedit += 1
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
|
||||||
all_samples.append(x_samples_ddim)
|
|
||||||
|
|
||||||
if not opt.skip_grid:
|
if not opt.skip_grid:
|
||||||
# additionally, save as grid
|
# additionally, save as grid
|
||||||
grid = torch.stack(all_samples, 0)
|
grid = image_grid(output_images, batch_size)
|
||||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
grid = make_grid(grid, nrow=n_rows)
|
|
||||||
|
|
||||||
# to image
|
|
||||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
|
||||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
|
||||||
Image.fromarray(grid.astype(np.uint8))
|
|
||||||
grid_count += 1
|
grid_count += 1
|
||||||
|
|
||||||
del sampler
|
|
||||||
return output_images, seed
|
return output_images, seed
|
||||||
|
|
||||||
|
|
||||||
|
@ -382,9 +368,10 @@ img2img_interface = gr.Interface(
|
||||||
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
|
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
|
||||||
gr.Image(value="https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg", source="upload", interactive=True, type="pil"),
|
gr.Image(value="https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg", source="upload", interactive=True, type="pil"),
|
||||||
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
||||||
|
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||||
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
||||||
gr.Slider(minimum=1, maximum=50, step=1, label='Sampling iterations', value=2),
|
gr.Slider(minimum=1, maximum=16, step=1, label='Sampling iterations', value=1),
|
||||||
gr.Slider(minimum=1, maximum=8, step=1, label='Samples per iteration', value=2),
|
gr.Slider(minimum=1, maximum=4, step=1, label='Samples per iteration', value=1),
|
||||||
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0),
|
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0),
|
||||||
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
|
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
|
||||||
gr.Number(label='Seed', value=-1),
|
gr.Number(label='Seed', value=-1),
|
||||||
|
@ -399,6 +386,37 @@ img2img_interface = gr.Interface(
|
||||||
description="Generate images from images with Stable Diffusion",
|
description="Generate images from images with Stable Diffusion",
|
||||||
)
|
)
|
||||||
|
|
||||||
demo = gr.TabbedInterface(interface_list=[dream_interface, img2img_interface], tab_names=["Dream", "Image Translation"])
|
interfaces = [
|
||||||
|
(dream_interface, "Dream"),
|
||||||
|
(img2img_interface, "Image Translation")
|
||||||
|
]
|
||||||
|
|
||||||
|
def run_GFPGAN(image, strength):
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
|
||||||
|
res = Image.fromarray(restored_img)
|
||||||
|
|
||||||
|
if strength < 1.0:
|
||||||
|
res = PIL.Image.blend(image, res, strength)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if GFPGAN is not None:
|
||||||
|
interfaces.append((gr.Interface(
|
||||||
|
run_GFPGAN,
|
||||||
|
inputs=[
|
||||||
|
gr.Image(label="Source", source="upload", interactive=True, type="pil"),
|
||||||
|
gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Effect strength", value=100),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
gr.Image(label="Result"),
|
||||||
|
],
|
||||||
|
title="GFPGAN",
|
||||||
|
description="Fix faces on images",
|
||||||
|
), "GFPGAN"))
|
||||||
|
|
||||||
|
demo = gr.TabbedInterface(interface_list=[x[0] for x in interfaces], tab_names=[x[1] for x in interfaces])
|
||||||
|
|
||||||
demo.launch()
|
demo.launch()
|
||||||
|
|
Loading…
Reference in New Issue