support for sdxl-inpaint model
This commit is contained in:
parent
cf2772fab0
commit
9feb034e34
|
@ -0,0 +1,98 @@
|
||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.13025
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||||
|
params:
|
||||||
|
num_idx: 1000
|
||||||
|
|
||||||
|
weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
adm_in_channels: 2816
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 9
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
||||||
|
context_dim: 2048
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
# crossattn cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: txt
|
||||||
|
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
layer: hidden
|
||||||
|
layer_idx: 11
|
||||||
|
# crossattn and vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: txt
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||||
|
params:
|
||||||
|
arch: ViT-bigG-14
|
||||||
|
version: laion2b_s39b_b160k
|
||||||
|
freeze: True
|
||||||
|
layer: penultimate
|
||||||
|
always_return_pooled: True
|
||||||
|
legacy: False
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: original_size_as_tuple
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: crop_coords_top_left
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: target_size_as_tuple
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
|
@ -106,6 +106,20 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
||||||
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
|
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
sd = sd_model.model.state_dict()
|
||||||
|
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||||
|
if diffusion_model_input.shape[1] == 9:
|
||||||
|
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
|
||||||
|
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
|
||||||
|
image_conditioning = images_tensor_to_samples(image_conditioning,
|
||||||
|
approximation_indexes.get(opts.sd_vae_encode_method))
|
||||||
|
|
||||||
|
# Add the fake full 1s mask to the first dimension.
|
||||||
|
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||||
|
image_conditioning = image_conditioning.to(x.dtype)
|
||||||
|
|
||||||
|
return image_conditioning
|
||||||
|
|
||||||
# Dummy zero conditioning if we're not using inpainting or unclip models.
|
# Dummy zero conditioning if we're not using inpainting or unclip models.
|
||||||
# Still takes up a bit of memory, but no encoder call.
|
# Still takes up a bit of memory, but no encoder call.
|
||||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||||
|
@ -362,6 +376,11 @@ class StableDiffusionProcessing:
|
||||||
if self.sampler.conditioning_key == "crossattn-adm":
|
if self.sampler.conditioning_key == "crossattn-adm":
|
||||||
return self.unclip_image_conditioning(source_image)
|
return self.unclip_image_conditioning(source_image)
|
||||||
|
|
||||||
|
sd = self.sampler.model_wrap.inner_model.model.state_dict()
|
||||||
|
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||||
|
if diffusion_model_input.shape[1] == 9:
|
||||||
|
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||||
|
|
||||||
# Dummy zero conditioning if we're not using inpainting or depth model.
|
# Dummy zero conditioning if we're not using inpainting or depth model.
|
||||||
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||||
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||||
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||||
|
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
|
||||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||||
|
@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename):
|
||||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||||
|
|
||||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||||
return config_sdxl
|
if diffusion_model_input.shape[1] == 9:
|
||||||
|
return config_sdxl_inpainting
|
||||||
|
else:
|
||||||
|
return config_sdxl
|
||||||
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||||
return config_sdxl_refiner
|
return config_sdxl_refiner
|
||||||
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||||
|
|
|
@ -34,6 +34,11 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
||||||
|
|
||||||
|
|
||||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
||||||
|
sd = self.model.state_dict()
|
||||||
|
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||||
|
if diffusion_model_input.shape[1] == 9:
|
||||||
|
x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||||
|
|
||||||
return self.model(x, t, cond)
|
return self.model(x, t, cond)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue