Extended Mode Updates
This commit is contained in:
parent
b0cec788be
commit
ee281badcd
|
@ -91,6 +91,7 @@ parser.add_argument('--inference', dest='enableinference', type=bool_t, default=
|
|||
parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
|
||||
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
|
||||
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
|
||||
parser.add_argument('--extended_mode_chunks', type=int, default=0, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -490,9 +491,11 @@ class AspectBucketSampler(torch.utils.data.Sampler):
|
|||
return self.bucket.get_batch_count() // self.num_replicas
|
||||
|
||||
class AspectDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1):
|
||||
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, device: torch.device, ucg: float = 0.1):
|
||||
self.store = store
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.device = device
|
||||
self.ucg = ucg
|
||||
|
||||
self.transforms = torchvision.transforms.Compose([
|
||||
|
@ -514,19 +517,67 @@ class AspectDataset(torch.utils.data.Dataset):
|
|||
caption_file = self.store.get_caption(item)
|
||||
else:
|
||||
caption_file = ''
|
||||
return_dict['input_ids'] = self.tokenizer(caption_file, max_length=self.tokenizer.model_max_length, padding='do_not_pad', truncation=True).input_ids
|
||||
|
||||
return_dict['input_ids'] = caption_file
|
||||
return return_dict
|
||||
|
||||
def collate_fn(self, examples):
|
||||
pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None])
|
||||
pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
input_ids = [example['input_ids'] for example in examples if example is not None]
|
||||
padded_tokens = self.tokenizer.pad({'input_ids': input_ids}, return_tensors='pt', padding=True)
|
||||
|
||||
if args.extended_mode_chunks < 2:
|
||||
max_length = self.tokenizer.model_max_length - 2
|
||||
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=max_length).input_ids for example in examples if example is not None]
|
||||
else:
|
||||
max_length = self.tokenizer.model_max_length
|
||||
max_chunks = args.extended_mode_chunks
|
||||
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for example in examples if example is not None]
|
||||
|
||||
if args.extended_mode_chunks < 2:
|
||||
for i, x in enumerate(input_ids):
|
||||
for j, y in enumerate(x):
|
||||
input_ids[i][j] = [self.tokenizer.bos_token_id, *y, *np.full((self.tokenizer.model_max_length - len(y) - 1), self.tokenizer.eos_token_id)]
|
||||
|
||||
if args.clip_penultimate:
|
||||
input_ids = [self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_id).to(self.device), output_hidden_states=True)['hidden_states'][-2])[0] for input_id in input_ids]
|
||||
else:
|
||||
input_ids = [self.text_encoder(torch.asarray(input_id).to(self.device), output_hidden_states=True).last_hidden_state[0] for input_id in input_ids]
|
||||
else:
|
||||
max_standard_tokens = max_length - 2
|
||||
max_chunks = args.extended_mode_chunks
|
||||
max_len = np.ceil(max(len(x) for x in input_ids) / max_standard_tokens).astype(int).item() * max_standard_tokens
|
||||
if max_len > max_standard_tokens:
|
||||
z = None
|
||||
for i, x in enumerate(input_ids):
|
||||
if len(x) < max_len:
|
||||
input_ids[i] = [*x, *np.full((max_len - len(x)), self.tokenizer.eos_token_id)]
|
||||
batch_t = torch.tensor(input_ids)
|
||||
chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)]
|
||||
for chunk in chunks:
|
||||
chunk = torch.cat((torch.full((chunk.shape[0], 1), self.tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), self.tokenizer.eos_token_id)), 1)
|
||||
if z is None:
|
||||
if args.clip_penultimate:
|
||||
z = self.text_encoder.text_model.final_layer_norm(self.text_encoder(chunk.to(self.device), output_hidden_states=True)['hidden_states'][-2])
|
||||
else:
|
||||
z = self.text_encoder(chunk.to(self.device), output_hidden_states=True).last_hidden_state
|
||||
else:
|
||||
if args.clip_penultimate:
|
||||
z = torch.cat((z, self.text_encoder.text_model.final_layer_norm(self.text_encoder(chunk.to(self.device), output_hidden_states=True)['hidden_states'][-2])), dim=-2)
|
||||
else:
|
||||
z = torch.cat((z, self.text_encoder(chunk.to(self.device), output_hidden_states=True).last_hidden_state), dim=-2)
|
||||
input_ids = z
|
||||
else:
|
||||
for i, x in enumerate(input_ids):
|
||||
input_ids[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.max_length - len(x) - 1), self.tokenizer.eos_token_id)]
|
||||
if args.clip_penultimate:
|
||||
input_ids = self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True)['hidden_states'][-2])
|
||||
else:
|
||||
input_ids = self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True).last_hidden_state
|
||||
input_ids = torch.stack(tuple(input_ids))
|
||||
|
||||
return {
|
||||
'pixel_values': pixel_values,
|
||||
'input_ids': padded_tokens.input_ids,
|
||||
'attention_mask': padded_tokens.attention_mask,
|
||||
'input_ids': input_ids,
|
||||
}
|
||||
|
||||
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||
|
@ -719,7 +770,7 @@ def main():
|
|||
|
||||
# load dataset
|
||||
store = ImageStore(args.dataset)
|
||||
dataset = AspectDataset(store, tokenizer, ucg=args.ucg)
|
||||
dataset = AspectDataset(store, tokenizer, text_encoder, device, ucg=args.ucg)
|
||||
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
|
||||
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
|
||||
|
||||
|
@ -817,11 +868,12 @@ def main():
|
|||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
|
||||
if args.clip_penultimate:
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
|
||||
else:
|
||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||
#encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
|
||||
encoder_hidden_states = batch['input_ids']
|
||||
#if args.clip_penultimate:
|
||||
# encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
|
||||
#else:
|
||||
# encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
with torch.autocast('cuda', enabled=args.fp16):
|
||||
|
|
Loading…
Reference in New Issue