make CLIP interrogator download original text files if the directory does not exist
remove random artist built-in extension (to re-added as a normal extension on demand) remove artists.csv (but what does it mean????????????????????) make interrogate buttons show Loading... when you click them
This commit is contained in:
parent
40ff6db532
commit
6d805b669e
|
@ -49,7 +49,6 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
||||
- Mouseover hints for most UI elements
|
||||
- Possible to change defaults/mix/max/step values for UI elements via text config
|
||||
- Random artist button
|
||||
- Tiling support, a checkbox to create images that can be tiled like textures
|
||||
- Progress bar and live image generation preview
|
||||
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
|
||||
|
|
3041
artists.csv
3041
artists.csv
File diff suppressed because it is too large
Load Diff
|
@ -1,50 +0,0 @@
|
|||
import random
|
||||
|
||||
from modules import script_callbacks, shared
|
||||
import gradio as gr
|
||||
|
||||
art_symbol = '\U0001f3a8' # 🎨
|
||||
global_prompt = None
|
||||
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
|
||||
|
||||
|
||||
def roll_artist(prompt):
|
||||
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
|
||||
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
|
||||
|
||||
return prompt + ", " + artist.name if prompt != '' else artist.name
|
||||
|
||||
|
||||
def add_roll_button(prompt):
|
||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_txt2img_tokens",
|
||||
inputs=[
|
||||
prompt,
|
||||
],
|
||||
outputs=[
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def after_component(component, **kwargs):
|
||||
global global_prompt
|
||||
|
||||
elem_id = kwargs.get('elem_id', None)
|
||||
if elem_id not in related_ids:
|
||||
return
|
||||
|
||||
if elem_id == "txt2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "txt2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
elif elem_id == "img2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "img2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
|
||||
|
||||
script_callbacks.on_after_component(after_component)
|
|
@ -14,7 +14,6 @@ titles = {
|
|||
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
|
||||
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
|
||||
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
|
||||
"\u{1f3a8}": "Add a random artist to the prompt.",
|
||||
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
||||
"\u{1f4c2}": "Open images output directory",
|
||||
"\u{1f4be}": "Save style",
|
||||
|
|
|
@ -126,8 +126,6 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
||||
|
@ -390,12 +388,6 @@ class Api:
|
|||
|
||||
return styleList
|
||||
|
||||
def get_artists_categories(self):
|
||||
return shared.artist_db.cats
|
||||
|
||||
def get_artists(self):
|
||||
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
|
||||
|
||||
def get_embeddings(self):
|
||||
db = sd_hijack.model_hijack.embedding_db
|
||||
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
import os.path
|
||||
import csv
|
||||
from collections import namedtuple
|
||||
|
||||
Artist = namedtuple("Artist", ['name', 'weight', 'category'])
|
||||
|
||||
|
||||
class ArtistsDatabase:
|
||||
def __init__(self, filename):
|
||||
self.cats = set()
|
||||
self.artists = []
|
||||
|
||||
if not os.path.exists(filename):
|
||||
return
|
||||
|
||||
with open(filename, "r", newline='', encoding="utf8") as file:
|
||||
reader = csv.DictReader(file)
|
||||
|
||||
for row in reader:
|
||||
artist = Artist(row["artist"], float(row["score"]), row["category"])
|
||||
self.artists.append(artist)
|
||||
self.cats.add(artist.category)
|
||||
|
||||
def categories(self):
|
||||
return sorted(self.cats)
|
|
@ -5,12 +5,13 @@ from collections import namedtuple
|
|||
import re
|
||||
|
||||
import torch
|
||||
import torch.hub
|
||||
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import devices, paths, lowvram, modelloader
|
||||
from modules import devices, paths, lowvram, modelloader, errors
|
||||
|
||||
blip_image_eval_size = 384
|
||||
clip_model_name = 'ViT-L/14'
|
||||
|
@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"])
|
|||
re_topn = re.compile(r"\.top(\d+)\.")
|
||||
|
||||
|
||||
def download_default_clip_interrogate_categories(content_dir):
|
||||
print("Downloading CLIP categories...")
|
||||
|
||||
tmpdir = content_dir + "_tmp"
|
||||
try:
|
||||
os.makedirs(tmpdir)
|
||||
|
||||
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
|
||||
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
|
||||
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
|
||||
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
|
||||
|
||||
os.rename(tmpdir, content_dir)
|
||||
|
||||
except Exception as e:
|
||||
errors.display(e, "downloading default CLIP interrogate categories")
|
||||
finally:
|
||||
if os.path.exists(tmpdir):
|
||||
os.remove(tmpdir)
|
||||
|
||||
|
||||
class InterrogateModels:
|
||||
blip_model = None
|
||||
clip_model = None
|
||||
clip_preprocess = None
|
||||
categories = None
|
||||
dtype = None
|
||||
running_on_cpu = None
|
||||
|
||||
def __init__(self, content_dir):
|
||||
self.categories = []
|
||||
self.loaded_categories = None
|
||||
self.content_dir = content_dir
|
||||
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||
|
||||
if os.path.exists(content_dir):
|
||||
for filename in os.listdir(content_dir):
|
||||
def categories(self):
|
||||
if self.loaded_categories is not None:
|
||||
return self.loaded_categories
|
||||
|
||||
self.loaded_categories = []
|
||||
|
||||
if not os.path.exists(self.content_dir):
|
||||
download_default_clip_interrogate_categories(self.content_dir)
|
||||
|
||||
if os.path.exists(self.content_dir):
|
||||
for filename in os.listdir(self.content_dir):
|
||||
m = re_topn.search(filename)
|
||||
topn = 1 if m is None else int(m.group(1))
|
||||
|
||||
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
|
||||
with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
|
||||
lines = [x.strip() for x in file.readlines()]
|
||||
|
||||
self.categories.append(Category(name=filename, topn=topn, items=lines))
|
||||
self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
|
||||
|
||||
return self.loaded_categories
|
||||
|
||||
def load_blip_model(self):
|
||||
import models.blip
|
||||
|
@ -139,7 +172,6 @@ class InterrogateModels:
|
|||
shared.state.begin()
|
||||
shared.state.job = 'interrogate'
|
||||
try:
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
|
@ -159,12 +191,7 @@ class InterrogateModels:
|
|||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
if shared.opts.interrogate_use_builtin_artists:
|
||||
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
|
||||
|
||||
res += ", " + artist[0]
|
||||
|
||||
for name, topn, items in self.categories:
|
||||
for name, topn, items in self.categories():
|
||||
matches = self.rank(image_features, items, top_count=topn)
|
||||
for match, score in matches:
|
||||
if shared.opts.interrogate_return_ranks:
|
||||
|
|
|
@ -9,7 +9,6 @@ from PIL import Image
|
|||
import gradio as gr
|
||||
import tqdm
|
||||
|
||||
import modules.artists
|
||||
import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.styles
|
||||
|
@ -254,8 +253,6 @@ class State:
|
|||
state = State()
|
||||
state.server_start = time.time()
|
||||
|
||||
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
|
||||
|
||||
styles_filename = cmd_opts.styles_file
|
||||
prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
||||
|
||||
|
@ -408,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
|
@ -419,7 +415,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
|||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
||||
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
|
||||
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
|
||||
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||
|
|
|
@ -228,17 +228,17 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
|
|||
left, _ = os.path.splitext(filename)
|
||||
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
|
||||
|
||||
return [gr_show(True), None]
|
||||
return [gr.update(), None]
|
||||
|
||||
|
||||
def interrogate(image):
|
||||
prompt = shared.interrogator.interrogate(image.convert("RGB"))
|
||||
return gr_show(True) if prompt is None else prompt
|
||||
return gr.update() if prompt is None else prompt
|
||||
|
||||
|
||||
def interrogate_deepbooru(image):
|
||||
prompt = deepbooru.model.tag(image)
|
||||
return gr_show(True) if prompt is None else prompt
|
||||
return gr.update() if prompt is None else prompt
|
||||
|
||||
|
||||
def create_seed_inputs(target_interface):
|
||||
|
@ -1039,19 +1039,18 @@ def create_ui():
|
|||
init_img_inpaint,
|
||||
],
|
||||
outputs=[img2img_prompt, dummy_component],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
img2img_prompt.submit(**img2img_args)
|
||||
submit.click(**img2img_args)
|
||||
|
||||
img2img_interrogate.click(
|
||||
fn=lambda *args : process_interrogate(interrogate, *args),
|
||||
fn=lambda *args: process_interrogate(interrogate, *args),
|
||||
**interrogate_args,
|
||||
)
|
||||
|
||||
img2img_deepbooru.click(
|
||||
fn=lambda *args : process_interrogate(interrogate_deepbooru, *args),
|
||||
fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
|
||||
**interrogate_args,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue