From 051116a7d9fd0f0e0556d7a2da2c05bec6d0d954 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Thu, 29 Dec 2022 21:11:06 -0500 Subject: [PATCH] add gpu id support --- data/data_loader.py | 15 ++++++++++++++- train.py | 34 +++++++++++++++++++++++++--------- windows_setup.cmd | 2 +- 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 75b45d4..d2f6834 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -21,6 +21,7 @@ import random from data.image_train_item import ImageTrainItem import data.aspects as aspects from colorama import Fore, Style +import zipfile class DataLoaderMultiAspect(): """ @@ -40,12 +41,24 @@ class DataLoaderMultiAspect(): logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}") logging.info(" Preloading images...") + self.unzip_all(data_root) + self.__recurse_data_root(self=self, recurse_root=data_root) random.Random(seed).shuffle(self.image_paths) prepared_train_data = self.__prescan_images(self.image_paths, flip_p) # ImageTrainItem[] self.image_caption_pairs = self.__bucketize_images(prepared_train_data, batch_size=batch_size, debug_level=debug_level) - #if debug_level > 0: print(f" * DLMA Example: {self.image_caption_pairs[0]} images") + def unzip_all(self, path): + #recursively unzip all files in path + try: + for root, dirs, files in os.walk(path): + for file in files: + if file.endswith('.zip'): + logging.info(f"Unzipping {file}") + with zipfile.ZipFile(path, 'r') as zip_ref: + zip_ref.extractall(path) + except Exception as e: + logging.error(f"Error unzipping files {e}") def get_all_images(self): return self.image_caption_pairs diff --git a/train.py b/train.py index 9671ce8..2172289 100644 --- a/train.py +++ b/train.py @@ -56,6 +56,19 @@ _GRAD_ACCUM_STEPS = 1 # future use... _SIGTERM_EXIT_CODE = 130 _VERY_LARGE_NUMBER = 1e9 +# def is_notebook() -> bool: +# try: +# from IPython import get_ipython +# shell = get_ipython().__class__.__name__ +# if shell == 'ZMQInteractiveShell': +# return True # Jupyter notebook or qtconsole +# elif shell == 'TerminalInteractiveShell': +# return False # Terminal running IPython +# else: +# return False # Other type (?) +# except NameError: +# return False # Probably standard Python interpreter + def clean_filename(filename): """ removes all non-alphanumeric characters from a string so it is safe to use as a filename @@ -163,10 +176,12 @@ def main(args): Main entry point """ log_time = setup_local_logger(args) + #notebook = is_notebook() seed = args.seed if args.seed != -1 else random.randint(0, 2**30) set_seed(seed) gpu = GPU() + device = torch.device(f"cuda:{args.gpuid}") torch.backends.cudnn.benchmark = False args.clip_skip = max(min(3, args.clip_skip), 0) @@ -197,7 +212,7 @@ def main(args): args.lr = args.lr * (total_batch_size**0.5) logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}") - log_folder = os.path.join(args.logdir, f"{args.project_name}{log_time}") + log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}") logging.info(f"Logging to {log_folder}") if not os.path.exists(log_folder): os.makedirs(log_folder) @@ -368,9 +383,9 @@ def main(args): default_lr = 2e-6 curr_lr = args.lr if args.lr is not None else default_lr - vae = vae.to(torch.device("cuda"), dtype=torch.float32 if not args.amp else torch.float16) - unet = unet.to(torch.device("cuda"), dtype=torch.float32 if not args.amp else torch.float16) - text_encoder = text_encoder.to(torch.device("cuda"), dtype=torch.float32 if not args.amp else torch.float16) + vae = vae.to(device, dtype=torch.float32 if not args.amp else torch.float16) + unet = unet.to(device, dtype=torch.float32 if not args.amp else torch.float16) + text_encoder = text_encoder.to(device, dtype=torch.float32 if not args.amp else torch.float16) if args.disable_textenc_training: logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") @@ -550,10 +565,10 @@ def main(args): #logging.info(f" {Fore.GREEN}total_batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{total_batch_size}") logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}") - epoch_pbar = tqdm(range(args.max_epochs), position=0) + epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True) epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}") - steps_pbar = tqdm(range(epoch_len), position=1) + steps_pbar = tqdm(range(epoch_len), position=1, leave=True) steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}") epoch_times = [] @@ -566,7 +581,7 @@ def main(args): append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer) torch.cuda.empty_cache() - loss = torch.tensor(0.0, device=torch.device("cuda"), dtype=torch.float32) + loss = torch.tensor(0.0, device=device, dtype=torch.float32) try: for epoch in range(args.max_epochs): @@ -659,7 +674,7 @@ def main(args): if (global_step + 1) % args.sample_steps == 0: #(unet, text_encoder, tokenizer, scheduler): pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, vae=vae) - pipe = pipe.to(torch.device("cuda")) + pipe = pipe.to(device) with torch.no_grad(): if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1: @@ -732,7 +747,7 @@ if __name__ == "__main__": argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)") argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?") argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)") - argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3]) + argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3, 4]) argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later") @@ -744,6 +759,7 @@ if __name__ == "__main__": argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)") argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random") argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5") + argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)") args = argparser.parse_args() main(args) diff --git a/windows_setup.cmd b/windows_setup.cmd index fc145c9..b3ba63e 100644 --- a/windows_setup.cmd +++ b/windows_setup.cmd @@ -3,7 +3,7 @@ call "venv\Scripts\activate.bat" echo should be in venv here cd . python -m pip install --upgrade pip -pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 +pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116" pip install transformers==4.25.1 pip install diffusers[torch]==0.10.2 pip install pynvml==11.4.1