commit
c19f450005
|
@ -91,6 +91,7 @@ parser.add_argument('--inference', dest='enableinference', type=bool_t, default=
|
||||||
parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
|
parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
|
||||||
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
|
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
|
||||||
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
|
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
|
||||||
|
parser.add_argument('--extended_mode_chunks', type=int, default=0, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -490,9 +491,11 @@ class AspectBucketSampler(torch.utils.data.Sampler):
|
||||||
return self.bucket.get_batch_count() // self.num_replicas
|
return self.bucket.get_batch_count() // self.num_replicas
|
||||||
|
|
||||||
class AspectDataset(torch.utils.data.Dataset):
|
class AspectDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1):
|
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, device: torch.device, ucg: float = 0.1):
|
||||||
self.store = store
|
self.store = store
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
self.text_encoder = text_encoder
|
||||||
|
self.device = device
|
||||||
self.ucg = ucg
|
self.ucg = ucg
|
||||||
|
|
||||||
self.transforms = torchvision.transforms.Compose([
|
self.transforms = torchvision.transforms.Compose([
|
||||||
|
@ -514,19 +517,70 @@ class AspectDataset(torch.utils.data.Dataset):
|
||||||
caption_file = self.store.get_caption(item)
|
caption_file = self.store.get_caption(item)
|
||||||
else:
|
else:
|
||||||
caption_file = ''
|
caption_file = ''
|
||||||
return_dict['input_ids'] = self.tokenizer(caption_file, max_length=self.tokenizer.model_max_length, padding='do_not_pad', truncation=True).input_ids
|
|
||||||
|
|
||||||
|
return_dict['input_ids'] = caption_file
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
def collate_fn(self, examples):
|
def collate_fn(self, examples):
|
||||||
pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None])
|
pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None])
|
||||||
pixel_values.to(memory_format=torch.contiguous_format).float()
|
pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||||
input_ids = [example['input_ids'] for example in examples if example is not None]
|
|
||||||
padded_tokens = self.tokenizer.pad({'input_ids': input_ids}, return_tensors='pt', padding=True)
|
if args.extended_mode_chunks < 2:
|
||||||
|
max_length = self.tokenizer.model_max_length - 2
|
||||||
|
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=max_length).input_ids for example in examples if example is not None]
|
||||||
|
else:
|
||||||
|
max_length = self.tokenizer.model_max_length
|
||||||
|
max_chunks = args.extended_mode_chunks
|
||||||
|
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for example in examples if example is not None]
|
||||||
|
|
||||||
|
tokens = input_ids
|
||||||
|
|
||||||
|
if args.extended_mode_chunks < 2:
|
||||||
|
for i, x in enumerate(input_ids):
|
||||||
|
for j, y in enumerate(x):
|
||||||
|
input_ids[i][j] = [self.tokenizer.bos_token_id, *y, *np.full((self.tokenizer.model_max_length - len(y) - 1), self.tokenizer.eos_token_id)]
|
||||||
|
|
||||||
|
if args.clip_penultimate:
|
||||||
|
input_ids = [self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_id).to(self.device), output_hidden_states=True)['hidden_states'][-2])[0] for input_id in input_ids]
|
||||||
|
else:
|
||||||
|
input_ids = [self.text_encoder(torch.asarray(input_id).to(self.device), output_hidden_states=True).last_hidden_state[0] for input_id in input_ids]
|
||||||
|
else:
|
||||||
|
max_standard_tokens = max_length - 2
|
||||||
|
max_chunks = args.extended_mode_chunks
|
||||||
|
max_len = np.ceil(max(len(x) for x in input_ids) / max_standard_tokens).astype(int).item() * max_standard_tokens
|
||||||
|
if max_len > max_standard_tokens:
|
||||||
|
z = None
|
||||||
|
for i, x in enumerate(input_ids):
|
||||||
|
if len(x) < max_len:
|
||||||
|
input_ids[i] = [*x, *np.full((max_len - len(x)), self.tokenizer.eos_token_id)]
|
||||||
|
batch_t = torch.tensor(input_ids)
|
||||||
|
chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)]
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk = torch.cat((torch.full((chunk.shape[0], 1), self.tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), self.tokenizer.eos_token_id)), 1)
|
||||||
|
if z is None:
|
||||||
|
if args.clip_penultimate:
|
||||||
|
z = self.text_encoder.text_model.final_layer_norm(self.text_encoder(chunk.to(self.device), output_hidden_states=True)['hidden_states'][-2])
|
||||||
|
else:
|
||||||
|
z = self.text_encoder(chunk.to(self.device), output_hidden_states=True).last_hidden_state
|
||||||
|
else:
|
||||||
|
if args.clip_penultimate:
|
||||||
|
z = torch.cat((z, self.text_encoder.text_model.final_layer_norm(self.text_encoder(chunk.to(self.device), output_hidden_states=True)['hidden_states'][-2])), dim=-2)
|
||||||
|
else:
|
||||||
|
z = torch.cat((z, self.text_encoder(chunk.to(self.device), output_hidden_states=True).last_hidden_state), dim=-2)
|
||||||
|
input_ids = z
|
||||||
|
else:
|
||||||
|
for i, x in enumerate(input_ids):
|
||||||
|
input_ids[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.max_length - len(x) - 1), self.tokenizer.eos_token_id)]
|
||||||
|
if args.clip_penultimate:
|
||||||
|
input_ids = self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True)['hidden_states'][-2])
|
||||||
|
else:
|
||||||
|
input_ids = self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True).last_hidden_state
|
||||||
|
input_ids = torch.stack(tuple(input_ids))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'pixel_values': pixel_values,
|
'pixel_values': pixel_values,
|
||||||
'input_ids': padded_tokens.input_ids,
|
'input_ids': input_ids,
|
||||||
'attention_mask': padded_tokens.attention_mask,
|
'tokens': tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||||
|
@ -739,7 +793,7 @@ def main():
|
||||||
|
|
||||||
# load dataset
|
# load dataset
|
||||||
store = ImageStore(args.dataset)
|
store = ImageStore(args.dataset)
|
||||||
dataset = AspectDataset(store, tokenizer, ucg=args.ucg)
|
dataset = AspectDataset(store, tokenizer, text_encoder, device, ucg=args.ucg)
|
||||||
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
|
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
|
||||||
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
|
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
|
||||||
|
|
||||||
|
@ -832,12 +886,8 @@ def main():
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
# Get the text embedding for conditioning
|
# Get the embedding for conditioning
|
||||||
encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
|
encoder_hidden_states = batch['input_ids']
|
||||||
if args.clip_penultimate:
|
|
||||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
|
|
||||||
else:
|
|
||||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
|
||||||
|
|
||||||
if noise_scheduler.config.prediction_type == "epsilon":
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
target = noise
|
target = noise
|
||||||
|
@ -898,7 +948,7 @@ def main():
|
||||||
if global_step % args.image_log_steps == 0 and global_step > 0:
|
if global_step % args.image_log_steps == 0 and global_step > 0:
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# get prompt from random batch
|
# get prompt from random batch
|
||||||
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
prompt = tokenizer.decode(batch['tokens'][random.randint(0, len(batch['tokens'])-1)])
|
||||||
|
|
||||||
if args.image_log_scheduler == 'DDIMScheduler':
|
if args.image_log_scheduler == 'DDIMScheduler':
|
||||||
print('using DDIMScheduler scheduler')
|
print('using DDIMScheduler scheduler')
|
||||||
|
|
Loading…
Reference in New Issue