diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 4e2865587..593abfef3 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -307,6 +307,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Schedule rho" not in res: res["Schedule rho"] = 0 + if "VAE Encoder" not in res: + res["VAE Encoder"] = "Full" + + if "VAE Decoder" not in res: + res["VAE Decoder"] = "Full" + return res @@ -332,6 +338,8 @@ infotext_to_setting_name_mapping = [ ('RNG', 'randn_source'), ('NGMS', 's_min_uncond'), ('Pad conds', 'pad_cond_uncond'), + ('VAE Encoder', 'sd_vae_encode_method'), + ('VAE Decoder', 'sd_vae_decode_method'), ] diff --git a/modules/processing.py b/modules/processing.py index ae58b108a..43cb763fe 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List import modules.sd_hijack from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors from modules.sd_hijack import model_hijack +from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.paths as paths @@ -30,7 +31,6 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType -decode_first_stage = sd_samplers_common.decode_first_stage # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -84,7 +84,7 @@ def txt2img_image_conditioning(sd_model, x, width, height): # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) + image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method)) # Add the fake full 1s mask to the first dimension. image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) @@ -203,7 +203,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -216,7 +216,7 @@ class StableDiffusionProcessing: return conditioning def edit_image_conditioning(self, source_image): - conditioning_image = self.sd_model.encode_first_stage(source_image).mode() + conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) return conditioning_image @@ -795,6 +795,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim else: + p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -1135,11 +1136,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): batch_images.append(image) decoded_samples = torch.from_numpy(np.array(batch_images)) - decoded_samples = decoded_samples.to(shared.device) - decoded_samples = 2. * decoded_samples - 1. decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae) - samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) + self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method + samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method)) image_conditioning = self.img2img_image_conditioning(decoded_samples, samples) @@ -1374,10 +1374,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") image = torch.from_numpy(batch_images) - image = 2. * image - 1. image = image.to(shared.device, dtype=devices.dtype_vae) - - self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) + self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method + self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) devices.torch_gc() if self.resize_mode == 3: diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index b3d344e77..42a29fc9c 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -23,19 +23,29 @@ def setup_img2img_steps(p, steps=None): approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3} -def single_sample_to_image(sample, approximation=None): +def samples_to_images_tensor(sample, approximation=None, model=None): + '''latents -> images [-1, 1]''' if approximation is None: approximation = approximation_indexes.get(opts.show_progress_type, 0) if approximation == 2: - x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5 + x_sample = sd_vae_approx.cheap_approximation(sample) elif approximation == 1: - x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5 + x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach() elif approximation == 3: x_sample = sample * 1.5 - x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() + x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach() + x_sample = x_sample * 2 - 1 else: - x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5 + if model is None: + model = shared.sd_model + x_sample = model.decode_first_stage(sample) + + return x_sample + + +def single_sample_to_image(sample, approximation=None): + x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 x_sample = torch.clamp(x_sample, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) @@ -45,9 +55,9 @@ def single_sample_to_image(sample, approximation=None): def decode_first_stage(model, x): - x = model.decode_first_stage(x.to(devices.dtype_vae)) - - return x + x = x.to(devices.dtype_vae) + approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0) + return samples_to_images_tensor(x, approx_index, model) def sample_to_image(samples, index=0, approximation=None): @@ -58,6 +68,24 @@ def samples_to_image_grid(samples, approximation=None): return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) +def images_tensor_to_samples(image, approximation=None, model=None): + '''image[0, 1] -> latent''' + if approximation is None: + approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0) + + if approximation == 3: + image = image.to(devices.device, devices.dtype) + x_latent = sd_vae_taesd.encoder_model()(image) + else: + if model is None: + model = shared.sd_model + image = image.to(shared.device, dtype=devices.dtype_vae) + image = image * 2 - 1 + x_latent = model.get_first_stage_encoding(model.encode_first_stage(image)) + + return x_latent + + def store_latent(decoded): state.current_latent = decoded diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 86bd658ad..3965e223e 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -81,6 +81,6 @@ def cheap_approximation(sample): coefs = torch.tensor(coeffs).to(sample.device) - x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) + x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs) return x_sample diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 5bf7c76e1..808eb3624 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -44,7 +44,17 @@ def decoder(): ) -class TAESD(nn.Module): +def encoder(): + return nn.Sequential( + conv(3, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 4), + ) + + +class TAESDDecoder(nn.Module): latent_magnitude = 3 latent_shift = 0.5 @@ -55,21 +65,28 @@ class TAESD(nn.Module): self.decoder.load_state_dict( torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) - @staticmethod - def unscale_latents(x): - """[0, 1] -> raw latents""" - return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + +class TAESDEncoder(nn.Module): + latent_magnitude = 3 + latent_shift = 0.5 + + def __init__(self, encoder_path="taesd_encoder.pth"): + """Initialize pretrained TAESD on the given device from the given checkpoints.""" + super().__init__() + self.encoder = encoder() + self.encoder.load_state_dict( + torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) def download_model(model_path, model_url): if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) - print(f'Downloading TAESD decoder to: {model_path}') + print(f'Downloading TAESD model to: {model_path}') torch.hub.download_url_to_file(model_url, model_path) -def model(): +def decoder_model(): model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" loaded_model = sd_vae_taesd_models.get(model_name) @@ -78,7 +95,7 @@ def model(): download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) if os.path.exists(model_path): - loaded_model = TAESD(model_path) + loaded_model = TAESDDecoder(model_path) loaded_model.eval() loaded_model.to(devices.device, devices.dtype) sd_vae_taesd_models[model_name] = loaded_model @@ -86,3 +103,22 @@ def model(): raise FileNotFoundError('TAESD model not found') return loaded_model.decoder + + +def encoder_model(): + model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) + + if loaded_model is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) + download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) + + if os.path.exists(model_path): + loaded_model = TAESDEncoder(model_path) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_taesd_models[model_name] = loaded_model + else: + raise FileNotFoundError('TAESD model not found') + + return loaded_model.encoder diff --git a/modules/shared.py b/modules/shared.py index 8245250a5..516ad7e80 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -435,6 +435,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), + "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"), + "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"), })) options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {