## Install

In [None]:
!git clone https://github.com/FuouM/stable-diffusion-hidamari stable-diffusion
%cd stable-diffusion
!git pull

!pip install albumentations==0.4.3
!pip install opencv-python==4.1.2.30
!pip install pudb==2019.2
!pip install imageio==2.9.0
!pip install imageio-ffmpeg==0.4.2
#!pip install pytorch-lightning==1.4.2
!pip install pytorch-lightning 
!pip install omegaconf==2.1.1
!pip install test-tube>=0.7.5
!pip install streamlit>=0.73.1
!pip install einops==0.3.0
!pip install torch-fidelity==0.3.0
# !pip install pilmoji

!pip install transformers==4.19.2

!mkdir -p '/notebooks/stable-diffusion/Source'
!mkdir -p '/notebooks/stable-diffusion/Output'

In [None]:
!mkdir -p /notebooks/stable-diffusion/src/
%cd /notebooks/stable-diffusion/src/
!git clone https://github.com/CompVis/taming-transformers.git
%cd /notebooks/stable-diffusion/src/taming-transformers
!git pull
!pip install -e .
import taming # for some reason these new packages have to be imported here and not later on or else python fails to find them

%cd /notebooks/stable-diffusion/src/
!git clone https://github.com/openai/CLIP.git
%cd /notebooks/stable-diffusion/src/CLIP
!git pull
!pip install -e .
import clip

%cd /notebooks/stable-diffusion/src/
!git clone https://github.com/crowsonkb/k-diffusion.git
%cd /notebooks/stable-diffusion/src/k-diffusion
!git pull
!pip install .
!pip install kornia
import kornia

## Download the model

In [None]:
!wget https://storage.googleapis.com/ws-store2/wd-v1-2-full-ema.ckpt -O /notebooks/stable-diffusion/model.ckpt

# Optimized SD + K-diffusion (Updated as of 8/28)

## Prepare

In [None]:
%cd /notebooks/stable-diffusion

import argparse, os, sys, glob, random
import torch
import numpy as np
from random import randint
import math

import time

from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice

from einops import rearrange, repeat
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
from ldm.util import instantiate_from_config

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

def load_model_from_config(ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    return sd

def torch_gc():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    
def load_img(init_image, h0, w0):
   
    image = init_image.convert("RGB")
    w, h = image.size

    # print(f"loaded input image of size ({w}, {h}) from {path}")   
    if(h0 is not None and w0 is not None):
        h, w = h0, w0
    
    w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 32

    print(f"New image size ({w}, {h})")
    image = image.resize((w, h), resample = Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.

LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
invalid_filename_chars = '<>:"/\|?*\n'

def resize_image(resize_mode, im, width, height):
    if resize_mode == 0:
        res = im.resize((width, height), resample=LANCZOS)
    elif resize_mode == 1:
        ratio = width / height
        src_ratio = im.width / im.height

        src_w = width if ratio > src_ratio else im.width * height // im.height
        src_h = height if ratio <= src_ratio else im.height * width // im.width

        resized = im.resize((src_w, src_h), resample=LANCZOS)
        res = Image.new("RGB", (width, height))
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
    else:
      if im.width != width or im.height != height:
        ratio = width / height
        src_ratio = im.width / im.height

        src_w = width if ratio < src_ratio else im.width * height // im.height
        src_h = height if ratio >= src_ratio else im.height * width // im.width

        resized = im.resize((src_w, src_h), resample=LANCZOS)
        res = Image.new("RGB", (width, height))
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))

        if ratio < src_ratio:
            fill_height = height // 2 - src_h // 2
            res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
            res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
        else:
            fill_width = width // 2 - src_w // 2
            res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
            res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
      else:
        return im

    return res


import PIL
from PIL import Image, ImageFont, ImageDraw 

def add_margin(pil_img, top, right, bottom, left, color):
    width, height = pil_img.size
    new_width = width + right + left
    new_height = height + top + bottom
    result = Image.new(pil_img.mode, (new_width, new_height), color)
    result.paste(pil_img, (left, top))
    return result

def text_wrap(text, font, max_width):
    lines = []
    if font.getsize(text)[0]  <= max_width:
        lines.append(text)
    else:
        words = text.split(' ')
        i = 0
        while i < len(words):
            line = ''
            while i < len(words) and font.getsize(line + words[i])[0] <= max_width:
                line = line + words[i]+ " "
                i += 1
            if not line:
                line = words[i]
                i += 1
            lines.append(line)
    return lines

def caption(image, prompt, info):
    width, height = image.size

    font = ImageFont.truetype("/notebooks/stable-diffusion/NotoSansJP-Bold.otf", 20, encoding='utf-8')
    lines = text_wrap(prompt, font, image.size[0])
    lines.append(f"{info}")
    line_height = font.getsize('hg')[1]
    cap_img = add_margin(image, 0, 0, line_height * (len(lines) + 1), 0, (255, 255, 255))
    draw = ImageDraw.Draw(cap_img)
    pad = 2
    x = pad * 2
    y = height + pad
    for line in lines:
        draw.text((x,y), line, fill=(0, 0, 0), font=font)
        y = y + line_height
    return cap_img

def get_concat_h_blank(im1, im2, color=(255, 255, 255)):
    dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height)), color)
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst

def get_concat_v_blank(im1, im2, color=(255, 255, 255)):
    dst = Image.new('RGB', (max(im1.width, im2.width), im1.height + im2.height), color)
    dst.paste(im1, (0, 0))
    dst.paste(im2, (0, im1.height))
    return dst

def image_grid(imgs, batch_size, n_rows:int):
    if n_rows > 0:
        rows = n_rows
    elif n_rows == 0:
        rows = batch_size
    else:
        rows = math.sqrt(len(imgs))
        rows = round(rows)

    cols = math.ceil(len(imgs) / rows)

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols * w, rows * h), color='black')

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))

    return grid

class User_OSD1:
  def __init__(self, prompt: str, seed: int, samples: int, steps: int, scale: float, height:int, width: int,
               rows: int, iter: int, skip_grid: bool, skip_save: bool):
    self.prompt = prompt
    self.seed = seed
    self.n_samples = samples

    self.ddim_steps = steps
    self.cfg_scale = scale
    
    self.height = height
    self.width = width

    self.n_rows = rows

    self.n_iter = iter

    self.skip_grid = skip_grid
    self.skip_save = skip_save
    
    
class User_OSD2:
    def __init__(self, prompt: str, seed: int, samples: int, steps: int, scale: float, strength: float,
                 height:int, width: int, rows: int, iter: int, skip_grid: bool, skip_save: bool):
        self.prompt = prompt
        self.seed = seed

        self.n_samples = samples

        self.ddim_steps = steps
        self.cfg_scale = scale
        self.strength = strength

        self.height = height
        self.width = width

        self.n_rows = rows
        self.n_iter = iter

        self.skip_grid = skip_grid
        self.skip_save = skip_save

config = "optimizedSD/v1-inference.yaml"
ckpt = f"model.ckpt"
device = "cuda"

sd = load_model_from_config(f"{ckpt}")
li, lo = [], []

for key, value in sd.items():
    sp = key.split('.')
    if(sp[0]) == 'model':
        if('input_blocks' in sp):
            li.append(key)
        elif('middle_block' in sp):
            li.append(key)
        elif('time_embed' in sp):
            li.append(key)
        else:
            lo.append(key)
            
for key in li:
    sd['model1.' + key[6:]] = sd.pop(key)
for key in lo:
    sd['model2.' + key[6:]] = sd.pop(key)

config = OmegaConf.load(f"{config}")


model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()

modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.cond_stage_model.device = device
modelCS.eval()
    
modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()

model.unet_bs = True
model.cdevice = device
model.turbo = True

del sd

def txt2img_generate(user: User_OSD1, out_name: str):
    torch_gc()
    
    device = "cuda"
    C = 4
    f = 8
    ddim_eta = 0.0
    start_code = None

    model.half()
    modelCS.half()

    batch_size = user.n_samples
      
    if user.seed == -1:
      user.seed = randint(0, 1000000)

    init_seed = user.seed

    seed_everything(user.seed)

    assert prompt is not None
    data = [batch_size * [prompt]]

    precision_scope = autocast

    with torch.no_grad():

        all_samples = list()
        for _ in trange(user.n_iter, desc="Sampling"):
            for prompts in tqdm(data, desc="data"):
                with precision_scope("cuda"):
                    modelCS.to(device)
                    uc = None
                    if user.cfg_scale != 1.0:
                        uc = modelCS.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    
                    c = modelCS.get_learned_conditioning(prompts)                    

                    shape = [C, height // f, width // f]
                    modelCS.to("cpu")                    

                    samples_ddim = model.sample(S=user.ddim_steps,
                                   conditioning=c,
                                   batch_size=batch_size,
                                   seed = user.seed,
                                   shape=shape,
                                   verbose=False,
                                   unconditional_guidance_scale=user.cfg_scale,
                                   unconditional_conditioning=uc,
                                   eta=ddim_eta,
                                   x_T=start_code)

                    modelFS.to(device)

                    for i in range(batch_size):
                        
                        x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
                        x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                        x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c')

                        out = Image.fromarray(x_sample.astype(np.uint8))
                        if not user.skip_save:
                          out.save(f"/notebooks/stable-diffusion/Output/{out_name}_{init_seed}[{i}].png")

                        all_samples.append(out)
                        user.seed+=1

                    modelFS.to("cpu")

                    del samples_ddim
                    del x_sample
                    del x_samples_ddim

    if not user.skip_grid:
              grid = image_grid(all_samples, batch_size, user.n_rows)
              all_samples.insert(0, grid)

    torch_gc()
    return all_samples, init_seed

def img2img_generate(user: User_OSD2, input_image, out_name: str):
    torch_gc()
    device = "cuda"
    batch_size = user.n_samples
    model.small_batch = False
    
    
    init_image = load_img(input_image, user.height, user.width).to(device).half()

    model.half()
    modelCS.half()
    modelFS.half()
    
    if user.seed == -1:
      user.seed = randint(0, 1000000)

    init_seed = user.seed

    seed_everything(user.seed)

    assert prompt is not None
    data = [batch_size * [prompt]]

    modelFS.to(device)

    init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
    init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image))  # move to latent space

    modelFS.to("cpu")

    assert 0. <= user.strength <= 1., 'can only work with strength in [0.0, 1.0]'
    t_enc = int(user.strength * user.ddim_steps)
    print(f"target t_enc is {t_enc} steps")

    precision_scope = autocast

    with torch.no_grad():
        all_samples = list()
        for _ in trange(user.n_iter, desc="Sampling"):
            for prompts in tqdm(data, desc="data"):
                with precision_scope("cuda"):
                    modelCS.to(device)
                    uc = None
                    if user.cfg_scale != 1.0:
                        uc = modelCS.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    
                    c = modelCS.get_learned_conditioning(prompts)

                    modelCS.to("cpu")

                    # encode (scaled latent)
                    z_enc = model.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device), user.seed,ddim_steps=user.ddim_steps, ddim_eta=0.0)
                    # decode it
                    samples_ddim = model.decode(z_enc, c, t_enc, unconditional_guidance_scale=user.cfg_scale,
                                                    unconditional_conditioning=uc,)

                    modelFS.to(device)
                    # print("saving images")
                    for i in range(batch_size):
                        
                        x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
                        x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                        x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c')

                        # all_samples.append(x_sample.to("cpu"))
                        # all_samples.append(Image.fromarray(x_sample.astype(np.uint8)))

                        out = Image.fromarray(x_sample.astype(np.uint8))
                        if not user.skip_save:
                          out.save(f"/notebooks/stable-diffusion/Output/{out_name}_{init_seed}[{i}].png")
                        all_samples.append(out)

                        user.seed+=1


                    modelFS.to("cpu")

                    del samples_ddim
                    del x_sample
                    del x_samples_ddim

    if not user.skip_grid:
              grid = image_grid(all_samples, batch_size, user.n_rows)
              all_samples.insert(0, grid)
    torch_gc()
    return all_samples, init_seed

def txt2img(prompt, seed, samples, steps, scale, height, width, rows, iter, skip_grid, skip_save, out_name: str):
  if(rows > samples):
    rows = samples
  user = User_OSD1(prompt, seed, samples, steps, scale, height, width, rows, iter, skip_grid, skip_save)
  return txt2img_generate(user, out_name)

def img2img(prompt, seed, samples, steps, scale, strength, height, width, rows,
            iter, skip_grid, skip_save, mode, init_image, out_name):
  if mode == "Just resize":
    resize_mode = 0
  elif mode == "Crop and resize":
    resize_mode = 1
  else:
    resize_mode = 2
  if(rows > samples):
      rows = samples
  user = User_OSD2(prompt, seed, samples, steps, scale, strength, height, width, rows, iter, skip_grid, skip_save)
  init_image = resize_image(resize_mode, init_image, width, height)

  return img2img_generate(user, init_image, out_name) + (resize_mode,)

# Inference

### Text 2 Image



In [None]:
prompt = "a cute young girl"
samples = 2
sampler = 'k_dpm_2' # ["k_euler_a","k-diffusion", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_heun"]

scale = 12 # min:1, max:30, step:0.5
steps = 120 # min:1, max:150, step:1

seed = -1

# Don't change these if you don't know what you're doing
width = 512
height = 512

skip_grid = True 
rows = 2

skip_save = False

out_name = "out" + str(int(time.time()))

# ===================================================================================================================

images, seed_new = txt2img(prompt, seed, samples,
                            steps, scale, height, width,
                            rows, 1, skip_grid, skip_save,
                            out_name)

path = "/notebooks/stable-diffusion/Output/"

save_all = True

if save_all:
  k = 0
  for i in images:
    i.save(f'{path}{name}_{k}.png')
    k += 1
else:
  index = 1
  images[index].save(f'{path}{name}_{index}.png')

### Image 2 Image



In [None]:
prompt = "" #@param {type:"string"}
sampler = 'k_dpm_2' #@param ["k_euler_a","k-diffusion", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_heun"] {allow-input: false}
init_image_path = "/notebooks/stable-diffusion/Source/794_1000.jpg" #@param {type: 'string'}

resize_mode = "Resize and fill" #@param ["Just resize", "Crop and resize", "Resize and fill"] {allow-input: false}


width = 512 #@param {type:"integer"}
height = 512 #@param {type:"integer"}

scale = 7.5 #@param {type:"slider", min:1, max:30, step:0.5}
steps = 64 #@param {type:"slider", min:1, max:150, step:1}
strength = 0.7 #@param {type: "slider", min:0.00, max:1.00, step:0.01}

samples = 2 #@param {type:'integer'}
skip_grid = True #@param {type:"boolean"}
rows = 2 #@param {type:'integer'}

seed = -1 #@param {type:'integer'}

init_image = Image.open(init_image_path)


images, seed_new, mode = img2img(prompt, init_image_path, seed, sampler, steps, scale, strength, samples, rows, height, width, skip_grid, resize_mode)

path = "/notebooks/stable-diffusion/Output/"
name = "out" + str(int(time.time()))

save_all = True #@param {type:"boolean"}

if save_all:
  k = 0
  for i in images:
    i.save(f'{path}{name}_{k}.png')
    k += 1
else:
  index = 1 #@param {type:"integer"}
  images[index].save(f'{path}{name}_{index}.png')

In [None]:
import os
os.kill(os.getpid(), 9) # Crash colab if runs out of gpu memory / Funny errors (Run from Set up again)

# Saving

In [None]:
!rm output.zip
!zip -r ./output.zip ./Output/*.png
from google.colab import files
files.download("./output.zip")
!rm ./Output/*