extra samplers from K-diffusion

This commit is contained in:
AUTOMATIC 2022-08-25 23:31:44 +03:00
parent 91dc8710ec
commit c9579b51a6
1 changed files with 32 additions and 23 deletions

View File

@ -1,4 +1,6 @@
import argparse, os, sys, glob
from collections import namedtuple
import torch
import torch.nn as nn
import numpy as np
@ -16,7 +18,7 @@ import time
import json
import traceback
import k_diffusion as K
import k_diffusion.sampling
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
@ -60,6 +62,19 @@ css_hide_progressbar = """
.meta-text { display:none!important; }
"""
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
samplers = [
*[SamplerData(x[0], lambda model: KDiffusionSampler(model, x[1])) for x in [
('LMS', 'sample_lms'),
('Heun', 'sample_heun'),
('Euler', 'sample_euler'),
('Euler ancestral', 'sample_euler_ancestral'),
('DPM 2', 'sample_dpm_2'),
('DPM 2 Ancestral', 'sample_dpm_2_ancestral'),
] if hasattr(k_diffusion.sampling, x[1])],
SamplerData('DDIM', lambda model: DDIMSampler(model)),
SamplerData('PLMS', lambda model: PLMSSampler(model)),
]
class Options:
@ -142,16 +157,18 @@ class CFGDenoiser(nn.Module):
class KDiffusionSampler:
def __init__(self, m):
def __init__(self, m, funcname):
self.model = m
self.model_wrap = K.external.CompVisDenoiser(m)
self.model_wrap = k_diffusion.external.CompVisDenoiser(m)
self.funcname = funcname
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
sigmas = self.model_wrap.get_sigmas(S)
x = x_T * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model_wrap)
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
fun = getattr(k_diffusion.sampling, self.funcname)
samples_ddim = fun(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
return samples_ddim, None
@ -526,7 +543,7 @@ def get_learned_conditioning_with_embeddings(model, prompts):
return model.get_learned_conditioning(prompts)
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert prompt is not None
@ -579,7 +596,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
def infotext():
return f"""
{prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
Steps: {steps}, Sampler: {samplers[sampler_index].name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip() + "".join(["\n\n" + x for x in comments])
if os.path.exists(cmd_opts.embeddings_dir):
@ -645,17 +662,10 @@ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{',
return output_images, seed, infotext()
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
outpath = opts.outdir or "outputs/txt2img-samples"
if sampler_name == 'PLMS':
sampler = PLMSSampler(model)
elif sampler_name == 'DDIM':
sampler = DDIMSampler(model)
elif sampler_name == 'k-diffusion':
sampler = KDiffusionSampler(model)
else:
raise Exception("Unknown sampler: " + sampler_name)
sampler = samplers[sampler_index].constructor(model)
def init():
pass
@ -670,7 +680,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, p
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
sampler_index=sampler_index,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
@ -732,7 +742,7 @@ txt2img_interface = gr.Interface(
inputs=[
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"),
gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
@ -756,7 +766,7 @@ txt2img_interface = gr.Interface(
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opts.outdir or "outputs/img2img-samples"
sampler = KDiffusionSampler(model)
sampler = KDiffusionSampler(model, 'sample_lms')
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
@ -785,7 +795,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
xi = x0 + noise
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
samples_ddim = k_diffusion.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
return samples_ddim
if loopback:
@ -800,7 +810,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name='k-diffusion',
sampler_index=0,
batch_size=1,
n_iter=1,
steps=ddim_steps,
@ -835,7 +845,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name='k-diffusion',
sampler_index=0,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
@ -877,10 +887,10 @@ img2img_interface = gr.Interface(
gr.Number(label='Seed'),
gr.HTML(),
],
title="Stable Diffusion Image-to-Image",
allow_flagging="never",
)
def run_GFPGAN(image, strength):
image = image.convert("RGB")
@ -904,7 +914,6 @@ gfpgan_interface = gr.Interface(
gr.Number(label='Seed', visible=False),
gr.HTML(),
],
title="GFPGAN",
description="Fix faces on images",
allow_flagging="never",
)