From f0c1063a707a4a43823b0ed00e2a8eeb22a9ed0a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 4 Aug 2023 09:09:09 +0300 Subject: [PATCH] resolve some of circular import issues for kohaku --- modules/hypernetworks/hypernetwork.py | 5 ++--- modules/processing.py | 7 +------ modules/sd_hijack.py | 6 +++--- modules/sd_samplers_common.py | 10 ++++++++-- modules/textual_inversion/textual_inversion.py | 4 +++- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c4821d21a..70f1cbd26 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -10,7 +10,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors +from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # images allows training previews to have infotext. Importing it at the top causes a circular import problem. - from modules import images + from modules import images, processing save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 diff --git a/modules/processing.py b/modules/processing.py index 8f34c8b4c..8086a2b02 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -30,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType +decode_first_stage = sd_samplers_common.decode_first_stage # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -572,12 +573,6 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): return samples -def decode_first_stage(model, x): - x = model.decode_first_stage(x.to(devices.dtype_vae)) - - return x - - def get_fixed_seed(seed): if seed is None or seed == '' or seed == -1: return int(random.randrange(4294967294)) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cfa5f0ebb..609fd56c4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -2,7 +2,6 @@ import torch from torch.nn.functional import silu from types import MethodType -import modules.textual_inversion.textual_inversion from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts @@ -164,12 +163,13 @@ class StableDiffusionModelHijack: clip = None optimization_method = None - embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() - def __init__(self): + import modules.textual_inversion.textual_inversion + self.extra_generation_params = {} self.comments = [] + self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) def apply_optimizations(self, option=None): diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 5deda7616..b3d344e77 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,7 +2,7 @@ from collections import namedtuple import numpy as np import torch from PIL import Image -from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared +from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared from modules.shared import opts, state SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -35,7 +35,7 @@ def single_sample_to_image(sample, approximation=None): x_sample = sample * 1.5 x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() else: - x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5 + x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5 x_sample = torch.clamp(x_sample, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) @@ -44,6 +44,12 @@ def single_sample_to_image(sample, approximation=None): return Image.fromarray(x_sample) +def decode_first_stage(model, x): + x = model.decode_first_stage(x.to(devices.dtype_vae)) + + return x + + def sample_to_image(samples, index=0, approximation=None): return single_sample_to_image(samples[index], approximation) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4713bc2d9..aa79dc098 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -13,7 +13,7 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes +from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -387,6 +387,8 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + from modules import processing + save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 template_file = textual_inversion_templates.get(template_filename, None)