diff --git a/docker/requirements-runtime.txt b/docker/requirements-runtime.txt index 512cadd..0dd1687 100644 --- a/docker/requirements-runtime.txt +++ b/docker/requirements-runtime.txt @@ -1,6 +1,7 @@ aiohttp==3.8.4 bitsandbytes==0.37.2 colorama==0.4.6 +compel~=1.1.3 ftfy==6.1.1 ipyevents ipywidgets diff --git a/train.py b/train.py index 67e741b..e1d90f4 100644 --- a/train.py +++ b/train.py @@ -651,7 +651,9 @@ def main(args): config_file_path=args.sample_prompts, batch_size=max(1,args.batch_size//2), default_sample_steps=args.sample_steps, - use_xformers=is_xformers_available() and not args.disable_xformers) + use_xformers=is_xformers_available() and not args.disable_xformers, + use_penultimate_clip_layer=(args.clip_skip >= 2) + ) """ Train the model diff --git a/utils/sample_generator.py b/utils/sample_generator.py index f053dfa..6607b7b 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -9,10 +9,12 @@ import torch from PIL import Image, ImageDraw, ImageFont from colorama import Fore, Style from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler +from torch import FloatTensor from torch.cuda.amp import autocast from torch.utils.tensorboard import SummaryWriter from torchvision import transforms from tqdm.auto import tqdm +from compel import Compel def clean_filename(filename): @@ -84,7 +86,8 @@ class SampleGenerator: batch_size: int, default_seed: int, default_sample_steps: int, - use_xformers: bool): + use_xformers: bool, + use_penultimate_clip_layer: bool): self.log_folder = log_folder self.log_writer = log_writer self.batch_size = batch_size @@ -92,6 +95,7 @@ class SampleGenerator: self.use_xformers = use_xformers self.show_progress_bars = False self.generate_pretrain_samples = False + self.use_penultimate_clip_layer = use_penultimate_clip_layer self.default_resolution = default_resolution self.default_seed = default_seed @@ -198,6 +202,9 @@ class SampleGenerator: compatibility_test=sample_compatibility_test)) pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False, desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}") + compel = Compel(tokenizer=pipe.tokenizer, + text_encoder=pipe.text_encoder, + use_penultimate_clip_layer=self.use_penultimate_clip_layer) for batch in batches: prompts = [p.prompt for p in batch] negative_prompts = [p.negative_prompt for p in batch] @@ -211,8 +218,10 @@ class SampleGenerator: for cfg in self.cfgs: pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False, desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}") - images = pipe(prompt=prompts, - negative_prompt=negative_prompts, + prompt_embeds = compel(prompts) + negative_prompt_embeds = compel(negative_prompts) + images = pipe(prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, num_inference_steps=self.num_inference_steps, num_images_per_prompt=1, guidance_scale=cfg, diff --git a/windows_setup.cmd b/windows_setup.cmd index 880b50b..abce383 100644 --- a/windows_setup.cmd +++ b/windows_setup.cmd @@ -22,6 +22,7 @@ pip install OmegaConf==2.2.3 pip install numpy==1.23.5 pip install keyboard pip install lion-pytorch +pip install compel~=1.1.3 python utils/patch_bnb.py python utils/get_yamls.py GOTO :eof