From fb75cbe0291c411ae4a1254c300fb41f247667dc Mon Sep 17 00:00:00 2001 From: cafeai <116491182+cafeai@users.noreply.github.com> Date: Thu, 1 Dec 2022 04:46:01 +0900 Subject: [PATCH] Provide Tokens for Inference --- trainer/diffusers_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 7463044..f65f3cb 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -533,6 +533,8 @@ class AspectDataset(torch.utils.data.Dataset): max_chunks = args.extended_mode_chunks input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for example in examples if example is not None] + tokens = input_ids + if args.extended_mode_chunks < 2: for i, x in enumerate(input_ids): for j, y in enumerate(x): @@ -578,6 +580,7 @@ class AspectDataset(torch.utils.data.Dataset): return { 'pixel_values': pixel_values, 'input_ids': input_ids, + 'tokens': tokens } # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 @@ -929,7 +932,7 @@ def main(): if global_step % args.image_log_steps == 0: if rank == 0: # get prompt from random batch - prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) + prompt = tokenizer.decode(batch['tokens'][random.randint(0, len(batch['tokens'])-1)].tolist()) if args.image_log_scheduler == 'DDIMScheduler': print('using DDIMScheduler scheduler')