Added prompt matrix to img2img
refactoring: separated duplicate code from img2img and txt2img into a single function
This commit is contained in:
parent
cb118c4036
commit
aa67540eba
254
webui.py
254
webui.py
|
@ -97,16 +97,21 @@ class KDiffusionSampler:
|
|||
sigmas = self.model_wrap.get_sigmas(S)
|
||||
x = x_T * sigmas[0]
|
||||
model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
|
||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
|
||||
|
||||
return samples_ddim, None
|
||||
|
||||
|
||||
def create_random_tensors(seed, shape, count, same_seed=False):
|
||||
def create_random_tensors(shape, seeds):
|
||||
xs = []
|
||||
for i in range(count):
|
||||
current_seed = seed if same_seed else seed + i
|
||||
torch.manual_seed(current_seed)
|
||||
for seed in seeds:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# randn results depend on device; gpu and cpu get different results for same seed;
|
||||
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
||||
# but the original script had it like this so i do not dare change it for now because
|
||||
# it will break everyone's seeds.
|
||||
xs.append(torch.randn(shape, device=device))
|
||||
x = torch.stack(xs)
|
||||
return x
|
||||
|
@ -190,7 +195,7 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
|||
color_inactive = (153, 153, 153)
|
||||
|
||||
pad_top = height // 4
|
||||
pad_left = width * 3 // 4
|
||||
pad_left = width * 3 // 4 if len(all_prompts) > 2 else 0
|
||||
|
||||
cols = im.width // width
|
||||
rows = im.height // height
|
||||
|
@ -226,63 +231,53 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
|||
return result
|
||||
|
||||
|
||||
def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||
torch.cuda.empty_cache()
|
||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN):
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
|
||||
outpath = opt.outdir or "outputs/txt2img-samples"
|
||||
assert prompt is not None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if seed == -1:
|
||||
seed = random.randrange(4294967294)
|
||||
|
||||
seed = int(seed)
|
||||
keep_same_seed = False
|
||||
|
||||
if sampler_name == 'PLMS':
|
||||
sampler = PLMSSampler(model)
|
||||
elif sampler_name == 'DDIM':
|
||||
sampler = DDIMSampler(model)
|
||||
elif sampler_name == 'k-diffusion':
|
||||
sampler = KDiffusionSampler(model)
|
||||
else:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
|
||||
os.makedirs(outpath, exist_ok=True)
|
||||
|
||||
batch_size = n_samples
|
||||
|
||||
assert prompt is not None
|
||||
prompts = batch_size * [prompt]
|
||||
|
||||
sample_path = os.path.join(outpath, "samples")
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
prompt_matrix_prompts = []
|
||||
prompt_matrix_parts = []
|
||||
if prompt_matrix:
|
||||
keep_same_seed = True
|
||||
|
||||
all_prompts = []
|
||||
prompt_matrix_parts = prompt.split("|")
|
||||
combination_count = 2 ** (len(prompt_matrix_parts)-1)
|
||||
combination_count = 2 ** (len(prompt_matrix_parts) - 1)
|
||||
for combination_num in range(combination_count):
|
||||
current = prompt_matrix_parts[0]
|
||||
|
||||
for n, text in enumerate(prompt_matrix_parts[1:]):
|
||||
if combination_num & (2**n) > 0:
|
||||
if combination_num & (2 ** n) > 0:
|
||||
current += ("" if text.strip().startswith(",") else ", ") + text
|
||||
|
||||
prompt_matrix_prompts.append(current)
|
||||
n_iter = math.ceil(len(prompt_matrix_prompts) / batch_size)
|
||||
all_prompts.append(current)
|
||||
|
||||
print(f"Prompt matrix will create {len(prompt_matrix_prompts)} images using a total of {n_iter} batches.")
|
||||
n_iter = math.ceil(len(all_prompts) / batch_size)
|
||||
all_seeds = len(all_prompts) * [seed]
|
||||
|
||||
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
|
||||
else:
|
||||
all_prompts = batch_size * n_iter * [prompt]
|
||||
all_seeds = [seed + x for x in range(len(all_prompts))]
|
||||
|
||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||
output_images = []
|
||||
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
|
||||
init_data = func_init()
|
||||
|
||||
for n in range(n_iter):
|
||||
if prompt_matrix:
|
||||
prompts = prompt_matrix_prompts[n*batch_size:(n+1)*batch_size]
|
||||
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
|
||||
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
|
||||
|
||||
uc = None
|
||||
if cfg_scale != 1.0:
|
||||
|
@ -290,14 +285,11 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
|
|||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
shape = [opt_C, height // opt_f, width // opt_f]
|
||||
|
||||
batch_seed = seed if keep_same_seed else seed + n * len(prompts)
|
||||
|
||||
# we manually generate all input noises because each one should have a specific seed
|
||||
x = create_random_tensors(batch_seed, shape, count=len(prompts), same_seed=keep_same_seed)
|
||||
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
|
||||
|
||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=len(prompts), shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=x)
|
||||
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc)
|
||||
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
@ -312,7 +304,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
|
|||
x_sample = restored_img
|
||||
|
||||
image = Image.fromarray(x_sample)
|
||||
filename = f"{base_count:05}-{seed if keep_same_seed else batch_seed + i}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
|
||||
filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
|
||||
|
||||
image.save(os.path.join(sample_path, filename))
|
||||
|
||||
|
@ -323,21 +315,68 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
|
|||
grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
|
||||
|
||||
if prompt_matrix:
|
||||
grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
|
||||
|
||||
try:
|
||||
grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
|
||||
except Exception:
|
||||
import traceback
|
||||
print("Error creating prompt_matrix text:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
output_images.insert(0, grid)
|
||||
|
||||
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||
grid_count += 1
|
||||
|
||||
del sampler
|
||||
|
||||
info = f"""
|
||||
{prompt}
|
||||
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
||||
""".strip()
|
||||
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
||||
""".strip()
|
||||
|
||||
return output_images, seed, info
|
||||
|
||||
|
||||
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||
outpath = opt.outdir or "outputs/txt2img-samples"
|
||||
|
||||
if sampler_name == 'PLMS':
|
||||
sampler = PLMSSampler(model)
|
||||
elif sampler_name == 'DDIM':
|
||||
sampler = DDIMSampler(model)
|
||||
elif sampler_name == 'k-diffusion':
|
||||
sampler = KDiffusionSampler(model)
|
||||
else:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
|
||||
def init():
|
||||
pass
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning):
|
||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
|
||||
return samples_ddim
|
||||
|
||||
output_images, seed, info = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=prompt_matrix,
|
||||
use_GFPGAN=use_GFPGAN
|
||||
)
|
||||
|
||||
del sampler
|
||||
|
||||
return output_images, seed, info
|
||||
|
||||
|
||||
class Flagging(gr.FlaggingCallback):
|
||||
|
||||
def setup(self, components, flagging_dir: str):
|
||||
|
@ -348,7 +387,7 @@ class Flagging(gr.FlaggingCallback):
|
|||
|
||||
os.makedirs("log/images", exist_ok=True)
|
||||
|
||||
# those must match the "dream" function
|
||||
# those must match the "txt2img" function
|
||||
prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data
|
||||
|
||||
filenames = []
|
||||
|
@ -379,8 +418,8 @@ class Flagging(gr.FlaggingCallback):
|
|||
print("Logged:", filenames[0])
|
||||
|
||||
|
||||
dream_interface = gr.Interface(
|
||||
dream,
|
||||
txt2img_interface = gr.Interface(
|
||||
txt2img,
|
||||
inputs=[
|
||||
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
|
||||
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
||||
|
@ -406,104 +445,70 @@ dream_interface = gr.Interface(
|
|||
)
|
||||
|
||||
|
||||
def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int):
|
||||
outpath = opt.outdir or "outputs/img2img-samples"
|
||||
|
||||
if seed == -1:
|
||||
seed = random.randrange(4294967294)
|
||||
sampler = KDiffusionSampler(model)
|
||||
|
||||
model_wrap = K.external.CompVisDenoiser(model)
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(denoising_strength * ddim_steps)
|
||||
|
||||
os.makedirs(outpath, exist_ok=True)
|
||||
def init():
|
||||
image = init_img.convert("RGB")
|
||||
image = image.resize((width, height), resample=Image.Resampling.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
batch_size = n_samples
|
||||
|
||||
assert prompt is not None
|
||||
|
||||
sample_path = os.path.join(outpath, "samples")
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
image = init_img.convert("RGB")
|
||||
image = image.resize((width, height), resample=Image.Resampling.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
output_images = []
|
||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
|
||||
init_image = 2. * image - 1.
|
||||
init_image = init_image.to(device)
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
||||
x0 = init_latent
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(denoising_strength * ddim_steps)
|
||||
return init_latent,
|
||||
|
||||
for n in range(n_iter):
|
||||
prompts = batch_size * [prompt]
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning):
|
||||
x0, = init_data
|
||||
|
||||
uc = None
|
||||
if cfg_scale != 1.0:
|
||||
uc = model.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
|
||||
noise = x * sigmas[ddim_steps - t_enc - 1]
|
||||
|
||||
batch_seed = seed + n * len(prompts)
|
||||
xi = x0 + noise
|
||||
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
|
||||
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
|
||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
|
||||
return samples_ddim
|
||||
|
||||
sigmas = model_wrap.get_sigmas(ddim_steps)
|
||||
noise = create_random_tensors(batch_seed, x0.shape[1:], count=len(prompts))
|
||||
noise = noise * sigmas[ddim_steps - t_enc - 1]
|
||||
output_images, seed, info = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name='k-diffusion',
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=prompt_matrix,
|
||||
use_GFPGAN=use_GFPGAN
|
||||
)
|
||||
|
||||
xi = x0 + noise
|
||||
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
|
||||
model_wrap_cfg = CFGDenoiser(model_wrap)
|
||||
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
|
||||
del sampler
|
||||
|
||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args=extra_args, disable=False)
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
return output_images, seed, info
|
||||
|
||||
if not opt.skip_save or not opt.skip_grid:
|
||||
for i, x_sample in enumerate(x_samples_ddim):
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
||||
if use_GFPGAN and GFPGAN is not None:
|
||||
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
x_sample = restored_img
|
||||
|
||||
image = Image.fromarray(x_sample)
|
||||
image.save(os.path.join(sample_path, f"{base_count:05}-{batch_seed+i}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"))
|
||||
|
||||
output_images.append(image)
|
||||
base_count += 1
|
||||
|
||||
if not opt.skip_grid:
|
||||
# additionally, save as grid
|
||||
grid = image_grid(output_images, batch_size)
|
||||
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||
grid_count += 1
|
||||
|
||||
return output_images, seed
|
||||
|
||||
|
||||
# prompt, init_img, ddim_steps, plms, ddim_eta, n_iter, n_samples, cfg_scale, denoising_strength, seed
|
||||
|
||||
img2img_interface = gr.Interface(
|
||||
translation,
|
||||
img2img,
|
||||
inputs=[
|
||||
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
|
||||
gr.Image(value="https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg", source="upload", interactive=True, type="pil"),
|
||||
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
||||
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
||||
gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1),
|
||||
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
|
||||
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
|
||||
|
@ -514,7 +519,8 @@ img2img_interface = gr.Interface(
|
|||
],
|
||||
outputs=[
|
||||
gr.Gallery(),
|
||||
gr.Number(label='Seed')
|
||||
gr.Number(label='Seed'),
|
||||
gr.Textbox(label="Copy-paste generation parameters"),
|
||||
],
|
||||
title="Stable Diffusion Image-to-Image",
|
||||
description="Generate images from images with Stable Diffusion",
|
||||
|
@ -522,7 +528,7 @@ img2img_interface = gr.Interface(
|
|||
)
|
||||
|
||||
interfaces = [
|
||||
(dream_interface, "txt2img"),
|
||||
(txt2img_interface, "txt2img"),
|
||||
(img2img_interface, "img2img")
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue