Merge pull request #140 from damian0815/fix_samples_clip_skip

Respect clip_skip when generating samples
This commit is contained in:
Victor Hall 2023-04-14 21:30:23 -04:00 committed by GitHub
commit e574805326
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 4 deletions

View File

@ -1,6 +1,7 @@
aiohttp==3.8.4
bitsandbytes==0.37.2
colorama==0.4.6
compel~=1.1.3
ftfy==6.1.1
ipyevents
ipywidgets

View File

@ -651,7 +651,9 @@ def main(args):
config_file_path=args.sample_prompts,
batch_size=max(1,args.batch_size//2),
default_sample_steps=args.sample_steps,
use_xformers=is_xformers_available() and not args.disable_xformers)
use_xformers=is_xformers_available() and not args.disable_xformers,
use_penultimate_clip_layer=(args.clip_skip >= 2)
)
"""
Train the model

View File

@ -9,10 +9,12 @@ import torch
from PIL import Image, ImageDraw, ImageFont
from colorama import Fore, Style
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler
from torch import FloatTensor
from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm.auto import tqdm
from compel import Compel
def clean_filename(filename):
@ -84,7 +86,8 @@ class SampleGenerator:
batch_size: int,
default_seed: int,
default_sample_steps: int,
use_xformers: bool):
use_xformers: bool,
use_penultimate_clip_layer: bool):
self.log_folder = log_folder
self.log_writer = log_writer
self.batch_size = batch_size
@ -92,6 +95,7 @@ class SampleGenerator:
self.use_xformers = use_xformers
self.show_progress_bars = False
self.generate_pretrain_samples = False
self.use_penultimate_clip_layer = use_penultimate_clip_layer
self.default_resolution = default_resolution
self.default_seed = default_seed
@ -198,6 +202,9 @@ class SampleGenerator:
compatibility_test=sample_compatibility_test))
pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
compel = Compel(tokenizer=pipe.tokenizer,
text_encoder=pipe.text_encoder,
use_penultimate_clip_layer=self.use_penultimate_clip_layer)
for batch in batches:
prompts = [p.prompt for p in batch]
negative_prompts = [p.negative_prompt for p in batch]
@ -211,8 +218,10 @@ class SampleGenerator:
for cfg in self.cfgs:
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
images = pipe(prompt=prompts,
negative_prompt=negative_prompts,
prompt_embeds = compel(prompts)
negative_prompt_embeds = compel(negative_prompts)
images = pipe(prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
num_inference_steps=self.num_inference_steps,
num_images_per_prompt=1,
guidance_scale=cfg,

View File

@ -22,6 +22,7 @@ pip install OmegaConf==2.2.3
pip install numpy==1.23.5
pip install keyboard
pip install lion-pytorch
pip install compel~=1.1.3
python utils/patch_bnb.py
python utils/get_yamls.py
GOTO :eof