Extended Mode Updates

This commit is contained in:
cafeai 2022-12-01 04:31:20 +09:00
parent b0cec788be
commit ee281badcd
1 changed files with 64 additions and 12 deletions

View File

@ -91,6 +91,7 @@ parser.add_argument('--inference', dest='enableinference', type=bool_t, default=
parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
parser.add_argument('--extended_mode_chunks', type=int, default=0, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.')
args = parser.parse_args()
@ -490,9 +491,11 @@ class AspectBucketSampler(torch.utils.data.Sampler):
return self.bucket.get_batch_count() // self.num_replicas
class AspectDataset(torch.utils.data.Dataset):
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1):
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, device: torch.device, ucg: float = 0.1):
self.store = store
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.device = device
self.ucg = ucg
self.transforms = torchvision.transforms.Compose([
@ -514,19 +517,67 @@ class AspectDataset(torch.utils.data.Dataset):
caption_file = self.store.get_caption(item)
else:
caption_file = ''
return_dict['input_ids'] = self.tokenizer(caption_file, max_length=self.tokenizer.model_max_length, padding='do_not_pad', truncation=True).input_ids
return_dict['input_ids'] = caption_file
return return_dict
def collate_fn(self, examples):
pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None])
pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = [example['input_ids'] for example in examples if example is not None]
padded_tokens = self.tokenizer.pad({'input_ids': input_ids}, return_tensors='pt', padding=True)
if args.extended_mode_chunks < 2:
max_length = self.tokenizer.model_max_length - 2
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=max_length).input_ids for example in examples if example is not None]
else:
max_length = self.tokenizer.model_max_length
max_chunks = args.extended_mode_chunks
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for example in examples if example is not None]
if args.extended_mode_chunks < 2:
for i, x in enumerate(input_ids):
for j, y in enumerate(x):
input_ids[i][j] = [self.tokenizer.bos_token_id, *y, *np.full((self.tokenizer.model_max_length - len(y) - 1), self.tokenizer.eos_token_id)]
if args.clip_penultimate:
input_ids = [self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_id).to(self.device), output_hidden_states=True)['hidden_states'][-2])[0] for input_id in input_ids]
else:
input_ids = [self.text_encoder(torch.asarray(input_id).to(self.device), output_hidden_states=True).last_hidden_state[0] for input_id in input_ids]
else:
max_standard_tokens = max_length - 2
max_chunks = args.extended_mode_chunks
max_len = np.ceil(max(len(x) for x in input_ids) / max_standard_tokens).astype(int).item() * max_standard_tokens
if max_len > max_standard_tokens:
z = None
for i, x in enumerate(input_ids):
if len(x) < max_len:
input_ids[i] = [*x, *np.full((max_len - len(x)), self.tokenizer.eos_token_id)]
batch_t = torch.tensor(input_ids)
chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)]
for chunk in chunks:
chunk = torch.cat((torch.full((chunk.shape[0], 1), self.tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), self.tokenizer.eos_token_id)), 1)
if z is None:
if args.clip_penultimate:
z = self.text_encoder.text_model.final_layer_norm(self.text_encoder(chunk.to(self.device), output_hidden_states=True)['hidden_states'][-2])
else:
z = self.text_encoder(chunk.to(self.device), output_hidden_states=True).last_hidden_state
else:
if args.clip_penultimate:
z = torch.cat((z, self.text_encoder.text_model.final_layer_norm(self.text_encoder(chunk.to(self.device), output_hidden_states=True)['hidden_states'][-2])), dim=-2)
else:
z = torch.cat((z, self.text_encoder(chunk.to(self.device), output_hidden_states=True).last_hidden_state), dim=-2)
input_ids = z
else:
for i, x in enumerate(input_ids):
input_ids[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.max_length - len(x) - 1), self.tokenizer.eos_token_id)]
if args.clip_penultimate:
input_ids = self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True)['hidden_states'][-2])
else:
input_ids = self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True).last_hidden_state
input_ids = torch.stack(tuple(input_ids))
return {
'pixel_values': pixel_values,
'input_ids': padded_tokens.input_ids,
'attention_mask': padded_tokens.attention_mask,
'input_ids': input_ids,
}
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
@ -719,7 +770,7 @@ def main():
# load dataset
store = ImageStore(args.dataset)
dataset = AspectDataset(store, tokenizer, ucg=args.ucg)
dataset = AspectDataset(store, tokenizer, text_encoder, device, ucg=args.ucg)
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
@ -817,11 +868,12 @@ def main():
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
if args.clip_penultimate:
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
#encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
encoder_hidden_states = batch['input_ids']
#if args.clip_penultimate:
# encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
#else:
# encoder_hidden_states = encoder_hidden_states.last_hidden_state
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):