diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 2b044bd..b36b09c 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -91,6 +91,7 @@ parser.add_argument('--inference', dest='enableinference', type=bool_t, default= parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.') parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with _cropped.') parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.') +parser.add_argument('--extended_mode_chunks', type=int, default=0, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.') args = parser.parse_args() @@ -490,9 +491,11 @@ class AspectBucketSampler(torch.utils.data.Sampler): return self.bucket.get_batch_count() // self.num_replicas class AspectDataset(torch.utils.data.Dataset): - def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1): + def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, device: torch.device, ucg: float = 0.1): self.store = store self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.device = device self.ucg = ucg self.transforms = torchvision.transforms.Compose([ @@ -514,19 +517,70 @@ class AspectDataset(torch.utils.data.Dataset): caption_file = self.store.get_caption(item) else: caption_file = '' - return_dict['input_ids'] = self.tokenizer(caption_file, max_length=self.tokenizer.model_max_length, padding='do_not_pad', truncation=True).input_ids + return_dict['input_ids'] = caption_file return return_dict def collate_fn(self, examples): pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None]) pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = [example['input_ids'] for example in examples if example is not None] - padded_tokens = self.tokenizer.pad({'input_ids': input_ids}, return_tensors='pt', padding=True) + + if args.extended_mode_chunks < 2: + max_length = self.tokenizer.model_max_length - 2 + input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=max_length).input_ids for example in examples if example is not None] + else: + max_length = self.tokenizer.model_max_length + max_chunks = args.extended_mode_chunks + input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for example in examples if example is not None] + + tokens = input_ids + + if args.extended_mode_chunks < 2: + for i, x in enumerate(input_ids): + for j, y in enumerate(x): + input_ids[i][j] = [self.tokenizer.bos_token_id, *y, *np.full((self.tokenizer.model_max_length - len(y) - 1), self.tokenizer.eos_token_id)] + + if args.clip_penultimate: + input_ids = [self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_id).to(self.device), output_hidden_states=True)['hidden_states'][-2])[0] for input_id in input_ids] + else: + input_ids = [self.text_encoder(torch.asarray(input_id).to(self.device), output_hidden_states=True).last_hidden_state[0] for input_id in input_ids] + else: + max_standard_tokens = max_length - 2 + max_chunks = args.extended_mode_chunks + max_len = np.ceil(max(len(x) for x in input_ids) / max_standard_tokens).astype(int).item() * max_standard_tokens + if max_len > max_standard_tokens: + z = None + for i, x in enumerate(input_ids): + if len(x) < max_len: + input_ids[i] = [*x, *np.full((max_len - len(x)), self.tokenizer.eos_token_id)] + batch_t = torch.tensor(input_ids) + chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)] + for chunk in chunks: + chunk = torch.cat((torch.full((chunk.shape[0], 1), self.tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), self.tokenizer.eos_token_id)), 1) + if z is None: + if args.clip_penultimate: + z = self.text_encoder.text_model.final_layer_norm(self.text_encoder(chunk.to(self.device), output_hidden_states=True)['hidden_states'][-2]) + else: + z = self.text_encoder(chunk.to(self.device), output_hidden_states=True).last_hidden_state + else: + if args.clip_penultimate: + z = torch.cat((z, self.text_encoder.text_model.final_layer_norm(self.text_encoder(chunk.to(self.device), output_hidden_states=True)['hidden_states'][-2])), dim=-2) + else: + z = torch.cat((z, self.text_encoder(chunk.to(self.device), output_hidden_states=True).last_hidden_state), dim=-2) + input_ids = z + else: + for i, x in enumerate(input_ids): + input_ids[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.max_length - len(x) - 1), self.tokenizer.eos_token_id)] + if args.clip_penultimate: + input_ids = self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True)['hidden_states'][-2]) + else: + input_ids = self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True).last_hidden_state + input_ids = torch.stack(tuple(input_ids)) + return { 'pixel_values': pixel_values, - 'input_ids': padded_tokens.input_ids, - 'attention_mask': padded_tokens.attention_mask, + 'input_ids': input_ids, + 'tokens': tokens } # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 @@ -739,7 +793,7 @@ def main(): # load dataset store = ImageStore(args.dataset) - dataset = AspectDataset(store, tokenizer, ucg=args.ucg) + dataset = AspectDataset(store, tokenizer, text_encoder, device, ucg=args.ucg) bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0) sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank) @@ -832,12 +886,8 @@ def main(): # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True) - if args.clip_penultimate: - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2]) - else: - encoder_hidden_states = encoder_hidden_states.last_hidden_state + # Get the embedding for conditioning + encoder_hidden_states = batch['input_ids'] if noise_scheduler.config.prediction_type == "epsilon": target = noise @@ -898,7 +948,7 @@ def main(): if global_step % args.image_log_steps == 0 and global_step > 0: if rank == 0: # get prompt from random batch - prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) + prompt = tokenizer.decode(batch['tokens'][random.randint(0, len(batch['tokens'])-1)]) if args.image_log_scheduler == 'DDIMScheduler': print('using DDIMScheduler scheduler')