Merge pull request #54 from harubaru/extended-mode

Adding Extended Mode Functionality
This commit is contained in:
Anthony Mercurio 2022-11-30 22:55:30 -07:00 committed by GitHub
commit 138cb7bbed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 64 additions and 14 deletions

View File

@ -91,6 +91,7 @@ parser.add_argument('--inference', dest='enableinference', type=bool_t, default=
parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.') parser.add_argument('--extended_validation', type=bool_t, default='False', help='Perform extended validation of images to catch truncated or corrupt images.')
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.') parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.') parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
parser.add_argument('--extended_mode_chunks', type=int, default=0, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.')
args = parser.parse_args() args = parser.parse_args()
@ -490,9 +491,11 @@ class AspectBucketSampler(torch.utils.data.Sampler):
return self.bucket.get_batch_count() // self.num_replicas return self.bucket.get_batch_count() // self.num_replicas
class AspectDataset(torch.utils.data.Dataset): class AspectDataset(torch.utils.data.Dataset):
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1): def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, device: torch.device, ucg: float = 0.1):
self.store = store self.store = store
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.device = device
self.ucg = ucg self.ucg = ucg
self.transforms = torchvision.transforms.Compose([ self.transforms = torchvision.transforms.Compose([
@ -514,19 +517,70 @@ class AspectDataset(torch.utils.data.Dataset):
caption_file = self.store.get_caption(item) caption_file = self.store.get_caption(item)
else: else:
caption_file = '' caption_file = ''
return_dict['input_ids'] = self.tokenizer(caption_file, max_length=self.tokenizer.model_max_length, padding='do_not_pad', truncation=True).input_ids
return_dict['input_ids'] = caption_file
return return_dict return return_dict
def collate_fn(self, examples): def collate_fn(self, examples):
pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None]) pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None])
pixel_values.to(memory_format=torch.contiguous_format).float() pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = [example['input_ids'] for example in examples if example is not None]
padded_tokens = self.tokenizer.pad({'input_ids': input_ids}, return_tensors='pt', padding=True) if args.extended_mode_chunks < 2:
max_length = self.tokenizer.model_max_length - 2
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=max_length).input_ids for example in examples if example is not None]
else:
max_length = self.tokenizer.model_max_length
max_chunks = args.extended_mode_chunks
input_ids = [self.tokenizer([example['input_ids']], truncation=True, return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False, max_length=(max_length * max_chunks) - (max_chunks * 2)).input_ids[0] for example in examples if example is not None]
tokens = input_ids
if args.extended_mode_chunks < 2:
for i, x in enumerate(input_ids):
for j, y in enumerate(x):
input_ids[i][j] = [self.tokenizer.bos_token_id, *y, *np.full((self.tokenizer.model_max_length - len(y) - 1), self.tokenizer.eos_token_id)]
if args.clip_penultimate:
input_ids = [self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_id).to(self.device), output_hidden_states=True)['hidden_states'][-2])[0] for input_id in input_ids]
else:
input_ids = [self.text_encoder(torch.asarray(input_id).to(self.device), output_hidden_states=True).last_hidden_state[0] for input_id in input_ids]
else:
max_standard_tokens = max_length - 2
max_chunks = args.extended_mode_chunks
max_len = np.ceil(max(len(x) for x in input_ids) / max_standard_tokens).astype(int).item() * max_standard_tokens
if max_len > max_standard_tokens:
z = None
for i, x in enumerate(input_ids):
if len(x) < max_len:
input_ids[i] = [*x, *np.full((max_len - len(x)), self.tokenizer.eos_token_id)]
batch_t = torch.tensor(input_ids)
chunks = [batch_t[:, i:i + max_standard_tokens] for i in range(0, max_len, max_standard_tokens)]
for chunk in chunks:
chunk = torch.cat((torch.full((chunk.shape[0], 1), self.tokenizer.bos_token_id), chunk, torch.full((chunk.shape[0], 1), self.tokenizer.eos_token_id)), 1)
if z is None:
if args.clip_penultimate:
z = self.text_encoder.text_model.final_layer_norm(self.text_encoder(chunk.to(self.device), output_hidden_states=True)['hidden_states'][-2])
else:
z = self.text_encoder(chunk.to(self.device), output_hidden_states=True).last_hidden_state
else:
if args.clip_penultimate:
z = torch.cat((z, self.text_encoder.text_model.final_layer_norm(self.text_encoder(chunk.to(self.device), output_hidden_states=True)['hidden_states'][-2])), dim=-2)
else:
z = torch.cat((z, self.text_encoder(chunk.to(self.device), output_hidden_states=True).last_hidden_state), dim=-2)
input_ids = z
else:
for i, x in enumerate(input_ids):
input_ids[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.max_length - len(x) - 1), self.tokenizer.eos_token_id)]
if args.clip_penultimate:
input_ids = self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True)['hidden_states'][-2])
else:
input_ids = self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True).last_hidden_state
input_ids = torch.stack(tuple(input_ids))
return { return {
'pixel_values': pixel_values, 'pixel_values': pixel_values,
'input_ids': padded_tokens.input_ids, 'input_ids': input_ids,
'attention_mask': padded_tokens.attention_mask, 'tokens': tokens
} }
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
@ -739,7 +793,7 @@ def main():
# load dataset # load dataset
store = ImageStore(args.dataset) store = ImageStore(args.dataset)
dataset = AspectDataset(store, tokenizer, ucg=args.ucg) dataset = AspectDataset(store, tokenizer, text_encoder, device, ucg=args.ucg)
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0) bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank) sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
@ -832,12 +886,8 @@ def main():
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning # Get the embedding for conditioning
encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True) encoder_hidden_states = batch['input_ids']
if args.clip_penultimate:
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
if noise_scheduler.config.prediction_type == "epsilon": if noise_scheduler.config.prediction_type == "epsilon":
target = noise target = noise
@ -898,7 +948,7 @@ def main():
if global_step % args.image_log_steps == 0 and global_step > 0: if global_step % args.image_log_steps == 0 and global_step > 0:
if rank == 0: if rank == 0:
# get prompt from random batch # get prompt from random batch
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) prompt = tokenizer.decode(batch['tokens'][random.randint(0, len(batch['tokens'])-1)])
if args.image_log_scheduler == 'DDIMScheduler': if args.image_log_scheduler == 'DDIMScheduler':
print('using DDIMScheduler scheduler') print('using DDIMScheduler scheduler')