Provide Tokens for Inference

This commit is contained in:
cafeai 2022-12-01 04:46:01 +09:00
parent 981c6ca41a
commit fb75cbe029
1 changed files with 4 additions and 1 deletions

View File

@ -533,6 +533,8 @@ class AspectDataset(torch.utils.data.Dataset):
max_chunks = args.extended_mode_chunks
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for example in examples if example is not None]
tokens = input_ids
if args.extended_mode_chunks < 2:
for i, x in enumerate(input_ids):
for j, y in enumerate(x):
@ -578,6 +580,7 @@ class AspectDataset(torch.utils.data.Dataset):
return {
'pixel_values': pixel_values,
'input_ids': input_ids,
'tokens': tokens
}
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
@ -929,7 +932,7 @@ def main():
if global_step % args.image_log_steps == 0:
if rank == 0:
# get prompt from random batch
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
prompt = tokenizer.decode(batch['tokens'][random.randint(0, len(batch['tokens'])-1)].tolist())
if args.image_log_scheduler == 'DDIMScheduler':
print('using DDIMScheduler scheduler')