Provide Tokens for Inference
This commit is contained in:
parent
981c6ca41a
commit
fb75cbe029
|
@ -533,6 +533,8 @@ class AspectDataset(torch.utils.data.Dataset):
|
||||||
max_chunks = args.extended_mode_chunks
|
max_chunks = args.extended_mode_chunks
|
||||||
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for example in examples if example is not None]
|
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for example in examples if example is not None]
|
||||||
|
|
||||||
|
tokens = input_ids
|
||||||
|
|
||||||
if args.extended_mode_chunks < 2:
|
if args.extended_mode_chunks < 2:
|
||||||
for i, x in enumerate(input_ids):
|
for i, x in enumerate(input_ids):
|
||||||
for j, y in enumerate(x):
|
for j, y in enumerate(x):
|
||||||
|
@ -578,6 +580,7 @@ class AspectDataset(torch.utils.data.Dataset):
|
||||||
return {
|
return {
|
||||||
'pixel_values': pixel_values,
|
'pixel_values': pixel_values,
|
||||||
'input_ids': input_ids,
|
'input_ids': input_ids,
|
||||||
|
'tokens': tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||||
|
@ -929,7 +932,7 @@ def main():
|
||||||
if global_step % args.image_log_steps == 0:
|
if global_step % args.image_log_steps == 0:
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# get prompt from random batch
|
# get prompt from random batch
|
||||||
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
prompt = tokenizer.decode(batch['tokens'][random.randint(0, len(batch['tokens'])-1)].tolist())
|
||||||
|
|
||||||
if args.image_log_scheduler == 'DDIMScheduler':
|
if args.image_log_scheduler == 'DDIMScheduler':
|
||||||
print('using DDIMScheduler scheduler')
|
print('using DDIMScheduler scheduler')
|
||||||
|
|
Loading…
Reference in New Issue