add klms sampler

This commit is contained in:
harubaru 2022-08-25 15:07:03 -07:00 committed by GitHub
parent 8e8c60e323
commit 397f154ea7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 40 additions and 17 deletions

View File

@ -12,12 +12,16 @@ from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
import torch.nn as nn
from contextlib import contextmanager, nullcontext
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from k_diffusion.sampling import sample_lms
from k_diffusion.external import CompVisDenoiser
parser = argparse.ArgumentParser()
parser.add_argument(
@ -91,7 +95,7 @@ def chunk(it, size):
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
pl_sd = torch.load(ckpt, map_location="cuda")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
@ -104,7 +108,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)
model.cuda()
model.to('cuda')
model.eval()
return model
@ -123,13 +127,25 @@ def load_img_pil(img_pil):
def load_img(path):
return load_img_pil(Image.open(path))
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.half().to(device)
def dream(prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
torch.cuda.empty_cache()
opt.H = height
@ -137,10 +153,12 @@ def dream(prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta:
rng_seed = seed_everything(seed)
if plms:
if sampler == 'plms':
sampler = PLMSSampler(model)
else:
if sampler == 'ddim':
sampler = DDIMSampler(model)
if sampler == 'k_lms':
model_wrap = CompVisDenoiser(model)
opt.outdir = "outputs/txt2img-samples"
@ -184,6 +202,13 @@ def dream(prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta:
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
if sampler == 'k_lms':
sigmas = model_wrap.get_sigmas(ddim_steps)
model_wrap_cfg = CFGDenoiser(model_wrap)
x = torch.randn([n_samples, *shape], device=device) * sigmas[0]
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
samples_ddim = sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False)
else:
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=n_samples,
@ -326,7 +351,7 @@ dream_interface = gr.Interface(
inputs=[
gr.Textbox(placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Checkbox(label='Enable PLMS sampling', value=False),
gr.Dropdown(choices=['plms', 'ddim', 'k_lms'], value='k_lms', label='Sampler'),
gr.Checkbox(label='Enable Fixed Code sampling', value=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
gr.Slider(minimum=1, maximum=8, step=1, label='Sampling iterations', value=2),
@ -344,8 +369,6 @@ dream_interface = gr.Interface(
description="Generate images from text with Stable Diffusion",
)
# prompt, init_img, ddim_steps, plms, ddim_eta, n_iter, n_samples, cfg_scale, denoising_strength, seed
img2img_interface = gr.Interface(
translation,
inputs=[