diff --git a/modules/processing.py b/modules/processing.py index aae39866b..544667a4d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List import modules.sd_hijack from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors from modules.sd_hijack import model_hijack +from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.paths as paths @@ -30,7 +31,6 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType -decode_first_stage = sd_samplers_common.decode_first_stage # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -84,7 +84,7 @@ def txt2img_image_conditioning(sd_model, x, width, height): # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) + image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method)) # Add the fake full 1s mask to the first dimension. image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) @@ -203,7 +203,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -216,7 +216,7 @@ class StableDiffusionProcessing: return conditioning def edit_image_conditioning(self, source_image): - conditioning_image = self.sd_model.encode_first_stage(source_image).mode() + conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) return conditioning_image @@ -255,7 +255,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + conditioning_image = images_tensor_to_samples(conditioning_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -1099,9 +1099,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): decoded_samples = torch.from_numpy(np.array(batch_images)) decoded_samples = decoded_samples.to(shared.device) - decoded_samples = 2. * decoded_samples - 1. - samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) + samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method)) image_conditioning = self.img2img_image_conditioning(decoded_samples, samples) @@ -1339,7 +1338,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") image = torch.from_numpy(batch_images) - from modules.sd_samplers_common import images_tensor_to_samples, approximation_indexes self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) devices.torch_gc() diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 7269514f8..42a29fc9c 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -75,7 +75,7 @@ def images_tensor_to_samples(image, approximation=None, model=None): if approximation == 3: image = image.to(devices.device, devices.dtype) - x_latent = sd_vae_taesd.encoder_model()(image) / 1.5 + x_latent = sd_vae_taesd.encoder_model()(image) else: if model is None: model = shared.sd_model diff --git a/run.ps1 b/run.ps1 new file mode 100644 index 000000000..82c1660bb --- /dev/null +++ b/run.ps1 @@ -0,0 +1 @@ +.\venv\Scripts\accelerate-launch.exe --num_cpu_threads_per_process=6 --api .\launch.py --listen --port 17415 --xformers --opt-channelslast \ No newline at end of file diff --git a/run_local.ps1 b/run_local.ps1 new file mode 100644 index 000000000..e2ac43dbd --- /dev/null +++ b/run_local.ps1 @@ -0,0 +1,3 @@ +.\venv\Scripts\Activate.ps1 +python .\launch.py --xformers --opt-channelslast --api +. $PSCommandPath \ No newline at end of file diff --git a/update.ps1 b/update.ps1 new file mode 100644 index 000000000..9960bead4 --- /dev/null +++ b/update.ps1 @@ -0,0 +1 @@ +git stash push && git pull --rebase && git stash pop \ No newline at end of file