Add options and local inference

Added options to:
- Disable Inference (it consumes about 2gb of VRAM even when not active)
- Disable wandb

and:
- if no hftoken is provided it just fills it with nothing so it doesn't argues
- if wandb is not enabled, save the inference outputs to a local folder along with information about it
This commit is contained in:
Carlos Chavez 2022-11-14 22:08:16 -05:00 committed by GitHub
parent ae561d19f7
commit d600078008
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 72 additions and 43 deletions

View File

@ -84,6 +84,8 @@ parser.add_argument('--clip_penultimate', type=bool_t, default='False', help='Us
parser.add_argument('--output_bucket_info', type=bool_t, default='False', help='Outputs bucket information and exits')
parser.add_argument('--resize', type=bool_t, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
parser.add_argument('--use_xformers', type=bool_t, default='False', help='Use memory efficient attention')
parser.add_argument('--wandb', dest='enablewandb', type=str, default='True', help='Enable WeightsAndBiases Reporting')
parser.add_argument('--inference', dest='enableinference', type=str, default='True', help='Enable Inference during training (Consumes 2GB of VRAM)')
args = parser.parse_args()
def setup():
@ -520,7 +522,11 @@ def main():
if rank == 0:
os.makedirs(args.output_path, exist_ok=True)
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb')
if args.enablewandb:
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb')
else:
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode="disabled")
# Inform the user of host, and various versions -- useful for debugging issues.
print("RUN_NAME:", args.run_name)
@ -534,8 +540,12 @@ def main():
print("RESOLUTION:", args.resolution)
if args.hf_token is None:
args.hf_token = os.environ['HF_API_TOKEN']
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.')
try:
args.hf_token = os.environ['HF_API_TOKEN']
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.')
except Exception:
print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)")
args.hf_token = "none"
device = torch.device('cuda')
@ -744,49 +754,68 @@ def main():
if global_step % args.save_steps == 0:
save_checkpoint(global_step)
if global_step % args.image_log_steps == 0:
if rank == 0:
# get prompt from random batch
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
if args.enableinference:
if global_step % args.image_log_steps == 0:
if rank == 0:
# get prompt from random batch
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
if args.image_log_scheduler == 'DDIMScheduler':
print('using DDIMScheduler scheduler')
scheduler = DDIMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
else:
print('using PNDMScheduler scheduler')
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
)
if args.image_log_scheduler == 'DDIMScheduler':
print('using DDIMScheduler scheduler')
scheduler = DDIMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
else:
print('using PNDMScheduler scheduler')
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
)
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=None, # disable safety checker to save memory
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
).to(device)
# inference
images = []
with torch.no_grad():
with torch.autocast('cuda', enabled=args.fp16):
for _ in range(args.image_log_amount):
images.append(
wandb.Image(pipeline(
prompt, num_inference_steps=args.image_log_inference_steps
).images[0],
caption=prompt)
)
# log images under single caption
run.log({'images': images}, step=global_step)
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=None, # disable safety checker to save memory
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
).to(device)
# inference
if args.enablewandb:
images = []
else:
saveInferencePath = args.output_path + "/inference"
os.makedirs(saveInferencePath, exist_ok=True)
with torch.no_grad():
with torch.autocast('cuda', enabled=args.fp16):
for _ in range(args.image_log_amount):
if args.enablewandb:
images.append(
wandb.Image(pipeline(
prompt, num_inference_steps=args.image_log_inference_steps
).images[0],
caption=prompt)
)
else:
from datetime import datetime
images = pipeline(prompt, num_inference_steps=args.image_log_inference_steps).images[0]
filenameImg = str(time.time_ns()) + ".png"
filenameTxt = str(time.time_ns()) + ".txt"
images.save(saveInferencePath + "/" + filenameImg)
with open(saveInferencePath + "/" + filenameTxt, 'a') as f:
f.write('Used prompt: ' + prompt + '\n')
f.write('Generated Image Filename: ' + filenameImg + '\n')
f.write('Generated at: ' + str(global_step) + ' steps' + '\n')
f.write('Generated at: ' + str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))+ '\n')
# cleanup so we don't run out of memory
del pipeline
gc.collect()
torch.distributed.barrier()
# log images under single caption
if args.enablewandb:
run.log({'images': images}, step=global_step)
# cleanup so we don't run out of memory
del pipeline
gc.collect()
torch.distributed.barrier()
except Exception as e:
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}')
pass