Merge pull request #42 from chavinlo/inference-option
Add options and local inference
This commit is contained in:
commit
c8eeaaf353
|
@ -86,6 +86,8 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us
|
||||||
parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits')
|
parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits')
|
||||||
parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
|
parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
|
||||||
parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention')
|
parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention')
|
||||||
|
parser.add_argument('--wandb', dest='enablewandb', type=bool_t, default='True', help='Enable WeightsAndBiases Reporting')
|
||||||
|
parser.add_argument('--inference', dest='enableinference', type=bool_t, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)')
|
||||||
parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
|
parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
|
||||||
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
|
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
|
||||||
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
|
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
|
||||||
|
@ -621,7 +623,12 @@ def main():
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
os.makedirs(args.output_path, exist_ok=True)
|
os.makedirs(args.output_path, exist_ok=True)
|
||||||
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb')
|
|
||||||
|
mode = 'enabled'
|
||||||
|
if args.enablewandb:
|
||||||
|
mode = 'disabled'
|
||||||
|
|
||||||
|
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode=mode)
|
||||||
|
|
||||||
# Inform the user of host, and various versions -- useful for debugging issues.
|
# Inform the user of host, and various versions -- useful for debugging issues.
|
||||||
print("RUN_NAME:", args.run_name)
|
print("RUN_NAME:", args.run_name)
|
||||||
|
@ -634,9 +641,16 @@ def main():
|
||||||
print("FP16:", args.fp16)
|
print("FP16:", args.fp16)
|
||||||
print("RESOLUTION:", args.resolution)
|
print("RESOLUTION:", args.resolution)
|
||||||
|
|
||||||
if args.hf_token is None:
|
|
||||||
args.hf_token = os.environ['HF_API_TOKEN']
|
if args.hf_token is not None:
|
||||||
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.')
|
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.')
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
args.hf_token = os.environ['HF_API_TOKEN']
|
||||||
|
print("HF Token set via enviroment variable")
|
||||||
|
except Exception:
|
||||||
|
print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)")
|
||||||
|
args.hf_token = "none"
|
||||||
|
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda')
|
||||||
|
|
||||||
|
@ -853,49 +867,68 @@ def main():
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
save_checkpoint(global_step)
|
save_checkpoint(global_step)
|
||||||
|
|
||||||
if global_step % args.image_log_steps == 0:
|
if args.enableinference:
|
||||||
if rank == 0:
|
if global_step % args.image_log_steps == 0:
|
||||||
# get prompt from random batch
|
if rank == 0:
|
||||||
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
# get prompt from random batch
|
||||||
|
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
||||||
|
|
||||||
if args.image_log_scheduler == 'DDIMScheduler':
|
if args.image_log_scheduler == 'DDIMScheduler':
|
||||||
print('using DDIMScheduler scheduler')
|
print('using DDIMScheduler scheduler')
|
||||||
scheduler = DDIMScheduler(
|
scheduler = DDIMScheduler(
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print('using PNDMScheduler scheduler')
|
print('using PNDMScheduler scheduler')
|
||||||
scheduler=PNDMScheduler(
|
scheduler=PNDMScheduler(
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline(
|
pipeline = StableDiffusionPipeline(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
safety_checker=None, # disable safety checker to save memory
|
safety_checker=None, # disable safety checker to save memory
|
||||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
).to(device)
|
).to(device)
|
||||||
# inference
|
# inference
|
||||||
images = []
|
if args.enablewandb:
|
||||||
with torch.no_grad():
|
images = []
|
||||||
with torch.autocast('cuda', enabled=args.fp16):
|
else:
|
||||||
for _ in range(args.image_log_amount):
|
saveInferencePath = args.output_path + "/inference"
|
||||||
images.append(
|
os.makedirs(saveInferencePath, exist_ok=True)
|
||||||
wandb.Image(pipeline(
|
with torch.no_grad():
|
||||||
prompt, num_inference_steps=args.image_log_inference_steps
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
).images[0],
|
for _ in range(args.image_log_amount):
|
||||||
caption=prompt)
|
if args.enablewandb:
|
||||||
)
|
images.append(
|
||||||
# log images under single caption
|
wandb.Image(pipeline(
|
||||||
run.log({'images': images}, step=global_step)
|
prompt, num_inference_steps=args.image_log_inference_steps
|
||||||
|
).images[0],
|
||||||
|
caption=prompt)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from datetime import datetime
|
||||||
|
images = pipeline(prompt, num_inference_steps=args.image_log_inference_steps).images[0]
|
||||||
|
filenameImg = str(time.time_ns()) + ".png"
|
||||||
|
filenameTxt = str(time.time_ns()) + ".txt"
|
||||||
|
images.save(saveInferencePath + "/" + filenameImg)
|
||||||
|
with open(saveInferencePath + "/" + filenameTxt, 'a') as f:
|
||||||
|
f.write('Used prompt: ' + prompt + '\n')
|
||||||
|
f.write('Generated Image Filename: ' + filenameImg + '\n')
|
||||||
|
f.write('Generated at: ' + str(global_step) + ' steps' + '\n')
|
||||||
|
f.write('Generated at: ' + str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))+ '\n')
|
||||||
|
|
||||||
# cleanup so we don't run out of memory
|
# log images under single caption
|
||||||
del pipeline
|
if args.enablewandb:
|
||||||
gc.collect()
|
run.log({'images': images}, step=global_step)
|
||||||
torch.distributed.barrier()
|
|
||||||
|
# cleanup so we don't run out of memory
|
||||||
|
del pipeline
|
||||||
|
gc.collect()
|
||||||
|
torch.distributed.barrier()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}')
|
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}')
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue