parent
80e2422967
commit
fed3431f03
|
@ -2,8 +2,8 @@
|
||||||
# `nvcc --version` to get CUDA version.
|
# `nvcc --version` to get CUDA version.
|
||||||
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
|
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
|
||||||
# Example Usage:
|
# Example Usage:
|
||||||
# Single GPU: torchrun --nproc_per_node=1 trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=1 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||||
# Multiple GPUs: torchrun --nproc_per_node=N trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import socket
|
import socket
|
||||||
|
@ -26,7 +26,6 @@ import numpy as np
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
import shutil
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
|
@ -39,7 +38,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.Image import Image as Img
|
|
||||||
|
|
||||||
from typing import Dict, List, Generator, Tuple
|
from typing import Dict, List, Generator, Tuple
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
|
@ -86,12 +84,8 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us
|
||||||
parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits')
|
parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits')
|
||||||
parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
|
parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
|
||||||
parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention')
|
parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention')
|
||||||
parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
|
|
||||||
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
|
|
||||||
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
|
|
||||||
parser.add_argument('--wandb', dest='enablewandb', type=str, default='True', help='Enable WeightsAndBiases Reporting')
|
parser.add_argument('--wandb', dest='enablewandb', type=str, default='True', help='Enable WeightsAndBiases Reporting')
|
||||||
parser.add_argument('--inference', dest='enableinference', type=str, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)')
|
parser.add_argument('--inference', dest='enableinference', type=str, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
|
@ -151,137 +145,33 @@ def _sort_by_ratio(bucket: tuple) -> float:
|
||||||
def _sort_by_area(bucket: tuple) -> float:
|
def _sort_by_area(bucket: tuple) -> float:
|
||||||
return bucket[0] * bucket[1]
|
return bucket[0] * bucket[1]
|
||||||
|
|
||||||
class Validation():
|
|
||||||
def __init__(self, is_skipped: bool, is_extended: bool) -> None:
|
|
||||||
if is_skipped:
|
|
||||||
self.validate = self.__no_op
|
|
||||||
return print("Validation: Skipped")
|
|
||||||
|
|
||||||
if is_extended:
|
|
||||||
self.validate = self.__extended_validate
|
|
||||||
return print("Validation: Extended")
|
|
||||||
|
|
||||||
self.validate = self.__validate
|
|
||||||
print("Validation: Standard")
|
|
||||||
|
|
||||||
def __validate(self, fp: str) -> bool:
|
|
||||||
try:
|
|
||||||
Image.open(fp)
|
|
||||||
return True
|
|
||||||
except:
|
|
||||||
print(f'WARNING: Image cannot be opened: {fp}')
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __extended_validate(self, fp: str) -> bool:
|
|
||||||
try:
|
|
||||||
Image.open(fp).load()
|
|
||||||
return True
|
|
||||||
except (OSError) as error:
|
|
||||||
if 'truncated' in str(error):
|
|
||||||
print(f'WARNING: Image truncated: {error}')
|
|
||||||
return False
|
|
||||||
print(f'WARNING: Image cannot be opened: {error}')
|
|
||||||
return False
|
|
||||||
except:
|
|
||||||
print(f'WARNING: Image cannot be opened: {error}')
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __no_op(self, fp: str) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
class Resize():
|
|
||||||
def __init__(self, is_resizing: bool, is_not_migrating: bool) -> None:
|
|
||||||
if not is_resizing:
|
|
||||||
self.resize = self.__no_op
|
|
||||||
return
|
|
||||||
|
|
||||||
if not is_not_migrating:
|
|
||||||
self.resize = self.__migration
|
|
||||||
dataset_path = os.path.split(args.dataset)
|
|
||||||
self.__directory = os.path.join(
|
|
||||||
dataset_path[0],
|
|
||||||
f'{dataset_path[1]}_cropped'
|
|
||||||
)
|
|
||||||
os.makedirs(self.__directory, exist_ok=True)
|
|
||||||
return print(f"Resizing: Performing migration to '{self.__directory}'.")
|
|
||||||
|
|
||||||
self.resize = self.__no_migration
|
|
||||||
|
|
||||||
def __no_migration(self, image_path: str, w: int, h: int) -> Img:
|
|
||||||
return ImageOps.fit(
|
|
||||||
Image.open(image_path),
|
|
||||||
(w, h),
|
|
||||||
bleed=0.0,
|
|
||||||
centering=(0.5, 0.5),
|
|
||||||
method=Image.Resampling.LANCZOS
|
|
||||||
).convert(mode='RGB')
|
|
||||||
|
|
||||||
def __migration(self, image_path: str, w: int, h: int) -> Img:
|
|
||||||
filename = re.sub('\.[^/.]+$', '', os.path.split(image_path)[1])
|
|
||||||
|
|
||||||
image = ImageOps.fit(
|
|
||||||
Image.open(image_path),
|
|
||||||
(w, h),
|
|
||||||
bleed=0.0,
|
|
||||||
centering=(0.5, 0.5),
|
|
||||||
method=Image.Resampling.LANCZOS
|
|
||||||
).convert(mode='RGB')
|
|
||||||
|
|
||||||
image.save(
|
|
||||||
os.path.join(f'{self.__directory}', f'{filename}.jpg'),
|
|
||||||
optimize=True
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
shutil.copy(
|
|
||||||
os.path.join(args.dataset, f'{filename}.txt'),
|
|
||||||
os.path.join(self.__directory, f'{filename}.txt'),
|
|
||||||
follow_symlinks=False
|
|
||||||
)
|
|
||||||
except (FileNotFoundError):
|
|
||||||
f = open(
|
|
||||||
os.path.join(self.__directory, f'{filename}.txt'),
|
|
||||||
'w',
|
|
||||||
encoding='UTF-8'
|
|
||||||
)
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
def __no_op(self, image_path: str, w: int, h: int) -> Img:
|
|
||||||
return Image.open(image_path)
|
|
||||||
|
|
||||||
class ImageStore:
|
class ImageStore:
|
||||||
def __init__(self, data_dir: str) -> None:
|
def __init__(self, data_dir: str) -> None:
|
||||||
self.data_dir = data_dir
|
self.data_dir = data_dir
|
||||||
|
|
||||||
self.image_files = []
|
self.image_files = []
|
||||||
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
|
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
|
||||||
|
self.image_files = [x for x in self.image_files if self.__valid_file(x)]
|
||||||
self.validator = Validation(
|
|
||||||
args.skip_validation,
|
|
||||||
args.extended_validation
|
|
||||||
).validate
|
|
||||||
|
|
||||||
self.resizer = Resize(args.resize, args.no_migration).resize
|
|
||||||
|
|
||||||
self.image_files = [x for x in self.image_files if self.validator(x)]
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.image_files)
|
return len(self.image_files)
|
||||||
|
|
||||||
|
def __valid_file(self, f) -> bool:
|
||||||
|
try:
|
||||||
|
Image.open(f)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
print(f'WARNING: Unable to open file: {f}')
|
||||||
|
return False
|
||||||
|
|
||||||
# iterator returns images as PIL images and their index in the store
|
# iterator returns images as PIL images and their index in the store
|
||||||
def entries_iterator(self) -> Generator[Tuple[Img, int], None, None]:
|
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
|
||||||
for f in range(len(self)):
|
for f in range(len(self)):
|
||||||
yield Image.open(self.image_files[f]), f
|
yield Image.open(self.image_files[f]).convert(mode='RGB'), f
|
||||||
|
|
||||||
# get image by index
|
# get image by index
|
||||||
def get_image(self, ref: Tuple[int, int, int]) -> Img:
|
def get_image(self, ref: Tuple[int, int, int]) -> Image.Image:
|
||||||
return self.resizer(
|
return Image.open(self.image_files[ref[0]]).convert(mode='RGB')
|
||||||
self.image_files[ref[0]],
|
|
||||||
ref[1],
|
|
||||||
ref[2]
|
|
||||||
)
|
|
||||||
|
|
||||||
# gets caption by removing the extension from the filename and replacing it with .txt
|
# gets caption by removing the extension from the filename and replacing it with .txt
|
||||||
def get_caption(self, ref: Tuple[int, int, int]) -> str:
|
def get_caption(self, ref: Tuple[int, int, int]) -> str:
|
||||||
|
@ -509,6 +399,15 @@ class AspectDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
image_file = self.store.get_image(item)
|
image_file = self.store.get_image(item)
|
||||||
|
|
||||||
|
if args.resize:
|
||||||
|
image_file = ImageOps.fit(
|
||||||
|
image_file,
|
||||||
|
(item[1], item[2]),
|
||||||
|
bleed=0.0,
|
||||||
|
centering=(0.5, 0.5),
|
||||||
|
method=Image.Resampling.LANCZOS
|
||||||
|
)
|
||||||
|
|
||||||
return_dict['pixel_values'] = self.transforms(image_file)
|
return_dict['pixel_values'] = self.transforms(image_file)
|
||||||
if random.random() > self.ucg:
|
if random.random() > self.ucg:
|
||||||
caption_file = self.store.get_caption(item)
|
caption_file = self.store.get_caption(item)
|
||||||
|
@ -700,7 +599,6 @@ def main():
|
||||||
beta_end=0.012,
|
beta_end=0.012,
|
||||||
beta_schedule='scaled_linear',
|
beta_schedule='scaled_linear',
|
||||||
num_train_timesteps=1000,
|
num_train_timesteps=1000,
|
||||||
clip_sample=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# load dataset
|
# load dataset
|
||||||
|
@ -722,13 +620,6 @@ def main():
|
||||||
collate_fn=dataset.collate_fn
|
collate_fn=dataset.collate_fn
|
||||||
)
|
)
|
||||||
|
|
||||||
# Migrate dataset
|
|
||||||
if args.resize and not args.no_migration:
|
|
||||||
for _, batch in enumerate(train_dataloader):
|
|
||||||
continue
|
|
||||||
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
weight_dtype = torch.float16 if args.fp16 else torch.float32
|
weight_dtype = torch.float16 if args.fp16 else torch.float32
|
||||||
|
|
||||||
# move models to device
|
# move models to device
|
||||||
|
|
Loading…
Reference in New Issue