diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 0e0dbca49..3b740dbd4 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -235,11 +235,13 @@ def prepare_environment(): openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") + stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") + stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") @@ -297,6 +299,7 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) + git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) diff --git a/modules/paths.py b/modules/paths.py index bada804e6..f509a85f7 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,6 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), + (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 0069d8b0e..d7f9e9a9c 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -144,7 +144,12 @@ def get_learned_conditioning(model, prompts, steps): cond_schedule = [] for i, (end_at_step, _) in enumerate(prompt_schedule): - cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i])) + if isinstance(conds, dict): + cond = {k: v[i] for k, v in conds.items()} + else: + cond = conds[i] + + cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond)) cache[prompt] = cond_schedule res.append(cond_schedule) @@ -214,20 +219,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) +class DictWithShape(dict): + def __init__(self, x, shape): + super().__init__() + self.update(x) + + @property + def shape(self): + return self["crossattn"].shape + + def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): param = c[0][0].cond - res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + is_dict = isinstance(param, dict) + + if is_dict: + dict_cond = param + res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()} + res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape) + else: + res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + for i, cond_schedule in enumerate(c): target_index = 0 for current, entry in enumerate(cond_schedule): if current_step <= entry.end_at_step: target_index = current break - res[i] = cond_schedule[target_index].cond + + if is_dict: + for k, param in cond_schedule[target_index].cond.items(): + res[k][i] = param + else: + res[i] = cond_schedule[target_index].cond return res +def stack_conds(tensors): + # if prompts have wildly different lengths above the limit we'll get tensors of different shapes + # and won't be able to torch.stack them. So this fixes that. + token_count = max([x.shape[0] for x in tensors]) + for i in range(len(tensors)): + if tensors[i].shape[0] != token_count: + last_vector = tensors[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) + tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + + return torch.stack(tensors) + + + def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): param = c.batch[0][0].schedules[0].cond @@ -249,16 +291,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): conds_list.append(conds_for_batch) - # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes - # and won't be able to torch.stack them. So this fixes that. - token_count = max([x.shape[0] for x in tensors]) - for i in range(len(tensors)): - if tensors[i].shape[0] != token_count: - last_vector = tensors[i][-1:] - last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) - tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + if isinstance(tensors[0], dict): + keys = list(tensors[0].keys()) + stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys} + stacked = DictWithShape(stacked, stacked['crossattn'].shape) + else: + stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype) - return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) + return conds_list, stacked re_attention = re.compile(r""" diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3b6f95ce2..c4b9211fd 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -166,6 +166,15 @@ class StableDiffusionModelHijack: undo_optimizations() def hijack(self, m): + conditioner = getattr(m, 'conditioner', None) + if conditioner: + for i in range(len(conditioner.embedders)): + embedder = conditioner.embedders[i] + if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder': + embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) + m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) + conditioner.embedders[i] = m.cond_stage_model + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index f733e8529..6ac5bda6e 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -16,6 +16,10 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit self.id_end = tokenizer.encoder[""] self.id_pad = 0 + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.legacy_ucg_val = None + def tokenize(self, texts): assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' diff --git a/modules/sd_models.py b/modules/sd_models.py index 060e0007c..8d639583c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -289,6 +289,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + if hasattr(model, 'conditioner'): + sd_models_xl.extend_sdxl(model) + model.load_state_dict(state_dict, strict=False) del state_dict timer.record("apply weights to model") @@ -334,7 +337,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.sd_checkpoint_info = checkpoint_info shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 - model.logvar = model.logvar.to(devices.device) # fix for training + if hasattr(model, 'logvar'): + model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9bfe1237d..96501569b 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -6,11 +6,13 @@ from modules import shared, paths, sd_disable_initialization sd_configs_path = shared.sd_configs_path sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") +sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference") config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py new file mode 100644 index 000000000..d43b88686 --- /dev/null +++ b/modules/sd_models_xl.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import torch + +import sgm.models.diffusion +import sgm.modules.diffusionmodules.denoiser_scaling +import sgm.modules.diffusionmodules.discretizer +from modules import devices + + +def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]): + for embedder in self.conditioner.embedders: + embedder.ucg_rate = 0.0 + + c = self.conditioner({'txt': batch}) + + return c + + +def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): + return self.model(x, t, cond) + + +def extend_sdxl(model): + dtype = next(model.model.diffusion_model.parameters()).dtype + model.model.diffusion_model.dtype = dtype + model.model.conditioning_key = 'crossattn' + + model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0] + model.cond_stage_key = model.cond_stage_model.input_key + + model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" + + discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + + +sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +sgm.models.diffusion.DiffusionEngine.apply_model = apply_model + diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 71581b763..73289ce47 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -53,6 +53,28 @@ k_diffusion_scheduler = { } +def catenate_conds(conds): + if not isinstance(conds[0], dict): + return torch.cat(conds) + + return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} + + +def subscript_cond(cond, a, b): + if not isinstance(cond, dict): + return cond[a:b] + + return {key: vec[a:b] for key, vec in cond.items()} + + +def pad_cond(tensor, repeats, empty): + if not isinstance(tensor, dict): + return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) + + tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) + return tensor + + class CFGDenoiser(torch.nn.Module): """ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) @@ -105,10 +127,13 @@ class CFGDenoiser(torch.nn.Module): if shared.sd_model.model.conditioning_key == "crossattn-adm": image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} else: image_uncond = image_cond - make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} + if isinstance(uncond, dict): + make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} + else: + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} if not is_edit_model: x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) @@ -140,28 +165,28 @@ class CFGDenoiser(torch.nn.Module): num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] if num_repeats < 0: - tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1) + tensor = pad_cond(tensor, -num_repeats, empty) self.padded_cond_uncond = True elif num_repeats > 0: - uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1) + uncond = pad_cond(uncond, num_repeats, empty) self.padded_cond_uncond = True if tensor.shape[1] == uncond.shape[1] or skip_uncond: if is_edit_model: - cond_in = torch.cat([tensor, uncond, uncond]) + cond_in = catenate_conds([tensor, uncond, uncond]) elif skip_uncond: cond_in = tensor else: - cond_in = torch.cat([tensor, uncond]) + cond_in = catenate_conds([tensor, uncond]) if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in)) + x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b])) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size @@ -170,14 +195,14 @@ class CFGDenoiser(torch.nn.Module): b = min(a + batch_size, tensor.shape[0]) if not is_edit_model: - c_crossattn = [tensor[a:b]] + c_crossattn = subscript_cond(tensor, a, b) else: c_crossattn = torch.cat([tensor[a:b]], uncond) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) if not skip_uncond: - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:])) + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) denoised_image_indexes = [x[0][0] for x in conds_list] if skip_uncond: