Merge pull request #140 from damian0815/fix_samples_clip_skip
Respect clip_skip when generating samples
This commit is contained in:
commit
e574805326
|
@ -1,6 +1,7 @@
|
||||||
aiohttp==3.8.4
|
aiohttp==3.8.4
|
||||||
bitsandbytes==0.37.2
|
bitsandbytes==0.37.2
|
||||||
colorama==0.4.6
|
colorama==0.4.6
|
||||||
|
compel~=1.1.3
|
||||||
ftfy==6.1.1
|
ftfy==6.1.1
|
||||||
ipyevents
|
ipyevents
|
||||||
ipywidgets
|
ipywidgets
|
||||||
|
|
4
train.py
4
train.py
|
@ -651,7 +651,9 @@ def main(args):
|
||||||
config_file_path=args.sample_prompts,
|
config_file_path=args.sample_prompts,
|
||||||
batch_size=max(1,args.batch_size//2),
|
batch_size=max(1,args.batch_size//2),
|
||||||
default_sample_steps=args.sample_steps,
|
default_sample_steps=args.sample_steps,
|
||||||
use_xformers=is_xformers_available() and not args.disable_xformers)
|
use_xformers=is_xformers_available() and not args.disable_xformers,
|
||||||
|
use_penultimate_clip_layer=(args.clip_skip >= 2)
|
||||||
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Train the model
|
Train the model
|
||||||
|
|
|
@ -9,10 +9,12 @@ import torch
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler
|
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler
|
||||||
|
from torch import FloatTensor
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
from compel import Compel
|
||||||
|
|
||||||
|
|
||||||
def clean_filename(filename):
|
def clean_filename(filename):
|
||||||
|
@ -84,7 +86,8 @@ class SampleGenerator:
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
default_seed: int,
|
default_seed: int,
|
||||||
default_sample_steps: int,
|
default_sample_steps: int,
|
||||||
use_xformers: bool):
|
use_xformers: bool,
|
||||||
|
use_penultimate_clip_layer: bool):
|
||||||
self.log_folder = log_folder
|
self.log_folder = log_folder
|
||||||
self.log_writer = log_writer
|
self.log_writer = log_writer
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
@ -92,6 +95,7 @@ class SampleGenerator:
|
||||||
self.use_xformers = use_xformers
|
self.use_xformers = use_xformers
|
||||||
self.show_progress_bars = False
|
self.show_progress_bars = False
|
||||||
self.generate_pretrain_samples = False
|
self.generate_pretrain_samples = False
|
||||||
|
self.use_penultimate_clip_layer = use_penultimate_clip_layer
|
||||||
|
|
||||||
self.default_resolution = default_resolution
|
self.default_resolution = default_resolution
|
||||||
self.default_seed = default_seed
|
self.default_seed = default_seed
|
||||||
|
@ -198,6 +202,9 @@ class SampleGenerator:
|
||||||
compatibility_test=sample_compatibility_test))
|
compatibility_test=sample_compatibility_test))
|
||||||
pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
|
pbar = tqdm(total=len(batches), disable=disable_progress_bars, position=1, leave=False,
|
||||||
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
|
desc=f"{Fore.YELLOW}Image samples (batches of {self.batch_size}){Style.RESET_ALL}")
|
||||||
|
compel = Compel(tokenizer=pipe.tokenizer,
|
||||||
|
text_encoder=pipe.text_encoder,
|
||||||
|
use_penultimate_clip_layer=self.use_penultimate_clip_layer)
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
prompts = [p.prompt for p in batch]
|
prompts = [p.prompt for p in batch]
|
||||||
negative_prompts = [p.negative_prompt for p in batch]
|
negative_prompts = [p.negative_prompt for p in batch]
|
||||||
|
@ -211,8 +218,10 @@ class SampleGenerator:
|
||||||
for cfg in self.cfgs:
|
for cfg in self.cfgs:
|
||||||
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
|
pipe.set_progress_bar_config(disable=disable_progress_bars, position=2, leave=False,
|
||||||
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
|
desc=f"{Fore.LIGHTYELLOW_EX}CFG scale {cfg}{Style.RESET_ALL}")
|
||||||
images = pipe(prompt=prompts,
|
prompt_embeds = compel(prompts)
|
||||||
negative_prompt=negative_prompts,
|
negative_prompt_embeds = compel(negative_prompts)
|
||||||
|
images = pipe(prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
num_inference_steps=self.num_inference_steps,
|
num_inference_steps=self.num_inference_steps,
|
||||||
num_images_per_prompt=1,
|
num_images_per_prompt=1,
|
||||||
guidance_scale=cfg,
|
guidance_scale=cfg,
|
||||||
|
|
|
@ -22,6 +22,7 @@ pip install OmegaConf==2.2.3
|
||||||
pip install numpy==1.23.5
|
pip install numpy==1.23.5
|
||||||
pip install keyboard
|
pip install keyboard
|
||||||
pip install lion-pytorch
|
pip install lion-pytorch
|
||||||
|
pip install compel~=1.1.3
|
||||||
python utils/patch_bnb.py
|
python utils/patch_bnb.py
|
||||||
python utils/get_yamls.py
|
python utils/get_yamls.py
|
||||||
GOTO :eof
|
GOTO :eof
|
||||||
|
|
Loading…
Reference in New Issue