add klms sampler
This commit is contained in:
parent
8e8c60e323
commit
397f154ea7
|
@ -12,12 +12,16 @@ from torchvision.utils import make_grid
|
|||
import time
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
import torch.nn as nn
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
from k_diffusion.sampling import sample_lms
|
||||
from k_diffusion.external import CompVisDenoiser
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
|
@ -91,7 +95,7 @@ def chunk(it, size):
|
|||
|
||||
def load_model_from_config(config, ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
pl_sd = torch.load(ckpt, map_location="cuda")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
|
@ -104,7 +108,7 @@ def load_model_from_config(config, ckpt, verbose=False):
|
|||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.to('cuda')
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
@ -123,13 +127,25 @@ def load_img_pil(img_pil):
|
|||
def load_img(path):
|
||||
return load_img_pil(Image.open(path))
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
|
||||
model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.half().to(device)
|
||||
|
||||
def dream(prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||
def dream(prompt: str, ddim_steps: int, sampler: str, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
opt.H = height
|
||||
|
@ -137,10 +153,12 @@ def dream(prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta:
|
|||
|
||||
rng_seed = seed_everything(seed)
|
||||
|
||||
if plms:
|
||||
if sampler == 'plms':
|
||||
sampler = PLMSSampler(model)
|
||||
else:
|
||||
if sampler == 'ddim':
|
||||
sampler = DDIMSampler(model)
|
||||
if sampler == 'k_lms':
|
||||
model_wrap = CompVisDenoiser(model)
|
||||
|
||||
opt.outdir = "outputs/txt2img-samples"
|
||||
|
||||
|
@ -184,6 +202,13 @@ def dream(prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta:
|
|||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||
if sampler == 'k_lms':
|
||||
sigmas = model_wrap.get_sigmas(ddim_steps)
|
||||
model_wrap_cfg = CFGDenoiser(model_wrap)
|
||||
x = torch.randn([n_samples, *shape], device=device) * sigmas[0]
|
||||
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
|
||||
samples_ddim = sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False)
|
||||
else:
|
||||
samples_ddim, _ = sampler.sample(S=ddim_steps,
|
||||
conditioning=c,
|
||||
batch_size=n_samples,
|
||||
|
@ -326,7 +351,7 @@ dream_interface = gr.Interface(
|
|||
inputs=[
|
||||
gr.Textbox(placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
|
||||
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
||||
gr.Checkbox(label='Enable PLMS sampling', value=False),
|
||||
gr.Dropdown(choices=['plms', 'ddim', 'k_lms'], value='k_lms', label='Sampler'),
|
||||
gr.Checkbox(label='Enable Fixed Code sampling', value=False),
|
||||
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
||||
gr.Slider(minimum=1, maximum=8, step=1, label='Sampling iterations', value=2),
|
||||
|
@ -344,8 +369,6 @@ dream_interface = gr.Interface(
|
|||
description="Generate images from text with Stable Diffusion",
|
||||
)
|
||||
|
||||
# prompt, init_img, ddim_steps, plms, ddim_eta, n_iter, n_samples, cfg_scale, denoising_strength, seed
|
||||
|
||||
img2img_interface = gr.Interface(
|
||||
translation,
|
||||
inputs=[
|
||||
|
|
Loading…
Reference in New Issue