Merge branch 'main' into inference-option

This commit is contained in:
Anthony Mercurio 2022-11-16 16:20:57 -05:00 committed by GitHub
commit dc5849b235
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 141 additions and 24 deletions

8
.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
*_cropped/
**/nsfw-ids.txt
**/*.image
**/*.caption
**/dataset*.tar
**/*.json
**/*.png
**/*.jpg

View File

@ -2,8 +2,8 @@
# `nvcc --version` to get CUDA version. # `nvcc --version` to get CUDA version.
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA. # `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
# Example Usage: # Example Usage:
# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True # Single GPU: torchrun --nproc_per_node=1 trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=1 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True # Multiple GPUs: torchrun --nproc_per_node=N trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
import argparse import argparse
import socket import socket
@ -26,6 +26,7 @@ import numpy as np
import json import json
import re import re
import traceback import traceback
import shutil
try: try:
pynvml.nvmlInit() pynvml.nvmlInit()
@ -38,6 +39,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.Image import Image as Img
from typing import Dict, List, Generator, Tuple from typing import Dict, List, Generator, Tuple
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
@ -86,6 +88,10 @@ parser.add_argument('--resize', type=bool_t, default='False', help="Resizes data
parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention') parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention')
parser.add_argument('--wandb', dest='enablewandb', type=bool_t, default='True', help='Enable WeightsAndBiases Reporting') parser.add_argument('--wandb', dest='enablewandb', type=bool_t, default='True', help='Enable WeightsAndBiases Reporting')
parser.add_argument('--inference', dest='enableinference', type=bool_t, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)') parser.add_argument('--inference', dest='enableinference', type=bool_t, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)')
parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
args = parser.parse_args() args = parser.parse_args()
def setup(): def setup():
@ -145,33 +151,137 @@ def _sort_by_ratio(bucket: tuple) -> float:
def _sort_by_area(bucket: tuple) -> float: def _sort_by_area(bucket: tuple) -> float:
return bucket[0] * bucket[1] return bucket[0] * bucket[1]
class Validation():
def __init__(self, is_skipped: bool, is_extended: bool) -> None:
if is_skipped:
self.validate = self.__no_op
return print("Validation: Skipped")
if is_extended:
self.validate = self.__extended_validate
return print("Validation: Extended")
self.validate = self.__validate
print("Validation: Standard")
def __validate(self, fp: str) -> bool:
try:
Image.open(fp)
return True
except:
print(f'WARNING: Image cannot be opened: {fp}')
return False
def __extended_validate(self, fp: str) -> bool:
try:
Image.open(fp).load()
return True
except (OSError) as error:
if 'truncated' in str(error):
print(f'WARNING: Image truncated: {error}')
return False
print(f'WARNING: Image cannot be opened: {error}')
return False
except:
print(f'WARNING: Image cannot be opened: {error}')
return False
def __no_op(self, fp: str) -> bool:
return True
class Resize():
def __init__(self, is_resizing: bool, is_not_migrating: bool) -> None:
if not is_resizing:
self.resize = self.__no_op
return
if not is_not_migrating:
self.resize = self.__migration
dataset_path = os.path.split(args.dataset)
self.__directory = os.path.join(
dataset_path[0],
f'{dataset_path[1]}_cropped'
)
os.makedirs(self.__directory, exist_ok=True)
return print(f"Resizing: Performing migration to '{self.__directory}'.")
self.resize = self.__no_migration
def __no_migration(self, image_path: str, w: int, h: int) -> Img:
return ImageOps.fit(
Image.open(image_path),
(w, h),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
).convert(mode='RGB')
def __migration(self, image_path: str, w: int, h: int) -> Img:
filename = re.sub('\.[^/.]+$', '', os.path.split(image_path)[1])
image = ImageOps.fit(
Image.open(image_path),
(w, h),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
).convert(mode='RGB')
image.save(
os.path.join(f'{self.__directory}', f'{filename}.jpg'),
optimize=True
)
try:
shutil.copy(
os.path.join(args.dataset, f'{filename}.txt'),
os.path.join(self.__directory, f'{filename}.txt'),
follow_symlinks=False
)
except (FileNotFoundError):
f = open(
os.path.join(self.__directory, f'{filename}.txt'),
'w',
encoding='UTF-8'
)
f.close()
return image
def __no_op(self, image_path: str, w: int, h: int) -> Img:
return Image.open(image_path)
class ImageStore: class ImageStore:
def __init__(self, data_dir: str) -> None: def __init__(self, data_dir: str) -> None:
self.data_dir = data_dir self.data_dir = data_dir
self.image_files = [] self.image_files = []
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']] [self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
self.image_files = [x for x in self.image_files if self.__valid_file(x)]
self.validator = Validation(
args.skip_validation,
args.extended_validation
).validate
self.resizer = Resize(args.resize, args.no_migration).resize
self.image_files = [x for x in self.image_files if self.validator(x)]
def __len__(self) -> int: def __len__(self) -> int:
return len(self.image_files) return len(self.image_files)
def __valid_file(self, f) -> bool:
try:
Image.open(f)
return True
except:
print(f'WARNING: Unable to open file: {f}')
return False
# iterator returns images as PIL images and their index in the store # iterator returns images as PIL images and their index in the store
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: def entries_iterator(self) -> Generator[Tuple[Img, int], None, None]:
for f in range(len(self)): for f in range(len(self)):
yield Image.open(self.image_files[f]).convert(mode='RGB'), f yield Image.open(self.image_files[f]), f
# get image by index # get image by index
def get_image(self, ref: Tuple[int, int, int]) -> Image.Image: def get_image(self, ref: Tuple[int, int, int]) -> Img:
return Image.open(self.image_files[ref[0]]).convert(mode='RGB') return self.resizer(
self.image_files[ref[0]],
ref[1],
ref[2]
)
# gets caption by removing the extension from the filename and replacing it with .txt # gets caption by removing the extension from the filename and replacing it with .txt
def get_caption(self, ref: Tuple[int, int, int]) -> str: def get_caption(self, ref: Tuple[int, int, int]) -> str:
@ -399,15 +509,6 @@ class AspectDataset(torch.utils.data.Dataset):
image_file = self.store.get_image(item) image_file = self.store.get_image(item)
if args.resize:
image_file = ImageOps.fit(
image_file,
(item[1], item[2]),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
)
return_dict['pixel_values'] = self.transforms(image_file) return_dict['pixel_values'] = self.transforms(image_file)
if random.random() > self.ucg: if random.random() > self.ucg:
caption_file = self.store.get_caption(item) caption_file = self.store.get_caption(item)
@ -603,6 +704,7 @@ def main():
beta_end=0.012, beta_end=0.012,
beta_schedule='scaled_linear', beta_schedule='scaled_linear',
num_train_timesteps=1000, num_train_timesteps=1000,
clip_sample=False
) )
# load dataset # load dataset
@ -624,6 +726,13 @@ def main():
collate_fn=dataset.collate_fn collate_fn=dataset.collate_fn
) )
# Migrate dataset
if args.resize and not args.no_migration:
for _, batch in enumerate(train_dataloader):
continue
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
exit(0)
weight_dtype = torch.float16 if args.fp16 else torch.float32 weight_dtype = torch.float16 if args.fp16 else torch.float32
# move models to device # move models to device