From d60007800876af7973ee3b40e1a8726d4c5b6cd2 Mon Sep 17 00:00:00 2001 From: Carlos Chavez <85657083+chavinlo@users.noreply.github.com> Date: Mon, 14 Nov 2022 22:08:16 -0500 Subject: [PATCH 1/4] Add options and local inference Added options to: - Disable Inference (it consumes about 2gb of VRAM even when not active) - Disable wandb and: - if no hftoken is provided it just fills it with nothing so it doesn't argues - if wandb is not enabled, save the inference outputs to a local folder along with information about it --- trainer/diffusers_trainer.py | 115 ++++++++++++++++++++++------------- 1 file changed, 72 insertions(+), 43 deletions(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index cc645a7..b5698b4 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -84,6 +84,8 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits') parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.") parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention') +parser.add_argument('--wandb', dest='enablewandb', type=str, default='True', help='Enable WeightsAndBiases Reporting') +parser.add_argument('--inference', dest='enableinference', type=str, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)') args = parser.parse_args() def setup(): @@ -520,7 +522,11 @@ def main(): if rank == 0: os.makedirs(args.output_path, exist_ok=True) - run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb') + + if args.enablewandb: + run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb') + else: + run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode="disabled") # Inform the user of host, and various versions -- useful for debugging issues. print("RUN_NAME:", args.run_name) @@ -534,8 +540,12 @@ def main(): print("RESOLUTION:", args.resolution) if args.hf_token is None: - args.hf_token = os.environ['HF_API_TOKEN'] - print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.') + try: + args.hf_token = os.environ['HF_API_TOKEN'] + print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.') + except Exception: + print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)") + args.hf_token = "none" device = torch.device('cuda') @@ -744,49 +754,68 @@ def main(): if global_step % args.save_steps == 0: save_checkpoint(global_step) - if global_step % args.image_log_steps == 0: - if rank == 0: - # get prompt from random batch - prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) + if args.enableinference: + if global_step % args.image_log_steps == 0: + if rank == 0: + # get prompt from random batch + prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) - if args.image_log_scheduler == 'DDIMScheduler': - print('using DDIMScheduler scheduler') - scheduler = DDIMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - else: - print('using PNDMScheduler scheduler') - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ) + if args.image_log_scheduler == 'DDIMScheduler': + print('using DDIMScheduler scheduler') + scheduler = DDIMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + else: + print('using PNDMScheduler scheduler') + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) - pipeline = StableDiffusionPipeline( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - safety_checker=None, # disable safety checker to save memory - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), - ).to(device) - # inference - images = [] - with torch.no_grad(): - with torch.autocast('cuda', enabled=args.fp16): - for _ in range(args.image_log_amount): - images.append( - wandb.Image(pipeline( - prompt, num_inference_steps=args.image_log_inference_steps - ).images[0], - caption=prompt) - ) - # log images under single caption - run.log({'images': images}, step=global_step) + pipeline = StableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + safety_checker=None, # disable safety checker to save memory + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ).to(device) + # inference + if args.enablewandb: + images = [] + else: + saveInferencePath = args.output_path + "/inference" + os.makedirs(saveInferencePath, exist_ok=True) + with torch.no_grad(): + with torch.autocast('cuda', enabled=args.fp16): + for _ in range(args.image_log_amount): + if args.enablewandb: + images.append( + wandb.Image(pipeline( + prompt, num_inference_steps=args.image_log_inference_steps + ).images[0], + caption=prompt) + ) + else: + from datetime import datetime + images = pipeline(prompt, num_inference_steps=args.image_log_inference_steps).images[0] + filenameImg = str(time.time_ns()) + ".png" + filenameTxt = str(time.time_ns()) + ".txt" + images.save(saveInferencePath + "/" + filenameImg) + with open(saveInferencePath + "/" + filenameTxt, 'a') as f: + f.write('Used prompt: ' + prompt + '\n') + f.write('Generated Image Filename: ' + filenameImg + '\n') + f.write('Generated at: ' + str(global_step) + ' steps' + '\n') + f.write('Generated at: ' + str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))+ '\n') - # cleanup so we don't run out of memory - del pipeline - gc.collect() - torch.distributed.barrier() + # log images under single caption + if args.enablewandb: + run.log({'images': images}, step=global_step) + + # cleanup so we don't run out of memory + del pipeline + gc.collect() + torch.distributed.barrier() except Exception as e: print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}') pass From 80e24229674a45ad4b18336b4a8b815c893116f3 Mon Sep 17 00:00:00 2001 From: Carlos Chavez <85657083+chavinlo@users.noreply.github.com> Date: Wed, 16 Nov 2022 10:39:20 -0500 Subject: [PATCH 2/4] sync trainer with main branch --- trainer/diffusers_trainer.py | 157 +++++++++++++++++++++++++++++------ 1 file changed, 133 insertions(+), 24 deletions(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index b5698b4..b99f83c 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -2,8 +2,8 @@ # `nvcc --version` to get CUDA version. # `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA. # Example Usage: -# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True -# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True +# Single GPU: torchrun --nproc_per_node=1 trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=1 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True +# Multiple GPUs: torchrun --nproc_per_node=N trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True import argparse import socket @@ -26,6 +26,7 @@ import numpy as np import json import re import traceback +import shutil try: pynvml.nvmlInit() @@ -38,6 +39,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.optimization import get_scheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from PIL import Image, ImageOps +from PIL.Image import Image as Img from typing import Dict, List, Generator, Tuple from scipy.interpolate import interp1d @@ -84,8 +86,12 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits') parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.") parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention') +parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.') +parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with _cropped.') +parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.') parser.add_argument('--wandb', dest='enablewandb', type=str, default='True', help='Enable WeightsAndBiases Reporting') parser.add_argument('--inference', dest='enableinference', type=str, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)') + args = parser.parse_args() def setup(): @@ -145,33 +151,137 @@ def _sort_by_ratio(bucket: tuple) -> float: def _sort_by_area(bucket: tuple) -> float: return bucket[0] * bucket[1] +class Validation(): + def __init__(self, is_skipped: bool, is_extended: bool) -> None: + if is_skipped: + self.validate = self.__no_op + return print("Validation: Skipped") + + if is_extended: + self.validate = self.__extended_validate + return print("Validation: Extended") + + self.validate = self.__validate + print("Validation: Standard") + + def __validate(self, fp: str) -> bool: + try: + Image.open(fp) + return True + except: + print(f'WARNING: Image cannot be opened: {fp}') + return False + + def __extended_validate(self, fp: str) -> bool: + try: + Image.open(fp).load() + return True + except (OSError) as error: + if 'truncated' in str(error): + print(f'WARNING: Image truncated: {error}') + return False + print(f'WARNING: Image cannot be opened: {error}') + return False + except: + print(f'WARNING: Image cannot be opened: {error}') + return False + + def __no_op(self, fp: str) -> bool: + return True + +class Resize(): + def __init__(self, is_resizing: bool, is_not_migrating: bool) -> None: + if not is_resizing: + self.resize = self.__no_op + return + + if not is_not_migrating: + self.resize = self.__migration + dataset_path = os.path.split(args.dataset) + self.__directory = os.path.join( + dataset_path[0], + f'{dataset_path[1]}_cropped' + ) + os.makedirs(self.__directory, exist_ok=True) + return print(f"Resizing: Performing migration to '{self.__directory}'.") + + self.resize = self.__no_migration + + def __no_migration(self, image_path: str, w: int, h: int) -> Img: + return ImageOps.fit( + Image.open(image_path), + (w, h), + bleed=0.0, + centering=(0.5, 0.5), + method=Image.Resampling.LANCZOS + ).convert(mode='RGB') + + def __migration(self, image_path: str, w: int, h: int) -> Img: + filename = re.sub('\.[^/.]+$', '', os.path.split(image_path)[1]) + + image = ImageOps.fit( + Image.open(image_path), + (w, h), + bleed=0.0, + centering=(0.5, 0.5), + method=Image.Resampling.LANCZOS + ).convert(mode='RGB') + + image.save( + os.path.join(f'{self.__directory}', f'{filename}.jpg'), + optimize=True + ) + + try: + shutil.copy( + os.path.join(args.dataset, f'{filename}.txt'), + os.path.join(self.__directory, f'{filename}.txt'), + follow_symlinks=False + ) + except (FileNotFoundError): + f = open( + os.path.join(self.__directory, f'{filename}.txt'), + 'w', + encoding='UTF-8' + ) + f.close() + + return image + + def __no_op(self, image_path: str, w: int, h: int) -> Img: + return Image.open(image_path) + class ImageStore: def __init__(self, data_dir: str) -> None: self.data_dir = data_dir self.image_files = [] [self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']] - self.image_files = [x for x in self.image_files if self.__valid_file(x)] + + self.validator = Validation( + args.skip_validation, + args.extended_validation + ).validate + + self.resizer = Resize(args.resize, args.no_migration).resize + + self.image_files = [x for x in self.image_files if self.validator(x)] def __len__(self) -> int: return len(self.image_files) - def __valid_file(self, f) -> bool: - try: - Image.open(f) - return True - except: - print(f'WARNING: Unable to open file: {f}') - return False - # iterator returns images as PIL images and their index in the store - def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: + def entries_iterator(self) -> Generator[Tuple[Img, int], None, None]: for f in range(len(self)): - yield Image.open(self.image_files[f]).convert(mode='RGB'), f + yield Image.open(self.image_files[f]), f # get image by index - def get_image(self, ref: Tuple[int, int, int]) -> Image.Image: - return Image.open(self.image_files[ref[0]]).convert(mode='RGB') + def get_image(self, ref: Tuple[int, int, int]) -> Img: + return self.resizer( + self.image_files[ref[0]], + ref[1], + ref[2] + ) # gets caption by removing the extension from the filename and replacing it with .txt def get_caption(self, ref: Tuple[int, int, int]) -> str: @@ -399,15 +509,6 @@ class AspectDataset(torch.utils.data.Dataset): image_file = self.store.get_image(item) - if args.resize: - image_file = ImageOps.fit( - image_file, - (item[1], item[2]), - bleed=0.0, - centering=(0.5, 0.5), - method=Image.Resampling.LANCZOS - ) - return_dict['pixel_values'] = self.transforms(image_file) if random.random() > self.ucg: caption_file = self.store.get_caption(item) @@ -599,6 +700,7 @@ def main(): beta_end=0.012, beta_schedule='scaled_linear', num_train_timesteps=1000, + clip_sample=False ) # load dataset @@ -620,6 +722,13 @@ def main(): collate_fn=dataset.collate_fn ) + # Migrate dataset + if args.resize and not args.no_migration: + for _, batch in enumerate(train_dataloader): + continue + print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") + exit(0) + weight_dtype = torch.float16 if args.fp16 else torch.float32 # move models to device From fed3431f0319b3177ff6d872766b5d508408d76f Mon Sep 17 00:00:00 2001 From: chavinlo <85657083+chavinlo@users.noreply.github.com> Date: Wed, 16 Nov 2022 10:44:39 -0500 Subject: [PATCH 3/4] Revert "sync trainer with main branch" This reverts commit 80e24229674a45ad4b18336b4a8b815c893116f3. --- trainer/diffusers_trainer.py | 157 ++++++----------------------------- 1 file changed, 24 insertions(+), 133 deletions(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index b99f83c..b5698b4 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -2,8 +2,8 @@ # `nvcc --version` to get CUDA version. # `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA. # Example Usage: -# Single GPU: torchrun --nproc_per_node=1 trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=1 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True -# Multiple GPUs: torchrun --nproc_per_node=N trainer/diffusers_trainer.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True +# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True +# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True import argparse import socket @@ -26,7 +26,6 @@ import numpy as np import json import re import traceback -import shutil try: pynvml.nvmlInit() @@ -39,7 +38,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.optimization import get_scheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from PIL import Image, ImageOps -from PIL.Image import Image as Img from typing import Dict, List, Generator, Tuple from scipy.interpolate import interp1d @@ -86,12 +84,8 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits') parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.") parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention') -parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.') -parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with _cropped.') -parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.') parser.add_argument('--wandb', dest='enablewandb', type=str, default='True', help='Enable WeightsAndBiases Reporting') parser.add_argument('--inference', dest='enableinference', type=str, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)') - args = parser.parse_args() def setup(): @@ -151,137 +145,33 @@ def _sort_by_ratio(bucket: tuple) -> float: def _sort_by_area(bucket: tuple) -> float: return bucket[0] * bucket[1] -class Validation(): - def __init__(self, is_skipped: bool, is_extended: bool) -> None: - if is_skipped: - self.validate = self.__no_op - return print("Validation: Skipped") - - if is_extended: - self.validate = self.__extended_validate - return print("Validation: Extended") - - self.validate = self.__validate - print("Validation: Standard") - - def __validate(self, fp: str) -> bool: - try: - Image.open(fp) - return True - except: - print(f'WARNING: Image cannot be opened: {fp}') - return False - - def __extended_validate(self, fp: str) -> bool: - try: - Image.open(fp).load() - return True - except (OSError) as error: - if 'truncated' in str(error): - print(f'WARNING: Image truncated: {error}') - return False - print(f'WARNING: Image cannot be opened: {error}') - return False - except: - print(f'WARNING: Image cannot be opened: {error}') - return False - - def __no_op(self, fp: str) -> bool: - return True - -class Resize(): - def __init__(self, is_resizing: bool, is_not_migrating: bool) -> None: - if not is_resizing: - self.resize = self.__no_op - return - - if not is_not_migrating: - self.resize = self.__migration - dataset_path = os.path.split(args.dataset) - self.__directory = os.path.join( - dataset_path[0], - f'{dataset_path[1]}_cropped' - ) - os.makedirs(self.__directory, exist_ok=True) - return print(f"Resizing: Performing migration to '{self.__directory}'.") - - self.resize = self.__no_migration - - def __no_migration(self, image_path: str, w: int, h: int) -> Img: - return ImageOps.fit( - Image.open(image_path), - (w, h), - bleed=0.0, - centering=(0.5, 0.5), - method=Image.Resampling.LANCZOS - ).convert(mode='RGB') - - def __migration(self, image_path: str, w: int, h: int) -> Img: - filename = re.sub('\.[^/.]+$', '', os.path.split(image_path)[1]) - - image = ImageOps.fit( - Image.open(image_path), - (w, h), - bleed=0.0, - centering=(0.5, 0.5), - method=Image.Resampling.LANCZOS - ).convert(mode='RGB') - - image.save( - os.path.join(f'{self.__directory}', f'{filename}.jpg'), - optimize=True - ) - - try: - shutil.copy( - os.path.join(args.dataset, f'{filename}.txt'), - os.path.join(self.__directory, f'{filename}.txt'), - follow_symlinks=False - ) - except (FileNotFoundError): - f = open( - os.path.join(self.__directory, f'{filename}.txt'), - 'w', - encoding='UTF-8' - ) - f.close() - - return image - - def __no_op(self, image_path: str, w: int, h: int) -> Img: - return Image.open(image_path) - class ImageStore: def __init__(self, data_dir: str) -> None: self.data_dir = data_dir self.image_files = [] [self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']] - - self.validator = Validation( - args.skip_validation, - args.extended_validation - ).validate - - self.resizer = Resize(args.resize, args.no_migration).resize - - self.image_files = [x for x in self.image_files if self.validator(x)] + self.image_files = [x for x in self.image_files if self.__valid_file(x)] def __len__(self) -> int: return len(self.image_files) + def __valid_file(self, f) -> bool: + try: + Image.open(f) + return True + except: + print(f'WARNING: Unable to open file: {f}') + return False + # iterator returns images as PIL images and their index in the store - def entries_iterator(self) -> Generator[Tuple[Img, int], None, None]: + def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: for f in range(len(self)): - yield Image.open(self.image_files[f]), f + yield Image.open(self.image_files[f]).convert(mode='RGB'), f # get image by index - def get_image(self, ref: Tuple[int, int, int]) -> Img: - return self.resizer( - self.image_files[ref[0]], - ref[1], - ref[2] - ) + def get_image(self, ref: Tuple[int, int, int]) -> Image.Image: + return Image.open(self.image_files[ref[0]]).convert(mode='RGB') # gets caption by removing the extension from the filename and replacing it with .txt def get_caption(self, ref: Tuple[int, int, int]) -> str: @@ -509,6 +399,15 @@ class AspectDataset(torch.utils.data.Dataset): image_file = self.store.get_image(item) + if args.resize: + image_file = ImageOps.fit( + image_file, + (item[1], item[2]), + bleed=0.0, + centering=(0.5, 0.5), + method=Image.Resampling.LANCZOS + ) + return_dict['pixel_values'] = self.transforms(image_file) if random.random() > self.ucg: caption_file = self.store.get_caption(item) @@ -700,7 +599,6 @@ def main(): beta_end=0.012, beta_schedule='scaled_linear', num_train_timesteps=1000, - clip_sample=False ) # load dataset @@ -722,13 +620,6 @@ def main(): collate_fn=dataset.collate_fn ) - # Migrate dataset - if args.resize and not args.no_migration: - for _, batch in enumerate(train_dataloader): - continue - print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") - exit(0) - weight_dtype = torch.float16 if args.fp16 else torch.float32 # move models to device From a2772fc668bd9b9a924efdf8b70dd50d0b163a4e Mon Sep 17 00:00:00 2001 From: chavinlo <85657083+chavinlo@users.noreply.github.com> Date: Wed, 16 Nov 2022 10:55:38 -0500 Subject: [PATCH 4/4] fixes --- trainer/diffusers_trainer.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index b5698b4..8ed36e4 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -84,8 +84,8 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits') parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.") parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention') -parser.add_argument('--wandb', dest='enablewandb', type=str, default='True', help='Enable WeightsAndBiases Reporting') -parser.add_argument('--inference', dest='enableinference', type=str, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)') +parser.add_argument('--wandb', dest='enablewandb', type=bool_t, default='True', help='Enable WeightsAndBiases Reporting') +parser.add_argument('--inference', dest='enableinference', type=bool_t, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)') args = parser.parse_args() def setup(): @@ -523,10 +523,11 @@ def main(): if rank == 0: os.makedirs(args.output_path, exist_ok=True) + mode = 'enabled' if args.enablewandb: - run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb') - else: - run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode="disabled") + mode = 'disabled' + + run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode=mode) # Inform the user of host, and various versions -- useful for debugging issues. print("RUN_NAME:", args.run_name) @@ -539,10 +540,13 @@ def main(): print("FP16:", args.fp16) print("RESOLUTION:", args.resolution) - if args.hf_token is None: + + if args.hf_token is not None: + print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.') + else: try: args.hf_token = os.environ['HF_API_TOKEN'] - print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.') + print("HF Token set via enviroment variable") except Exception: print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)") args.hf_token = "none"