Add files via upload
This commit is contained in:
parent
982f87035d
commit
5a54e79c5f
|
@ -0,0 +1,820 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "jwZ0GT0eObBW",
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Install"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "hTjD9Ij7Nuh4",
|
||||||
|
"scrolled": true,
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!git clone https://github.com/FuouM/stable-diffusion-hidamari stable-diffusion\n",
|
||||||
|
"%cd stable-diffusion\n",
|
||||||
|
"!git pull\n",
|
||||||
|
"\n",
|
||||||
|
"!pip install albumentations==0.4.3\n",
|
||||||
|
"!pip install opencv-python==4.1.2.30\n",
|
||||||
|
"!pip install pudb==2019.2\n",
|
||||||
|
"!pip install imageio==2.9.0\n",
|
||||||
|
"!pip install imageio-ffmpeg==0.4.2\n",
|
||||||
|
"#!pip install pytorch-lightning==1.4.2\n",
|
||||||
|
"!pip install pytorch-lightning \n",
|
||||||
|
"!pip install omegaconf==2.1.1\n",
|
||||||
|
"!pip install test-tube>=0.7.5\n",
|
||||||
|
"!pip install streamlit>=0.73.1\n",
|
||||||
|
"!pip install einops==0.3.0\n",
|
||||||
|
"!pip install torch-fidelity==0.3.0\n",
|
||||||
|
"# !pip install pilmoji\n",
|
||||||
|
"\n",
|
||||||
|
"!pip install transformers==4.19.2\n",
|
||||||
|
"\n",
|
||||||
|
"!mkdir -p '/notebooks/stable-diffusion/Source'\n",
|
||||||
|
"!mkdir -p '/notebooks/stable-diffusion/Output'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true,
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!mkdir -p /notebooks/stable-diffusion/src/\n",
|
||||||
|
"%cd /notebooks/stable-diffusion/src/\n",
|
||||||
|
"!git clone https://github.com/CompVis/taming-transformers.git\n",
|
||||||
|
"%cd /notebooks/stable-diffusion/src/taming-transformers\n",
|
||||||
|
"!git pull\n",
|
||||||
|
"!pip install -e .\n",
|
||||||
|
"import taming # for some reason these new packages have to be imported here and not later on or else python fails to find them\n",
|
||||||
|
"\n",
|
||||||
|
"%cd /notebooks/stable-diffusion/src/\n",
|
||||||
|
"!git clone https://github.com/openai/CLIP.git\n",
|
||||||
|
"%cd /notebooks/stable-diffusion/src/CLIP\n",
|
||||||
|
"!git pull\n",
|
||||||
|
"!pip install -e .\n",
|
||||||
|
"import clip\n",
|
||||||
|
"\n",
|
||||||
|
"%cd /notebooks/stable-diffusion/src/\n",
|
||||||
|
"!git clone https://github.com/crowsonkb/k-diffusion.git\n",
|
||||||
|
"%cd /notebooks/stable-diffusion/src/k-diffusion\n",
|
||||||
|
"!git pull\n",
|
||||||
|
"!pip install .\n",
|
||||||
|
"!pip install kornia\n",
|
||||||
|
"import kornia"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "XvHTXI7KOnu4",
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Download the model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "3BQoPx_8Hj08"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!wget https://storage.googleapis.com/ws-store2/wd-v1-2-full-ema.ckpt -O /notebooks/stable-diffusion/model.ckpt"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "CszrKJDe-66T",
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Optimized SD + K-diffusion (Updated as of 8/28)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "0PXSWiHjOROA",
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Prepare"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "L70j4gz6_Aq1",
|
||||||
|
"scrolled": true,
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%cd /notebooks/stable-diffusion\n",
|
||||||
|
"\n",
|
||||||
|
"import argparse, os, sys, glob, random\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from random import randint\n",
|
||||||
|
"import math\n",
|
||||||
|
"\n",
|
||||||
|
"import time\n",
|
||||||
|
"\n",
|
||||||
|
"from omegaconf import OmegaConf\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"from tqdm import tqdm, trange\n",
|
||||||
|
"from itertools import islice\n",
|
||||||
|
"\n",
|
||||||
|
"from einops import rearrange, repeat\n",
|
||||||
|
"import time\n",
|
||||||
|
"from pytorch_lightning import seed_everything\n",
|
||||||
|
"from torch import autocast\n",
|
||||||
|
"from contextlib import contextmanager, nullcontext\n",
|
||||||
|
"from ldm.util import instantiate_from_config\n",
|
||||||
|
"\n",
|
||||||
|
"def chunk(it, size):\n",
|
||||||
|
" it = iter(it)\n",
|
||||||
|
" return iter(lambda: tuple(islice(it, size)), ())\n",
|
||||||
|
"\n",
|
||||||
|
"def load_model_from_config(ckpt, verbose=False):\n",
|
||||||
|
" print(f\"Loading model from {ckpt}\")\n",
|
||||||
|
" pl_sd = torch.load(ckpt, map_location=\"cpu\")\n",
|
||||||
|
" if \"global_step\" in pl_sd:\n",
|
||||||
|
" print(f\"Global Step: {pl_sd['global_step']}\")\n",
|
||||||
|
" sd = pl_sd[\"state_dict\"]\n",
|
||||||
|
" return sd\n",
|
||||||
|
"\n",
|
||||||
|
"def torch_gc():\n",
|
||||||
|
" torch.cuda.empty_cache()\n",
|
||||||
|
" torch.cuda.ipc_collect()\n",
|
||||||
|
" \n",
|
||||||
|
"def load_img(init_image, h0, w0):\n",
|
||||||
|
" \n",
|
||||||
|
" image = init_image.convert(\"RGB\")\n",
|
||||||
|
" w, h = image.size\n",
|
||||||
|
"\n",
|
||||||
|
" # print(f\"loaded input image of size ({w}, {h}) from {path}\") \n",
|
||||||
|
" if(h0 is not None and w0 is not None):\n",
|
||||||
|
" h, w = h0, w0\n",
|
||||||
|
" \n",
|
||||||
|
" w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"New image size ({w}, {h})\")\n",
|
||||||
|
" image = image.resize((w, h), resample = Image.LANCZOS)\n",
|
||||||
|
" image = np.array(image).astype(np.float32) / 255.0\n",
|
||||||
|
" image = image[None].transpose(0, 3, 1, 2)\n",
|
||||||
|
" image = torch.from_numpy(image)\n",
|
||||||
|
" return 2.*image - 1.\n",
|
||||||
|
"\n",
|
||||||
|
"LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)\n",
|
||||||
|
"invalid_filename_chars = '<>:\"/\\|?*\\n'\n",
|
||||||
|
"\n",
|
||||||
|
"def resize_image(resize_mode, im, width, height):\n",
|
||||||
|
" if resize_mode == 0:\n",
|
||||||
|
" res = im.resize((width, height), resample=LANCZOS)\n",
|
||||||
|
" elif resize_mode == 1:\n",
|
||||||
|
" ratio = width / height\n",
|
||||||
|
" src_ratio = im.width / im.height\n",
|
||||||
|
"\n",
|
||||||
|
" src_w = width if ratio > src_ratio else im.width * height // im.height\n",
|
||||||
|
" src_h = height if ratio <= src_ratio else im.height * width // im.width\n",
|
||||||
|
"\n",
|
||||||
|
" resized = im.resize((src_w, src_h), resample=LANCZOS)\n",
|
||||||
|
" res = Image.new(\"RGB\", (width, height))\n",
|
||||||
|
" res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))\n",
|
||||||
|
" else:\n",
|
||||||
|
" if im.width != width or im.height != height:\n",
|
||||||
|
" ratio = width / height\n",
|
||||||
|
" src_ratio = im.width / im.height\n",
|
||||||
|
"\n",
|
||||||
|
" src_w = width if ratio < src_ratio else im.width * height // im.height\n",
|
||||||
|
" src_h = height if ratio >= src_ratio else im.height * width // im.width\n",
|
||||||
|
"\n",
|
||||||
|
" resized = im.resize((src_w, src_h), resample=LANCZOS)\n",
|
||||||
|
" res = Image.new(\"RGB\", (width, height))\n",
|
||||||
|
" res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))\n",
|
||||||
|
"\n",
|
||||||
|
" if ratio < src_ratio:\n",
|
||||||
|
" fill_height = height // 2 - src_h // 2\n",
|
||||||
|
" res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))\n",
|
||||||
|
" res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))\n",
|
||||||
|
" else:\n",
|
||||||
|
" fill_width = width // 2 - src_w // 2\n",
|
||||||
|
" res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))\n",
|
||||||
|
" res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))\n",
|
||||||
|
" else:\n",
|
||||||
|
" return im\n",
|
||||||
|
"\n",
|
||||||
|
" return res\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"import PIL\n",
|
||||||
|
"from PIL import Image, ImageFont, ImageDraw \n",
|
||||||
|
"\n",
|
||||||
|
"def add_margin(pil_img, top, right, bottom, left, color):\n",
|
||||||
|
" width, height = pil_img.size\n",
|
||||||
|
" new_width = width + right + left\n",
|
||||||
|
" new_height = height + top + bottom\n",
|
||||||
|
" result = Image.new(pil_img.mode, (new_width, new_height), color)\n",
|
||||||
|
" result.paste(pil_img, (left, top))\n",
|
||||||
|
" return result\n",
|
||||||
|
"\n",
|
||||||
|
"def text_wrap(text, font, max_width):\n",
|
||||||
|
" lines = []\n",
|
||||||
|
" if font.getsize(text)[0] <= max_width:\n",
|
||||||
|
" lines.append(text)\n",
|
||||||
|
" else:\n",
|
||||||
|
" words = text.split(' ')\n",
|
||||||
|
" i = 0\n",
|
||||||
|
" while i < len(words):\n",
|
||||||
|
" line = ''\n",
|
||||||
|
" while i < len(words) and font.getsize(line + words[i])[0] <= max_width:\n",
|
||||||
|
" line = line + words[i]+ \" \"\n",
|
||||||
|
" i += 1\n",
|
||||||
|
" if not line:\n",
|
||||||
|
" line = words[i]\n",
|
||||||
|
" i += 1\n",
|
||||||
|
" lines.append(line)\n",
|
||||||
|
" return lines\n",
|
||||||
|
"\n",
|
||||||
|
"def caption(image, prompt, info):\n",
|
||||||
|
" width, height = image.size\n",
|
||||||
|
"\n",
|
||||||
|
" font = ImageFont.truetype(\"/notebooks/stable-diffusion/NotoSansJP-Bold.otf\", 20, encoding='utf-8')\n",
|
||||||
|
" lines = text_wrap(prompt, font, image.size[0])\n",
|
||||||
|
" lines.append(f\"{info}\")\n",
|
||||||
|
" line_height = font.getsize('hg')[1]\n",
|
||||||
|
" cap_img = add_margin(image, 0, 0, line_height * (len(lines) + 1), 0, (255, 255, 255))\n",
|
||||||
|
" draw = ImageDraw.Draw(cap_img)\n",
|
||||||
|
" pad = 2\n",
|
||||||
|
" x = pad * 2\n",
|
||||||
|
" y = height + pad\n",
|
||||||
|
" for line in lines:\n",
|
||||||
|
" draw.text((x,y), line, fill=(0, 0, 0), font=font)\n",
|
||||||
|
" y = y + line_height\n",
|
||||||
|
" return cap_img\n",
|
||||||
|
"\n",
|
||||||
|
"def get_concat_h_blank(im1, im2, color=(255, 255, 255)):\n",
|
||||||
|
" dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height)), color)\n",
|
||||||
|
" dst.paste(im1, (0, 0))\n",
|
||||||
|
" dst.paste(im2, (im1.width, 0))\n",
|
||||||
|
" return dst\n",
|
||||||
|
"\n",
|
||||||
|
"def get_concat_v_blank(im1, im2, color=(255, 255, 255)):\n",
|
||||||
|
" dst = Image.new('RGB', (max(im1.width, im2.width), im1.height + im2.height), color)\n",
|
||||||
|
" dst.paste(im1, (0, 0))\n",
|
||||||
|
" dst.paste(im2, (0, im1.height))\n",
|
||||||
|
" return dst\n",
|
||||||
|
"\n",
|
||||||
|
"def image_grid(imgs, batch_size, n_rows:int):\n",
|
||||||
|
" if n_rows > 0:\n",
|
||||||
|
" rows = n_rows\n",
|
||||||
|
" elif n_rows == 0:\n",
|
||||||
|
" rows = batch_size\n",
|
||||||
|
" else:\n",
|
||||||
|
" rows = math.sqrt(len(imgs))\n",
|
||||||
|
" rows = round(rows)\n",
|
||||||
|
"\n",
|
||||||
|
" cols = math.ceil(len(imgs) / rows)\n",
|
||||||
|
"\n",
|
||||||
|
" w, h = imgs[0].size\n",
|
||||||
|
" grid = Image.new('RGB', size=(cols * w, rows * h), color='black')\n",
|
||||||
|
"\n",
|
||||||
|
" for i, img in enumerate(imgs):\n",
|
||||||
|
" grid.paste(img, box=(i % cols * w, i // cols * h))\n",
|
||||||
|
"\n",
|
||||||
|
" return grid\n",
|
||||||
|
"\n",
|
||||||
|
"class User_OSD1:\n",
|
||||||
|
" def __init__(self, prompt: str, seed: int, samples: int, steps: int, scale: float, height:int, width: int,\n",
|
||||||
|
" rows: int, iter: int, skip_grid: bool, skip_save: bool):\n",
|
||||||
|
" self.prompt = prompt\n",
|
||||||
|
" self.seed = seed\n",
|
||||||
|
" self.n_samples = samples\n",
|
||||||
|
"\n",
|
||||||
|
" self.ddim_steps = steps\n",
|
||||||
|
" self.cfg_scale = scale\n",
|
||||||
|
" \n",
|
||||||
|
" self.height = height\n",
|
||||||
|
" self.width = width\n",
|
||||||
|
"\n",
|
||||||
|
" self.n_rows = rows\n",
|
||||||
|
"\n",
|
||||||
|
" self.n_iter = iter\n",
|
||||||
|
"\n",
|
||||||
|
" self.skip_grid = skip_grid\n",
|
||||||
|
" self.skip_save = skip_save\n",
|
||||||
|
" \n",
|
||||||
|
" \n",
|
||||||
|
"class User_OSD2:\n",
|
||||||
|
" def __init__(self, prompt: str, seed: int, samples: int, steps: int, scale: float, strength: float,\n",
|
||||||
|
" height:int, width: int, rows: int, iter: int, skip_grid: bool, skip_save: bool):\n",
|
||||||
|
" self.prompt = prompt\n",
|
||||||
|
" self.seed = seed\n",
|
||||||
|
"\n",
|
||||||
|
" self.n_samples = samples\n",
|
||||||
|
"\n",
|
||||||
|
" self.ddim_steps = steps\n",
|
||||||
|
" self.cfg_scale = scale\n",
|
||||||
|
" self.strength = strength\n",
|
||||||
|
"\n",
|
||||||
|
" self.height = height\n",
|
||||||
|
" self.width = width\n",
|
||||||
|
"\n",
|
||||||
|
" self.n_rows = rows\n",
|
||||||
|
" self.n_iter = iter\n",
|
||||||
|
"\n",
|
||||||
|
" self.skip_grid = skip_grid\n",
|
||||||
|
" self.skip_save = skip_save\n",
|
||||||
|
"\n",
|
||||||
|
"config = \"optimizedSD/v1-inference.yaml\"\n",
|
||||||
|
"ckpt = f\"model.ckpt\"\n",
|
||||||
|
"device = \"cuda\"\n",
|
||||||
|
"\n",
|
||||||
|
"sd = load_model_from_config(f\"{ckpt}\")\n",
|
||||||
|
"li, lo = [], []\n",
|
||||||
|
"\n",
|
||||||
|
"for key, value in sd.items():\n",
|
||||||
|
" sp = key.split('.')\n",
|
||||||
|
" if(sp[0]) == 'model':\n",
|
||||||
|
" if('input_blocks' in sp):\n",
|
||||||
|
" li.append(key)\n",
|
||||||
|
" elif('middle_block' in sp):\n",
|
||||||
|
" li.append(key)\n",
|
||||||
|
" elif('time_embed' in sp):\n",
|
||||||
|
" li.append(key)\n",
|
||||||
|
" else:\n",
|
||||||
|
" lo.append(key)\n",
|
||||||
|
" \n",
|
||||||
|
"for key in li:\n",
|
||||||
|
" sd['model1.' + key[6:]] = sd.pop(key)\n",
|
||||||
|
"for key in lo:\n",
|
||||||
|
" sd['model2.' + key[6:]] = sd.pop(key)\n",
|
||||||
|
"\n",
|
||||||
|
"config = OmegaConf.load(f\"{config}\")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"model = instantiate_from_config(config.modelUNet)\n",
|
||||||
|
"_, _ = model.load_state_dict(sd, strict=False)\n",
|
||||||
|
"model.eval()\n",
|
||||||
|
"\n",
|
||||||
|
"modelCS = instantiate_from_config(config.modelCondStage)\n",
|
||||||
|
"_, _ = modelCS.load_state_dict(sd, strict=False)\n",
|
||||||
|
"modelCS.cond_stage_model.device = device\n",
|
||||||
|
"modelCS.eval()\n",
|
||||||
|
" \n",
|
||||||
|
"modelFS = instantiate_from_config(config.modelFirstStage)\n",
|
||||||
|
"_, _ = modelFS.load_state_dict(sd, strict=False)\n",
|
||||||
|
"modelFS.eval()\n",
|
||||||
|
"\n",
|
||||||
|
"model.unet_bs = True\n",
|
||||||
|
"model.cdevice = device\n",
|
||||||
|
"model.turbo = True\n",
|
||||||
|
"\n",
|
||||||
|
"del sd\n",
|
||||||
|
"\n",
|
||||||
|
"def txt2img_generate(user: User_OSD1, out_name: str):\n",
|
||||||
|
" torch_gc()\n",
|
||||||
|
" \n",
|
||||||
|
" device = \"cuda\"\n",
|
||||||
|
" C = 4\n",
|
||||||
|
" f = 8\n",
|
||||||
|
" ddim_eta = 0.0\n",
|
||||||
|
" start_code = None\n",
|
||||||
|
"\n",
|
||||||
|
" model.half()\n",
|
||||||
|
" modelCS.half()\n",
|
||||||
|
"\n",
|
||||||
|
" batch_size = user.n_samples\n",
|
||||||
|
" \n",
|
||||||
|
" if user.seed == -1:\n",
|
||||||
|
" user.seed = randint(0, 1000000)\n",
|
||||||
|
"\n",
|
||||||
|
" init_seed = user.seed\n",
|
||||||
|
"\n",
|
||||||
|
" seed_everything(user.seed)\n",
|
||||||
|
"\n",
|
||||||
|
" assert prompt is not None\n",
|
||||||
|
" data = [batch_size * [prompt]]\n",
|
||||||
|
"\n",
|
||||||
|
" precision_scope = autocast\n",
|
||||||
|
"\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
"\n",
|
||||||
|
" all_samples = list()\n",
|
||||||
|
" for _ in trange(user.n_iter, desc=\"Sampling\"):\n",
|
||||||
|
" for prompts in tqdm(data, desc=\"data\"):\n",
|
||||||
|
" with precision_scope(\"cuda\"):\n",
|
||||||
|
" modelCS.to(device)\n",
|
||||||
|
" uc = None\n",
|
||||||
|
" if user.cfg_scale != 1.0:\n",
|
||||||
|
" uc = modelCS.get_learned_conditioning(batch_size * [\"\"])\n",
|
||||||
|
" if isinstance(prompts, tuple):\n",
|
||||||
|
" prompts = list(prompts)\n",
|
||||||
|
" \n",
|
||||||
|
" c = modelCS.get_learned_conditioning(prompts) \n",
|
||||||
|
"\n",
|
||||||
|
" shape = [C, height // f, width // f]\n",
|
||||||
|
" modelCS.to(\"cpu\") \n",
|
||||||
|
"\n",
|
||||||
|
" samples_ddim = model.sample(S=user.ddim_steps,\n",
|
||||||
|
" conditioning=c,\n",
|
||||||
|
" batch_size=batch_size,\n",
|
||||||
|
" seed = user.seed,\n",
|
||||||
|
" shape=shape,\n",
|
||||||
|
" verbose=False,\n",
|
||||||
|
" unconditional_guidance_scale=user.cfg_scale,\n",
|
||||||
|
" unconditional_conditioning=uc,\n",
|
||||||
|
" eta=ddim_eta,\n",
|
||||||
|
" x_T=start_code)\n",
|
||||||
|
"\n",
|
||||||
|
" modelFS.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
" for i in range(batch_size):\n",
|
||||||
|
" \n",
|
||||||
|
" x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))\n",
|
||||||
|
" x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
|
||||||
|
" x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c')\n",
|
||||||
|
"\n",
|
||||||
|
" out = Image.fromarray(x_sample.astype(np.uint8))\n",
|
||||||
|
" if not user.skip_save:\n",
|
||||||
|
" out.save(f\"/notebooks/stable-diffusion/Output/{out_name}_{init_seed}[{i}].png\")\n",
|
||||||
|
"\n",
|
||||||
|
" all_samples.append(out)\n",
|
||||||
|
" user.seed+=1\n",
|
||||||
|
"\n",
|
||||||
|
" modelFS.to(\"cpu\")\n",
|
||||||
|
"\n",
|
||||||
|
" del samples_ddim\n",
|
||||||
|
" del x_sample\n",
|
||||||
|
" del x_samples_ddim\n",
|
||||||
|
"\n",
|
||||||
|
" if not user.skip_grid:\n",
|
||||||
|
" grid = image_grid(all_samples, batch_size, user.n_rows)\n",
|
||||||
|
" all_samples.insert(0, grid)\n",
|
||||||
|
"\n",
|
||||||
|
" torch_gc()\n",
|
||||||
|
" return all_samples, init_seed\n",
|
||||||
|
"\n",
|
||||||
|
"def img2img_generate(user: User_OSD2, input_image, out_name: str):\n",
|
||||||
|
" torch_gc()\n",
|
||||||
|
" device = \"cuda\"\n",
|
||||||
|
" batch_size = user.n_samples\n",
|
||||||
|
" model.small_batch = False\n",
|
||||||
|
" \n",
|
||||||
|
" \n",
|
||||||
|
" init_image = load_img(input_image, user.height, user.width).to(device).half()\n",
|
||||||
|
"\n",
|
||||||
|
" model.half()\n",
|
||||||
|
" modelCS.half()\n",
|
||||||
|
" modelFS.half()\n",
|
||||||
|
" \n",
|
||||||
|
" if user.seed == -1:\n",
|
||||||
|
" user.seed = randint(0, 1000000)\n",
|
||||||
|
"\n",
|
||||||
|
" init_seed = user.seed\n",
|
||||||
|
"\n",
|
||||||
|
" seed_everything(user.seed)\n",
|
||||||
|
"\n",
|
||||||
|
" assert prompt is not None\n",
|
||||||
|
" data = [batch_size * [prompt]]\n",
|
||||||
|
"\n",
|
||||||
|
" modelFS.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
" init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)\n",
|
||||||
|
" init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space\n",
|
||||||
|
"\n",
|
||||||
|
" modelFS.to(\"cpu\")\n",
|
||||||
|
"\n",
|
||||||
|
" assert 0. <= user.strength <= 1., 'can only work with strength in [0.0, 1.0]'\n",
|
||||||
|
" t_enc = int(user.strength * user.ddim_steps)\n",
|
||||||
|
" print(f\"target t_enc is {t_enc} steps\")\n",
|
||||||
|
"\n",
|
||||||
|
" precision_scope = autocast\n",
|
||||||
|
"\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" all_samples = list()\n",
|
||||||
|
" for _ in trange(user.n_iter, desc=\"Sampling\"):\n",
|
||||||
|
" for prompts in tqdm(data, desc=\"data\"):\n",
|
||||||
|
" with precision_scope(\"cuda\"):\n",
|
||||||
|
" modelCS.to(device)\n",
|
||||||
|
" uc = None\n",
|
||||||
|
" if user.cfg_scale != 1.0:\n",
|
||||||
|
" uc = modelCS.get_learned_conditioning(batch_size * [\"\"])\n",
|
||||||
|
" if isinstance(prompts, tuple):\n",
|
||||||
|
" prompts = list(prompts)\n",
|
||||||
|
" \n",
|
||||||
|
" c = modelCS.get_learned_conditioning(prompts)\n",
|
||||||
|
"\n",
|
||||||
|
" modelCS.to(\"cpu\")\n",
|
||||||
|
"\n",
|
||||||
|
" # encode (scaled latent)\n",
|
||||||
|
" z_enc = model.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device), user.seed,ddim_steps=user.ddim_steps, ddim_eta=0.0)\n",
|
||||||
|
" # decode it\n",
|
||||||
|
" samples_ddim = model.decode(z_enc, c, t_enc, unconditional_guidance_scale=user.cfg_scale,\n",
|
||||||
|
" unconditional_conditioning=uc,)\n",
|
||||||
|
"\n",
|
||||||
|
" modelFS.to(device)\n",
|
||||||
|
" # print(\"saving images\")\n",
|
||||||
|
" for i in range(batch_size):\n",
|
||||||
|
" \n",
|
||||||
|
" x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))\n",
|
||||||
|
" x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
|
||||||
|
" x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c')\n",
|
||||||
|
"\n",
|
||||||
|
" # all_samples.append(x_sample.to(\"cpu\"))\n",
|
||||||
|
" # all_samples.append(Image.fromarray(x_sample.astype(np.uint8)))\n",
|
||||||
|
"\n",
|
||||||
|
" out = Image.fromarray(x_sample.astype(np.uint8))\n",
|
||||||
|
" if not user.skip_save:\n",
|
||||||
|
" out.save(f\"/notebooks/stable-diffusion/Output/{out_name}_{init_seed}[{i}].png\")\n",
|
||||||
|
" all_samples.append(out)\n",
|
||||||
|
"\n",
|
||||||
|
" user.seed+=1\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" modelFS.to(\"cpu\")\n",
|
||||||
|
"\n",
|
||||||
|
" del samples_ddim\n",
|
||||||
|
" del x_sample\n",
|
||||||
|
" del x_samples_ddim\n",
|
||||||
|
"\n",
|
||||||
|
" if not user.skip_grid:\n",
|
||||||
|
" grid = image_grid(all_samples, batch_size, user.n_rows)\n",
|
||||||
|
" all_samples.insert(0, grid)\n",
|
||||||
|
" torch_gc()\n",
|
||||||
|
" return all_samples, init_seed\n",
|
||||||
|
"\n",
|
||||||
|
"def txt2img(prompt, seed, samples, steps, scale, height, width, rows, iter, skip_grid, skip_save, out_name: str):\n",
|
||||||
|
" if(rows > samples):\n",
|
||||||
|
" rows = samples\n",
|
||||||
|
" user = User_OSD1(prompt, seed, samples, steps, scale, height, width, rows, iter, skip_grid, skip_save)\n",
|
||||||
|
" return txt2img_generate(user, out_name)\n",
|
||||||
|
"\n",
|
||||||
|
"def img2img(prompt, seed, samples, steps, scale, strength, height, width, rows,\n",
|
||||||
|
" iter, skip_grid, skip_save, mode, init_image, out_name):\n",
|
||||||
|
" if mode == \"Just resize\":\n",
|
||||||
|
" resize_mode = 0\n",
|
||||||
|
" elif mode == \"Crop and resize\":\n",
|
||||||
|
" resize_mode = 1\n",
|
||||||
|
" else:\n",
|
||||||
|
" resize_mode = 2\n",
|
||||||
|
" if(rows > samples):\n",
|
||||||
|
" rows = samples\n",
|
||||||
|
" user = User_OSD2(prompt, seed, samples, steps, scale, strength, height, width, rows, iter, skip_grid, skip_save)\n",
|
||||||
|
" init_image = resize_image(resize_mode, init_image, width, height)\n",
|
||||||
|
"\n",
|
||||||
|
" return img2img_generate(user, init_image, out_name) + (resize_mode,)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "7v-FQVQ9OYtk",
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Inference"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "9lnxMQBZ_asY"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Text 2 Image\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"cellView": "form",
|
||||||
|
"id": "ImGOidHaAOZn"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"prompt = \"a cute young girl\"\n",
|
||||||
|
"samples = 2\n",
|
||||||
|
"sampler = 'k_dpm_2' # [\"k_euler_a\",\"k-diffusion\", \"k_dpm_2\", \"k_dpm_2_a\", \"k_euler\", \"k_heun\"]\n",
|
||||||
|
"\n",
|
||||||
|
"scale = 12 # min:1, max:30, step:0.5\n",
|
||||||
|
"steps = 120 # min:1, max:150, step:1\n",
|
||||||
|
"\n",
|
||||||
|
"seed = -1\n",
|
||||||
|
"\n",
|
||||||
|
"# Don't change these if you don't know what you're doing\n",
|
||||||
|
"width = 512\n",
|
||||||
|
"height = 512\n",
|
||||||
|
"\n",
|
||||||
|
"skip_grid = True \n",
|
||||||
|
"rows = 2\n",
|
||||||
|
"\n",
|
||||||
|
"skip_save = False\n",
|
||||||
|
"\n",
|
||||||
|
"out_name = \"out\" + str(int(time.time()))\n",
|
||||||
|
"\n",
|
||||||
|
"# ===================================================================================================================\n",
|
||||||
|
"\n",
|
||||||
|
"images, seed_new = txt2img(prompt, seed, samples,\n",
|
||||||
|
" steps, scale, height, width,\n",
|
||||||
|
" rows, 1, skip_grid, skip_save,\n",
|
||||||
|
" out_name)\n",
|
||||||
|
"\n",
|
||||||
|
"path = \"/notebooks/stable-diffusion/Output/\"\n",
|
||||||
|
"\n",
|
||||||
|
"save_all = True\n",
|
||||||
|
"\n",
|
||||||
|
"if save_all:\n",
|
||||||
|
" k = 0\n",
|
||||||
|
" for i in images:\n",
|
||||||
|
" i.save(f'{path}{name}_{k}.png')\n",
|
||||||
|
" k += 1\n",
|
||||||
|
"else:\n",
|
||||||
|
" index = 1\n",
|
||||||
|
" images[index].save(f'{path}{name}_{index}.png')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "L7jnzh7O_vQT",
|
||||||
|
"jp-MarkdownHeadingCollapsed": true,
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"### Image 2 Image\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"cellView": "form",
|
||||||
|
"id": "xOcjG58T_wzp"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"prompt = \"\" #@param {type:\"string\"}\n",
|
||||||
|
"sampler = 'k_dpm_2' #@param [\"k_euler_a\",\"k-diffusion\", \"k_dpm_2\", \"k_dpm_2_a\", \"k_euler\", \"k_heun\"] {allow-input: false}\n",
|
||||||
|
"init_image_path = \"/notebooks/stable-diffusion/Source/794_1000.jpg\" #@param {type: 'string'}\n",
|
||||||
|
"\n",
|
||||||
|
"resize_mode = \"Resize and fill\" #@param [\"Just resize\", \"Crop and resize\", \"Resize and fill\"] {allow-input: false}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"width = 512 #@param {type:\"integer\"}\n",
|
||||||
|
"height = 512 #@param {type:\"integer\"}\n",
|
||||||
|
"\n",
|
||||||
|
"scale = 7.5 #@param {type:\"slider\", min:1, max:30, step:0.5}\n",
|
||||||
|
"steps = 64 #@param {type:\"slider\", min:1, max:150, step:1}\n",
|
||||||
|
"strength = 0.7 #@param {type: \"slider\", min:0.00, max:1.00, step:0.01}\n",
|
||||||
|
"\n",
|
||||||
|
"samples = 2 #@param {type:'integer'}\n",
|
||||||
|
"skip_grid = True #@param {type:\"boolean\"}\n",
|
||||||
|
"rows = 2 #@param {type:'integer'}\n",
|
||||||
|
"\n",
|
||||||
|
"seed = -1 #@param {type:'integer'}\n",
|
||||||
|
"\n",
|
||||||
|
"init_image = Image.open(init_image_path)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"images, seed_new, mode = img2img(prompt, init_image_path, seed, sampler, steps, scale, strength, samples, rows, height, width, skip_grid, resize_mode)\n",
|
||||||
|
"\n",
|
||||||
|
"path = \"/notebooks/stable-diffusion/Output/\"\n",
|
||||||
|
"name = \"out\" + str(int(time.time()))\n",
|
||||||
|
"\n",
|
||||||
|
"save_all = True #@param {type:\"boolean\"}\n",
|
||||||
|
"\n",
|
||||||
|
"if save_all:\n",
|
||||||
|
" k = 0\n",
|
||||||
|
" for i in images:\n",
|
||||||
|
" i.save(f'{path}{name}_{k}.png')\n",
|
||||||
|
" k += 1\n",
|
||||||
|
"else:\n",
|
||||||
|
" index = 1 #@param {type:\"integer\"}\n",
|
||||||
|
" images[index].save(f'{path}{name}_{index}.png')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "0QJG1W0-fXlI"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"os.kill(os.getpid(), 9) # Crash colab if runs out of gpu memory / Funny errors (Run from Set up again)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "JAlvbY3mxLuc",
|
||||||
|
"jp-MarkdownHeadingCollapsed": true,
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Saving"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "8m_6DufvxMkF"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!rm output.zip\n",
|
||||||
|
"!zip -r ./output.zip ./Output/*.png\n",
|
||||||
|
"from google.colab import files\n",
|
||||||
|
"files.download(\"./output.zip\")\n",
|
||||||
|
"!rm ./Output/*"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"colab": {
|
||||||
|
"collapsed_sections": [
|
||||||
|
"jwZ0GT0eObBW",
|
||||||
|
"XvHTXI7KOnu4",
|
||||||
|
"yaHfjOHyJgdC",
|
||||||
|
"N-ObhCcRHhFv",
|
||||||
|
"Oudkt8AEKw05",
|
||||||
|
"-7syo80FPa4T",
|
||||||
|
"ihQD4BD1P1Vr",
|
||||||
|
"ie9xTKWi8b1e",
|
||||||
|
"P95iTXcZQbnk",
|
||||||
|
"IDQmNv_9ZTGc",
|
||||||
|
"NMj_JbQ9Rrxq",
|
||||||
|
"iIBkjkG_dHAu",
|
||||||
|
"5kQ2bDBrPe6Q",
|
||||||
|
"CszrKJDe-66T",
|
||||||
|
"5deJM1EQ_BIC",
|
||||||
|
"4y4c6LWDAAVB",
|
||||||
|
"9lnxMQBZ_asY",
|
||||||
|
"NK8ZJXL3DOkY",
|
||||||
|
"L7jnzh7O_vQT",
|
||||||
|
"NBtpXi1JDXCv",
|
||||||
|
"BTpsvKnCQyqj",
|
||||||
|
"YbgZNNTxLOEo",
|
||||||
|
"CcHuKU13TdZ8",
|
||||||
|
"fVPP3LONLmdN",
|
||||||
|
"gkLTqrwyRcAy",
|
||||||
|
"FxoQNQIQTBR2",
|
||||||
|
"odc0JtzUTglD",
|
||||||
|
"thSWMltDLYNy",
|
||||||
|
"UeAcuHx-Em0U",
|
||||||
|
"NwsSUAbMLe04",
|
||||||
|
"7fplwcZ_EZaZ",
|
||||||
|
"4D_g45M-F9e2",
|
||||||
|
"uZ13WzesoXdq"
|
||||||
|
],
|
||||||
|
"private_outputs": true,
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"gpuClass": "standard",
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.13"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
Loading…
Reference in New Issue