Merge pull request #12311 from AUTOMATIC1111/efficient-vae-methods

Add TAESD(or more) options for all the VAE encode/decode operation
This commit is contained in:
AUTOMATIC1111 2023-08-05 09:24:26 +03:00 committed by GitHub
commit f879cac1e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 100 additions and 27 deletions

View File

@ -307,6 +307,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Schedule rho" not in res: if "Schedule rho" not in res:
res["Schedule rho"] = 0 res["Schedule rho"] = 0
if "VAE Encoder" not in res:
res["VAE Encoder"] = "Full"
if "VAE Decoder" not in res:
res["VAE Decoder"] = "Full"
return res return res
@ -332,6 +338,8 @@ infotext_to_setting_name_mapping = [
('RNG', 'randn_source'), ('RNG', 'randn_source'),
('NGMS', 's_min_uncond'), ('NGMS', 's_min_uncond'),
('Pad conds', 'pad_cond_uncond'), ('Pad conds', 'pad_cond_uncond'),
('VAE Encoder', 'sd_vae_encode_method'),
('VAE Decoder', 'sd_vae_decode_method'),
] ]

View File

@ -16,6 +16,7 @@ from typing import Any, Dict, List
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.paths as paths import modules.paths as paths
@ -30,7 +31,6 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType from blendmodes.blend import blendLayers, BlendType
decode_first_stage = sd_samplers_common.decode_first_stage
# some of those options should not be changed at all because they would break the model, so I removed them from options. # some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4 opt_C = 4
@ -84,7 +84,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
# The "masked-image" in this case will just be all zeros since the entire image is masked. # The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
# Add the fake full 1s mask to the first dimension. # Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
@ -203,7 +203,7 @@ class StableDiffusionProcessing:
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
conditioning = torch.nn.functional.interpolate( conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in), self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:], size=conditioning_image.shape[2:],
@ -216,7 +216,7 @@ class StableDiffusionProcessing:
return conditioning return conditioning
def edit_image_conditioning(self, source_image): def edit_image_conditioning(self, source_image):
conditioning_image = self.sd_model.encode_first_stage(source_image).mode() conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
return conditioning_image return conditioning_image
@ -795,6 +795,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if getattr(samples_ddim, 'already_decoded', False): if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim x_samples_ddim = samples_ddim
else: else:
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.stack(x_samples_ddim).float()
@ -1135,11 +1136,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
batch_images.append(image) batch_images.append(image)
decoded_samples = torch.from_numpy(np.array(batch_images)) decoded_samples = torch.from_numpy(np.array(batch_images))
decoded_samples = decoded_samples.to(shared.device)
decoded_samples = 2. * decoded_samples - 1.
decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae) decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples) image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
@ -1374,10 +1374,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
image = torch.from_numpy(batch_images) image = torch.from_numpy(batch_images)
image = 2. * image - 1.
image = image.to(shared.device, dtype=devices.dtype_vae) image = image.to(shared.device, dtype=devices.dtype_vae)
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
devices.torch_gc() devices.torch_gc()
if self.resize_mode == 3: if self.resize_mode == 3:

View File

@ -23,19 +23,29 @@ def setup_img2img_steps(p, steps=None):
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3} approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
def single_sample_to_image(sample, approximation=None): def samples_to_images_tensor(sample, approximation=None, model=None):
'''latents -> images [-1, 1]'''
if approximation is None: if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0) approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2: if approximation == 2:
x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5 x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1: elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5 x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
elif approximation == 3: elif approximation == 3:
x_sample = sample * 1.5 x_sample = sample * 1.5
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
x_sample = x_sample * 2 - 1
else: else:
x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5 if model is None:
model = shared.sd_model
x_sample = model.decode_first_stage(sample)
return x_sample
def single_sample_to_image(sample, approximation=None):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0) x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@ -45,9 +55,9 @@ def single_sample_to_image(sample, approximation=None):
def decode_first_stage(model, x): def decode_first_stage(model, x):
x = model.decode_first_stage(x.to(devices.dtype_vae)) x = x.to(devices.dtype_vae)
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
return x return samples_to_images_tensor(x, approx_index, model)
def sample_to_image(samples, index=0, approximation=None): def sample_to_image(samples, index=0, approximation=None):
@ -58,6 +68,24 @@ def samples_to_image_grid(samples, approximation=None):
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
def images_tensor_to_samples(image, approximation=None, model=None):
'''image[0, 1] -> latent'''
if approximation is None:
approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
if approximation == 3:
image = image.to(devices.device, devices.dtype)
x_latent = sd_vae_taesd.encoder_model()(image)
else:
if model is None:
model = shared.sd_model
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
return x_latent
def store_latent(decoded): def store_latent(decoded):
state.current_latent = decoded state.current_latent = decoded

View File

@ -81,6 +81,6 @@ def cheap_approximation(sample):
coefs = torch.tensor(coeffs).to(sample.device) coefs = torch.tensor(coeffs).to(sample.device)
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs)
return x_sample return x_sample

View File

@ -44,7 +44,17 @@ def decoder():
) )
class TAESD(nn.Module): def encoder():
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4),
)
class TAESDDecoder(nn.Module):
latent_magnitude = 3 latent_magnitude = 3
latent_shift = 0.5 latent_shift = 0.5
@ -55,21 +65,28 @@ class TAESD(nn.Module):
self.decoder.load_state_dict( self.decoder.load_state_dict(
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
@staticmethod
def unscale_latents(x): class TAESDEncoder(nn.Module):
"""[0, 1] -> raw latents""" latent_magnitude = 3
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth"):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = encoder()
self.encoder.load_state_dict(
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
def download_model(model_path, model_url): def download_model(model_path, model_url):
if not os.path.exists(model_path): if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True) os.makedirs(os.path.dirname(model_path), exist_ok=True)
print(f'Downloading TAESD decoder to: {model_path}') print(f'Downloading TAESD model to: {model_path}')
torch.hub.download_url_to_file(model_url, model_path) torch.hub.download_url_to_file(model_url, model_path)
def model(): def decoder_model():
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name) loaded_model = sd_vae_taesd_models.get(model_name)
@ -78,7 +95,7 @@ def model():
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
if os.path.exists(model_path): if os.path.exists(model_path):
loaded_model = TAESD(model_path) loaded_model = TAESDDecoder(model_path)
loaded_model.eval() loaded_model.eval()
loaded_model.to(devices.device, devices.dtype) loaded_model.to(devices.device, devices.dtype)
sd_vae_taesd_models[model_name] = loaded_model sd_vae_taesd_models[model_name] = loaded_model
@ -86,3 +103,22 @@ def model():
raise FileNotFoundError('TAESD model not found') raise FileNotFoundError('TAESD model not found')
return loaded_model.decoder return loaded_model.decoder
def encoder_model():
model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name)
if loaded_model is None:
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
if os.path.exists(model_path):
loaded_model = TAESDEncoder(model_path)
loaded_model.eval()
loaded_model.to(devices.device, devices.dtype)
sd_vae_taesd_models[model_name] = loaded_model
else:
raise FileNotFoundError('TAESD model not found')
return loaded_model.encoder

View File

@ -435,6 +435,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
})) }))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {