This commit is contained in:
dogewanwan 2022-08-24 21:38:14 +03:00
commit 781f054a20
3 changed files with 88 additions and 25 deletions

View File

@ -146,3 +146,12 @@ to get otherwise.
Example: (cherrypicked result; original picture by anon) Example: (cherrypicked result; original picture by anon)
![](images/loopback.jpg) ![](images/loopback.jpg)
### Png info
Adds information about generation parameters to PNG as a text chunk. You
can view this information later using any software that supports viewing
PNG chunk info, for example: https://www.nayuki.io/page/png-file-chunk-inspector
This can be disabled using the `--disable-pnginfo` command line option.
![](images/pnginfo.png)

BIN
images/pnginfo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

100
webui.py
View File

@ -4,7 +4,7 @@ import torch.nn as nn
import numpy as np import numpy as np
import gradio as gr import gradio as gr
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from itertools import islice from itertools import islice
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import autocast from torch import autocast
@ -12,6 +12,8 @@ from contextlib import contextmanager, nullcontext
import mimetypes import mimetypes
import random import random
import math import math
import html
import time
import k_diffusion as K import k_diffusion as K
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
@ -50,7 +52,12 @@ parser.add_argument("--no-verify-input", action='store_true', help="do not verif
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--save-format", type=str, default='png', help="file format for saved indiviual samples; can be png or jpg")
parser.add_argument("--grid-format", type=str, default='png', help="file format for saved grids; can be png or jpg") parser.add_argument("--grid-format", type=str, default='png', help="file format for saved grids; can be png or jpg")
parser.add_argument("--grid-extended-filename", action='store_true', help="save grid images to filenames with extended info: seed, prompt")
parser.add_argument("--jpeg-quality", type=int, default=80, help="quality for saved jpeg images")
parser.add_argument("--disable-pnginfo", action='store_true', help="disable saving text information about generation parameters as chunks to png files")
parser.add_argument("--inversion", action='store_true', help="switch to stable inversion version; allows for uploading embeddings; this option should be used only with textual inversion repo") parser.add_argument("--inversion", action='store_true', help="switch to stable inversion version; allows for uploading embeddings; this option should be used only with textual inversion repo")
opt = parser.parse_args() opt = parser.parse_args()
@ -130,6 +137,37 @@ def create_random_tensors(shape, seeds):
return x return x
def torch_gc():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def sanitize_filename_part(text):
return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
def save_image(image, path, basename, seed, prompt, extension, info=None, short_filename=False):
prompt = sanitize_filename_part(prompt)
if short_filename:
filename = f"{basename}.{extension}"
else:
filename = f"{basename}-{seed}-{prompt[:128]}.{extension}"
if extension == 'png' and not opt.disable_pnginfo:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", info)
else:
pnginfo = None
image.save(os.path.join(path, filename), quality=opt.jpeg_quality, pnginfo=pnginfo)
def plaintext_to_html(text):
text = "".join([f"<p>{html.escape(x)}</p>\n" for x in text.split('\n')])
return text
def load_GFPGAN(): def load_GFPGAN():
model_name = 'GFPGANv1.3' model_name = 'GFPGANv1.3'
model_path = os.path.join(GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth') model_path = os.path.join(GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
@ -301,11 +339,25 @@ def check_prompt_length(prompt, comments):
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
def wrap_gradio_call(func):
def f(*p1, **p2):
t = time.perf_counter()
res = list(func(*p1, **p2))
elapsed = time.perf_counter() - t
# last item is always HTML
res[-1] = res[-1] + f"<p class='performance'>Time taken: {elapsed:.2f}s</p>"
return tuple(res)
return f
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False): def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert prompt is not None assert prompt is not None
torch.cuda.empty_cache() torch_gc()
if seed == -1: if seed == -1:
seed = random.randrange(4294967294) seed = random.randrange(4294967294)
@ -351,6 +403,11 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
all_prompts = batch_size * n_iter * [prompt] all_prompts = batch_size * n_iter * [prompt]
all_seeds = [seed + x for x in range(len(all_prompts))] all_seeds = [seed + x for x in range(len(all_prompts))]
info = f"""
{prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip() + "".join(["\n\n" + x for x in comments])
precision_scope = autocast if opt.precision == "autocast" else nullcontext precision_scope = autocast if opt.precision == "autocast" else nullcontext
output_images = [] output_images = []
with torch.no_grad(), precision_scope("cuda"), model.ema_scope(): with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
@ -385,9 +442,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
x_sample = restored_img x_sample = restored_img
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png" save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opt.save_format, info=info)
image.save(os.path.join(sample_path, filename))
output_images.append(image) output_images.append(image)
base_count += 1 base_count += 1
@ -406,17 +461,10 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
output_images.insert(0, grid) output_images.insert(0, grid)
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.{opt.grid_format}')) save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opt.grid_format, info=info, short_filename=not opt.grid_extended_filename)
grid_count += 1 grid_count += 1
info = f""" torch_gc()
{prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip()
for comment in comments:
info += "\n\n" + comment
return output_images, seed, info return output_images, seed, info
@ -465,7 +513,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, p
del sampler del sampler
return output_images, seed, info return output_images, seed, plaintext_to_html(info)
class Flagging(gr.FlaggingCallback): class Flagging(gr.FlaggingCallback):
@ -510,7 +558,7 @@ class Flagging(gr.FlaggingCallback):
txt2img_interface = gr.Interface( txt2img_interface = gr.Interface(
txt2img, wrap_gradio_call(txt2img),
inputs=[ inputs=[
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1), gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50), gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
@ -529,7 +577,7 @@ txt2img_interface = gr.Interface(
outputs=[ outputs=[
gr.Gallery(label="Images"), gr.Gallery(label="Images"),
gr.Number(label='Seed'), gr.Number(label='Seed'),
gr.Textbox(label="Copy-paste generation parameters"), gr.HTML(),
], ],
title="Stable Diffusion Text-to-Image K", title="Stable Diffusion Text-to-Image K",
description="Generate images from text with Stable Diffusion (using K-LMS)", description="Generate images from text with Stable Diffusion (using K-LMS)",
@ -608,7 +656,8 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
grid_count = len(os.listdir(outpath)) - 1 grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, force_n_rows=1) grid = image_grid(history, batch_size, force_n_rows=1)
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.{opt.grid_format}'))
save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opt.grid_format, info=info, short_filename=not opt.grid_extended_filename)
output_images = history output_images = history
seed = initial_seed seed = initial_seed
@ -633,14 +682,14 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
del sampler del sampler
return output_images, seed, info return output_images, seed, plaintext_to_html(info)
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
img2img_interface = gr.Interface( img2img_interface = gr.Interface(
img2img, wrap_gradio_call(img2img),
inputs=[ inputs=[
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1), gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"), gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"),
@ -661,7 +710,7 @@ img2img_interface = gr.Interface(
outputs=[ outputs=[
gr.Gallery(), gr.Gallery(),
gr.Number(label='Seed'), gr.Number(label='Seed'),
gr.Textbox(label="Copy-paste generation parameters"), gr.HTML(),
], ],
title="Stable Diffusion Image-to-Image", title="Stable Diffusion Image-to-Image",
description="Generate images from images with Stable Diffusion", description="Generate images from images with Stable Diffusion",
@ -682,7 +731,7 @@ def run_GFPGAN(image, strength):
if strength < 1.0: if strength < 1.0:
res = Image.blend(image, res, strength) res = Image.blend(image, res, strength)
return res return res, 0, ''
if GFPGAN is not None: if GFPGAN is not None:
@ -694,6 +743,8 @@ if GFPGAN is not None:
], ],
outputs=[ outputs=[
gr.Image(label="Result"), gr.Image(label="Result"),
gr.Number(label='Seed', visible=False),
gr.HTML(),
], ],
title="GFPGAN", title="GFPGAN",
description="Fix faces on images", description="Fix faces on images",
@ -704,7 +755,10 @@ if GFPGAN is not None:
demo = gr.TabbedInterface( demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces], interface_list=[x[0] for x in interfaces],
tab_names=[x[1] for x in interfaces], tab_names=[x[1] for x in interfaces],
css=("" if opt.no_progressbar_hiding else css_hide_progressbar) css=("" if opt.no_progressbar_hiding else css_hide_progressbar) + """
.output-html p {margin: 0 0.5em;}
.performance { font-size: 0.85em; color: #444; }
"""
) )
demo.launch() demo.launch()