From 085427de0efc9e9e7a6e9a5aebc6b5a69f0365e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 8 Jan 2023 09:37:33 +0300 Subject: [PATCH] make it possible for extensions/scripts to add their own embedding directories --- modules/sd_hijack.py | 7 +- .../textual_inversion/textual_inversion.py | 166 +++++++++++------- 2 files changed, 106 insertions(+), 67 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cfdb09d6d..6b0d95af9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -83,10 +83,12 @@ class StableDiffusionModelHijack: clip = None optimization_method = None - embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) + embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() + + def __init__(self): + self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) def hijack(self, m): - if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) @@ -117,7 +119,6 @@ class StableDiffusionModelHijack: self.layers = flatten(m) def undo_hijack(self, m): - if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: m.cond_stage_model = m.cond_stage_model.wrapped diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e85dd5498..217fe9eb1 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -66,17 +66,41 @@ class Embedding: return self.cached_checksum +class DirWithTextualInversionEmbeddings: + def __init__(self, path): + self.path = path + self.mtime = None + + def has_changed(self): + if not os.path.isdir(self.path): + return False + + mt = os.path.getmtime(self.path) + if self.mtime is None or mt > self.mtime: + return True + + def update(self): + if not os.path.isdir(self.path): + return + + self.mtime = os.path.getmtime(self.path) + + class EmbeddingDatabase: - def __init__(self, embeddings_dir): + def __init__(self): self.ids_lookup = {} self.word_embeddings = {} self.skipped_embeddings = {} - self.dir_mtime = None - self.embeddings_dir = embeddings_dir self.expected_shape = -1 + self.embedding_dirs = {} + + def add_embedding_dir(self, path): + self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) + + def clear_embedding_dirs(self): + self.embedding_dirs.clear() def register_embedding(self, embedding, model): - self.word_embeddings[embedding.name] = embedding ids = model.cond_stage_model.tokenize([embedding.name])[0] @@ -93,69 +117,62 @@ class EmbeddingDatabase: vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) return vec.shape[1] - def load_textual_inversion_embeddings(self, force_reload = False): - mt = os.path.getmtime(self.embeddings_dir) - if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime: - return + def load_from_file(self, path, filename): + name, ext = os.path.splitext(filename) + ext = ext.upper() - self.dir_mtime = mt - self.ids_lookup.clear() - self.word_embeddings.clear() - self.skipped_embeddings.clear() - self.expected_shape = self.get_expected_shape() - - def process_file(path, filename): - name, ext = os.path.splitext(filename) - ext = ext.upper() - - if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: - _, second_ext = os.path.splitext(name) - if second_ext.upper() == '.PREVIEW': - return - - embed_image = Image.open(path) - if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: - data = embedding_from_b64(embed_image.text['sd-ti-embedding']) - name = data.get('name', name) - else: - data = extract_image_data_embed(embed_image) - name = data.get('name', name) - elif ext in ['.BIN', '.PT']: - data = torch.load(path, map_location="cpu") - else: + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': return - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) + embed_image = Image.open(path) + if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: + data = embedding_from_b64(embed_image.text['sd-ti-embedding']) + name = data.get('name', name) else: - raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + data = extract_image_data_embed(embed_image) + name = data.get('name', name) + elif ext in ['.BIN', '.PT']: + data = torch.load(path, map_location="cpu") + else: + return - vec = emb.detach().to(devices.device, dtype=torch.float32) - embedding = Embedding(vec, name) - embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('sd_checkpoint', None) - embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) - embedding.vectors = vec.shape[0] - embedding.shape = vec.shape[-1] + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - if self.expected_shape == -1 or self.expected_shape == embedding.shape: - self.register_embedding(embedding, shared.sd_model) - else: - self.skipped_embeddings[name] = embedding + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") - for root, dirs, fns in os.walk(self.embeddings_dir): + vec = emb.detach().to(devices.device, dtype=torch.float32) + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + embedding.vectors = vec.shape[0] + embedding.shape = vec.shape[-1] + + if self.expected_shape == -1 or self.expected_shape == embedding.shape: + self.register_embedding(embedding, shared.sd_model) + else: + self.skipped_embeddings[name] = embedding + + def load_from_dir(self, embdir): + if not os.path.isdir(embdir.path): + return + + for root, dirs, fns in os.walk(embdir.path): for fn in fns: try: fullfn = os.path.join(root, fn) @@ -163,12 +180,32 @@ class EmbeddingDatabase: if os.stat(fullfn).st_size == 0: continue - process_file(fullfn, fn) + self.load_from_file(fullfn, fn) except Exception: print(f"Error loading embedding {fn}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) continue + def load_textual_inversion_embeddings(self, force_reload=False): + if not force_reload: + need_reload = False + for path, embdir in self.embedding_dirs.items(): + if embdir.has_changed(): + need_reload = True + break + + if not need_reload: + return + + self.ids_lookup.clear() + self.word_embeddings.clear() + self.skipped_embeddings.clear() + self.expected_shape = self.get_expected_shape() + + for path, embdir in self.embedding_dirs.items(): + self.load_from_dir(embdir) + embdir.update() + print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if len(self.skipped_embeddings) > 0: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") @@ -251,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert os.path.isfile(template_file), "Prompt template file doesn't exist" assert steps, "Max steps is empty or 0" assert isinstance(steps, int), "Max steps must be integer" - assert steps > 0 , "Max steps must be positive" + assert steps > 0, "Max steps must be positive" assert isinstance(save_model_every, int), "Save {name} must be integer" - assert save_model_every >= 0 , "Save {name} must be positive or 0" + assert save_model_every >= 0, "Save {name} must be positive or 0" assert isinstance(create_image_every, int), "Create image must be integer" - assert create_image_every >= 0 , "Create image must be positive or 0" + assert create_image_every >= 0, "Create image must be positive or 0" if save_model_every or create_image_every: assert log_directory, "Log directory is empty" + def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0