diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index c12dcce..bdc712b 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -86,6 +86,8 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits') parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.") parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention') +parser.add_argument('--wandb', dest='enablewandb', type=bool_t, default='True', help='Enable WeightsAndBiases Reporting') +parser.add_argument('--inference', dest='enableinference', type=bool_t, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)') parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.') parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with _cropped.') parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.') @@ -621,7 +623,12 @@ def main(): if rank == 0: os.makedirs(args.output_path, exist_ok=True) - run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb') + + mode = 'enabled' + if args.enablewandb: + mode = 'disabled' + + run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode=mode) # Inform the user of host, and various versions -- useful for debugging issues. print("RUN_NAME:", args.run_name) @@ -634,9 +641,16 @@ def main(): print("FP16:", args.fp16) print("RESOLUTION:", args.resolution) - if args.hf_token is None: - args.hf_token = os.environ['HF_API_TOKEN'] + + if args.hf_token is not None: print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.') + else: + try: + args.hf_token = os.environ['HF_API_TOKEN'] + print("HF Token set via enviroment variable") + except Exception: + print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)") + args.hf_token = "none" device = torch.device('cuda') @@ -853,49 +867,68 @@ def main(): if global_step % args.save_steps == 0: save_checkpoint(global_step) - if global_step % args.image_log_steps == 0: - if rank == 0: - # get prompt from random batch - prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) + if args.enableinference: + if global_step % args.image_log_steps == 0: + if rank == 0: + # get prompt from random batch + prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) - if args.image_log_scheduler == 'DDIMScheduler': - print('using DDIMScheduler scheduler') - scheduler = DDIMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - else: - print('using PNDMScheduler scheduler') - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ) + if args.image_log_scheduler == 'DDIMScheduler': + print('using DDIMScheduler scheduler') + scheduler = DDIMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + else: + print('using PNDMScheduler scheduler') + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) - pipeline = StableDiffusionPipeline( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - safety_checker=None, # disable safety checker to save memory - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), - ).to(device) - # inference - images = [] - with torch.no_grad(): - with torch.autocast('cuda', enabled=args.fp16): - for _ in range(args.image_log_amount): - images.append( - wandb.Image(pipeline( - prompt, num_inference_steps=args.image_log_inference_steps - ).images[0], - caption=prompt) - ) - # log images under single caption - run.log({'images': images}, step=global_step) + pipeline = StableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + safety_checker=None, # disable safety checker to save memory + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ).to(device) + # inference + if args.enablewandb: + images = [] + else: + saveInferencePath = args.output_path + "/inference" + os.makedirs(saveInferencePath, exist_ok=True) + with torch.no_grad(): + with torch.autocast('cuda', enabled=args.fp16): + for _ in range(args.image_log_amount): + if args.enablewandb: + images.append( + wandb.Image(pipeline( + prompt, num_inference_steps=args.image_log_inference_steps + ).images[0], + caption=prompt) + ) + else: + from datetime import datetime + images = pipeline(prompt, num_inference_steps=args.image_log_inference_steps).images[0] + filenameImg = str(time.time_ns()) + ".png" + filenameTxt = str(time.time_ns()) + ".txt" + images.save(saveInferencePath + "/" + filenameImg) + with open(saveInferencePath + "/" + filenameTxt, 'a') as f: + f.write('Used prompt: ' + prompt + '\n') + f.write('Generated Image Filename: ' + filenameImg + '\n') + f.write('Generated at: ' + str(global_step) + ' steps' + '\n') + f.write('Generated at: ' + str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))+ '\n') - # cleanup so we don't run out of memory - del pipeline - gc.collect() - torch.distributed.barrier() + # log images under single caption + if args.enablewandb: + run.log({'images': images}, step=global_step) + + # cleanup so we don't run out of memory + del pipeline + gc.collect() + torch.distributed.barrier() except Exception as e: print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}') pass