2022-09-06 01:00:21 -06:00
import argparse , os , sys , glob
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm , trange
from einops import rearrange
from torchvision . utils import make_grid , save_image
from ldm . util import instantiate_from_config
from ldm . models . diffusion . ddim import DDIMSampler
from ldm . models . diffusion . plms import PLMSSampler
def load_model_from_config ( config , ckpt , verbose = False ) :
print ( f " Loading model from { ckpt } " )
pl_sd = torch . load ( ckpt , map_location = " cpu " )
sd = pl_sd [ " state_dict " ]
model = instantiate_from_config ( config . model )
m , u = model . load_state_dict ( sd , strict = False )
if len ( m ) > 0 and verbose :
print ( " missing keys: " )
print ( m )
if len ( u ) > 0 and verbose :
print ( " unexpected keys: " )
print ( u )
model . cuda ( )
model . eval ( )
return model
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
parser . add_argument (
" --prompt " ,
type = str ,
nargs = " ? " ,
default = " a painting of a virus monster playing guitar " ,
help = " the prompt to render "
parser . add_argument (
" --outdir " ,
type = str ,
nargs = " ? " ,
help = " dir to write results to " ,
default = " outputs/txt2img-samples "
parser . add_argument (
" --ddim_steps " ,
type = int ,
2022-11-18 20:52:25 -07:00
default = 50 ,
2022-09-06 01:00:21 -06:00
help = " number of ddim sampling steps " ,
parser . add_argument (
" --plms " ,
action = ' store_true ' ,
help = " use plms sampling " ,
parser . add_argument (
" --ddim_eta " ,
type = float ,
default = 0.0 ,
help = " ddim eta (eta=0.0 corresponds to deterministic sampling " ,
parser . add_argument (
" --n_iter " ,
type = int ,
default = 1 ,
help = " sample this often " ,
parser . add_argument (
" --H " ,
type = int ,
default = 256 ,
help = " image height, in pixel space " ,
parser . add_argument (
" --W " ,
type = int ,
default = 256 ,
help = " image width, in pixel space " ,
parser . add_argument (
" --n_samples " ,
type = int ,
2022-11-18 20:52:25 -07:00
default = 1 ,
2022-09-06 01:00:21 -06:00
help = " how many samples to produce for the given prompt " ,
parser . add_argument (
" --scale " ,
type = float ,
default = 5.0 ,
help = " unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) " ,
parser . add_argument (
" --ckpt_path " ,
type = str ,
default = " /data/pretrained_models/ldm/text2img-large/model.ckpt " ,
help = " Path to pretrained ldm text2img model " )
parser . add_argument (
" --embedding_path " ,
type = str ,
help = " Path to a pre-trained embedding manager checkpoint " )
opt = parser . parse_args ( )
config = OmegaConf . load ( " configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml " ) # TODO: Optionally download from same location as ckpt and chnage this logic
model = load_model_from_config ( config , opt . ckpt_path ) # TODO: check path
device = torch . device ( " cuda " ) if torch . cuda . is_available ( ) else torch . device ( " cpu " )
model = model . to ( device )
if opt . plms :
sampler = PLMSSampler ( model )
else :
sampler = DDIMSampler ( model )
os . makedirs ( opt . outdir , exist_ok = True )
outpath = opt . outdir
prompt = opt . prompt
sample_path = os . path . join ( outpath , " samples " )
os . makedirs ( sample_path , exist_ok = True )
base_count = len ( os . listdir ( sample_path ) )
all_samples = list ( )
with torch . no_grad ( ) :
with model . ema_scope ( ) :
uc = None
if opt . scale != 1.0 :
uc = model . get_learned_conditioning ( opt . n_samples * [ " " ] )
for n in trange ( opt . n_iter , desc = " Sampling " ) :
c = model . get_learned_conditioning ( opt . n_samples * [ prompt ] )
shape = [ 4 , opt . H / / 8 , opt . W / / 8 ]
samples_ddim , _ = sampler . sample ( S = opt . ddim_steps ,
conditioning = c ,
batch_size = opt . n_samples ,
shape = shape ,
verbose = False ,
unconditional_guidance_scale = opt . scale ,
unconditional_conditioning = uc ,
eta = opt . ddim_eta )
x_samples_ddim = model . decode_first_stage ( samples_ddim )
x_samples_ddim = torch . clamp ( ( x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
for x_sample in x_samples_ddim :
x_sample = 255. * rearrange ( x_sample . cpu ( ) . numpy ( ) , ' c h w -> h w c ' )
Image . fromarray ( x_sample . astype ( np . uint8 ) ) . save ( os . path . join ( sample_path , f " { base_count : 04 } .jpg " ) )
base_count + = 1
all_samples . append ( x_samples_ddim )
# additionally, save as grid
grid = torch . stack ( all_samples , 0 )
grid = rearrange ( grid , ' n b c h w -> (n b) c h w ' )
for i in range ( grid . size ( 0 ) ) :
save_image ( grid [ i , : , : , : ] , os . path . join ( outpath , opt . prompt + ' _ {} .png ' . format ( i ) ) )
grid = make_grid ( grid , nrow = opt . n_samples )
# to image
grid = 255. * rearrange ( grid , ' c h w -> h w c ' ) . cpu ( ) . numpy ( )
Image . fromarray ( grid . astype ( np . uint8 ) ) . save ( os . path . join ( outpath , f ' { prompt . replace ( " " , " - " ) } .jpg ' ) )
print ( f " Your samples are ready and waiting four you here: \n { outpath } \n Enjoy. " )