extra samplers from K-diffusion
This commit is contained in:
parent
91dc8710ec
commit
c9579b51a6
55
webui.py
55
webui.py
|
@ -1,4 +1,6 @@
|
|||
import argparse, os, sys, glob
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
@ -16,7 +18,7 @@ import time
|
|||
import json
|
||||
import traceback
|
||||
|
||||
import k_diffusion as K
|
||||
import k_diffusion.sampling
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
@ -60,6 +62,19 @@ css_hide_progressbar = """
|
|||
.meta-text { display:none!important; }
|
||||
"""
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
|
||||
samplers = [
|
||||
*[SamplerData(x[0], lambda model: KDiffusionSampler(model, x[1])) for x in [
|
||||
('LMS', 'sample_lms'),
|
||||
('Heun', 'sample_heun'),
|
||||
('Euler', 'sample_euler'),
|
||||
('Euler ancestral', 'sample_euler_ancestral'),
|
||||
('DPM 2', 'sample_dpm_2'),
|
||||
('DPM 2 Ancestral', 'sample_dpm_2_ancestral'),
|
||||
] if hasattr(k_diffusion.sampling, x[1])],
|
||||
SamplerData('DDIM', lambda model: DDIMSampler(model)),
|
||||
SamplerData('PLMS', lambda model: PLMSSampler(model)),
|
||||
]
|
||||
|
||||
|
||||
class Options:
|
||||
|
@ -142,16 +157,18 @@ class CFGDenoiser(nn.Module):
|
|||
|
||||
|
||||
class KDiffusionSampler:
|
||||
def __init__(self, m):
|
||||
def __init__(self, m, funcname):
|
||||
self.model = m
|
||||
self.model_wrap = K.external.CompVisDenoiser(m)
|
||||
self.model_wrap = k_diffusion.external.CompVisDenoiser(m)
|
||||
self.funcname = funcname
|
||||
|
||||
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
|
||||
sigmas = self.model_wrap.get_sigmas(S)
|
||||
x = x_T * sigmas[0]
|
||||
model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
|
||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
|
||||
fun = getattr(k_diffusion.sampling, self.funcname)
|
||||
samples_ddim = fun(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
|
||||
|
||||
return samples_ddim, None
|
||||
|
||||
|
@ -526,7 +543,7 @@ def get_learned_conditioning_with_embeddings(model, prompts):
|
|||
return model.get_learned_conditioning(prompts)
|
||||
|
||||
|
||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
|
||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
|
||||
assert prompt is not None
|
||||
|
@ -579,7 +596,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
|||
def infotext():
|
||||
return f"""
|
||||
{prompt}
|
||||
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
||||
Steps: {steps}, Sampler: {samplers[sampler_index].name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
||||
""".strip() + "".join(["\n\n" + x for x in comments])
|
||||
|
||||
if os.path.exists(cmd_opts.embeddings_dir):
|
||||
|
@ -645,17 +662,10 @@ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{',
|
|||
return output_images, seed, infotext()
|
||||
|
||||
|
||||
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||
def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||
outpath = opts.outdir or "outputs/txt2img-samples"
|
||||
|
||||
if sampler_name == 'PLMS':
|
||||
sampler = PLMSSampler(model)
|
||||
elif sampler_name == 'DDIM':
|
||||
sampler = DDIMSampler(model)
|
||||
elif sampler_name == 'k-diffusion':
|
||||
sampler = KDiffusionSampler(model)
|
||||
else:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
sampler = samplers[sampler_index].constructor(model)
|
||||
|
||||
def init():
|
||||
pass
|
||||
|
@ -670,7 +680,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, p
|
|||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
sampler_index=sampler_index,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=ddim_steps,
|
||||
|
@ -732,7 +742,7 @@ txt2img_interface = gr.Interface(
|
|||
inputs=[
|
||||
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
|
||||
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
||||
gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"),
|
||||
gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"),
|
||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
||||
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
||||
|
@ -756,7 +766,7 @@ txt2img_interface = gr.Interface(
|
|||
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
||||
outpath = opts.outdir or "outputs/img2img-samples"
|
||||
|
||||
sampler = KDiffusionSampler(model)
|
||||
sampler = KDiffusionSampler(model, 'sample_lms')
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
|
||||
|
@ -785,7 +795,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
|
|||
xi = x0 + noise
|
||||
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
|
||||
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
|
||||
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
|
||||
samples_ddim = k_diffusion.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
|
||||
return samples_ddim
|
||||
|
||||
if loopback:
|
||||
|
@ -800,7 +810,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
|
|||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name='k-diffusion',
|
||||
sampler_index=0,
|
||||
batch_size=1,
|
||||
n_iter=1,
|
||||
steps=ddim_steps,
|
||||
|
@ -835,7 +845,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
|
|||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name='k-diffusion',
|
||||
sampler_index=0,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=ddim_steps,
|
||||
|
@ -877,10 +887,10 @@ img2img_interface = gr.Interface(
|
|||
gr.Number(label='Seed'),
|
||||
gr.HTML(),
|
||||
],
|
||||
title="Stable Diffusion Image-to-Image",
|
||||
allow_flagging="never",
|
||||
)
|
||||
|
||||
|
||||
def run_GFPGAN(image, strength):
|
||||
image = image.convert("RGB")
|
||||
|
||||
|
@ -904,7 +914,6 @@ gfpgan_interface = gr.Interface(
|
|||
gr.Number(label='Seed', visible=False),
|
||||
gr.HTML(),
|
||||
],
|
||||
title="GFPGAN",
|
||||
description="Fix faces on images",
|
||||
allow_flagging="never",
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue