From e4872fdc0c2a903acfa6b9bad23bcd887e10afb6 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 14 Jun 2023 09:18:18 +0200 Subject: [PATCH 1/9] improve shuffle and runt handling for named buckets --- data/data_loader.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 1ab9b67..12706e8 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -117,8 +117,19 @@ class DataLoaderMultiAspect(): buckets[bucket_key] = [] buckets[bucket_key].append(image_caption_pair) - # handle runts by randomly duplicating items + # handled named batch runts by demoting them to the DEFAULT_BATCH_ID + for key, bucket in [(k, b) for k, b in buckets.items() if k[0] != DEFAULT_BATCH_ID]: + runt_count = len(bucket) % batch_size + if runt_count == 0: + continue + runts = bucket[-runt_count:] + del bucket[-runt_count:] + matching_default_bucket_key = [DEFAULT_BATCH_ID, key[1], key[2]] + buckets[matching_default_bucket_key].extend(runts) + + # handle remaining runts by randomly duplicating items for bucket in buckets: + assert bucket[0] == DEFAULT_BATCH_ID, "there should be no more runts in named batches" truncate_count = len(buckets[bucket]) % batch_size if truncate_count > 0: runt_bucket = buckets[bucket][-truncate_count:] @@ -132,9 +143,21 @@ class DataLoaderMultiAspect(): buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count] buckets[bucket].extend(runt_bucket) + items_by_batch_id = collapse_buckets_by_batch_id(buckets) + # at this point items have a partially deterministic order + # (in particular: rarer aspect ratios are more likely to cluster at the end due to stochastic sampling) + # so we shuffle them to mitigate this, using chunked_shuffle to keep batches with the same aspect ratio together + items_by_batch_id = {k: chunked_shuffle(v, chunk_size=batch_size, randomizer=randomizer) + for k,v in items_by_batch_id.items()} + # paranoia: verify that this hasn't fucked up the aspect ratio batching + for items in items_by_batch_id.values(): + batches = chunk(items, chunk_size=batch_size) + for batch in batches: + target_wh = batch[0].target_wh + assert all(target_wh == i.target_wh for i in batch[1:]), "mixed aspect ratios in a batch - this shouldn't happen" + # handle batch_id # unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID. - items_by_batch_id = collapse_buckets_by_batch_id(buckets) items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id, batch_size=batch_size, grad_accum=grad_accum) From bd89ad96d27023c18967d9a0d6c383dcfb26ef49 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 14 Jun 2023 09:28:40 +0200 Subject: [PATCH 2/9] typo --- data/data_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/data_loader.py b/data/data_loader.py index 12706e8..3d18049 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -124,7 +124,7 @@ class DataLoaderMultiAspect(): continue runts = bucket[-runt_count:] del bucket[-runt_count:] - matching_default_bucket_key = [DEFAULT_BATCH_ID, key[1], key[2]] + matching_default_bucket_key = (DEFAULT_BATCH_ID, key[1], key[2]) buckets[matching_default_bucket_key].extend(runts) # handle remaining runts by randomly duplicating items From 8885d58efd05d641a1d32a6648554c7c15639d12 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 14 Jun 2023 09:30:23 +0200 Subject: [PATCH 3/9] create bucket if it doesn't exist --- data/data_loader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/data/data_loader.py b/data/data_loader.py index 3d18049..fa31192 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -125,6 +125,8 @@ class DataLoaderMultiAspect(): runts = bucket[-runt_count:] del bucket[-runt_count:] matching_default_bucket_key = (DEFAULT_BATCH_ID, key[1], key[2]) + if matching_default_bucket_key not in buckets: + buckets[matching_default_bucket_key] = [] buckets[matching_default_bucket_key].extend(runts) # handle remaining runts by randomly duplicating items From ded73f088f41830539ffa0a3732ae063362997ea Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 14 Jun 2023 09:42:39 +0200 Subject: [PATCH 4/9] cleanup --- data/data_loader.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index fa31192..35fddc1 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -108,32 +108,35 @@ class DataLoaderMultiAspect(): batch_size = self.batch_size grad_accum = self.grad_accum - for image_caption_pair in picked_images: - image_caption_pair.runt_size = 0 - bucket_key = (image_caption_pair.batch_id, - image_caption_pair.target_wh[0], - image_caption_pair.target_wh[1]) + def add_image_to_appropriate_bucket(image: ImageTrainItem, batch_id_override: str=None): + bucket_key = (image.batch_id if batch_id_override is None else batch_id_override, + image.target_wh[0], + image.target_wh[1]) if bucket_key not in buckets: buckets[bucket_key] = [] - buckets[bucket_key].append(image_caption_pair) + buckets[bucket_key].append(image) + + for image_caption_pair in picked_images: + image_caption_pair.runt_size = 0 + add_image_to_appropriate_bucket(image_caption_pair) # handled named batch runts by demoting them to the DEFAULT_BATCH_ID - for key, bucket in [(k, b) for k, b in buckets.items() if k[0] != DEFAULT_BATCH_ID]: - runt_count = len(bucket) % batch_size + for key, bucket_contents in [(k, b) for k, b in buckets.items() if k[0] != DEFAULT_BATCH_ID]: + runt_count = len(bucket_contents) % batch_size if runt_count == 0: continue - runts = bucket[-runt_count:] - del bucket[-runt_count:] - matching_default_bucket_key = (DEFAULT_BATCH_ID, key[1], key[2]) - if matching_default_bucket_key not in buckets: - buckets[matching_default_bucket_key] = [] - buckets[matching_default_bucket_key].extend(runts) + runts = bucket_contents[-runt_count:] + del bucket_contents[-runt_count:] + for r in runts: + add_image_to_appropriate_bucket(r, batch_id_override=DEFAULT_BATCH_ID) + if len(bucket_contents) == 0: + del buckets[key] # handle remaining runts by randomly duplicating items for bucket in buckets: - assert bucket[0] == DEFAULT_BATCH_ID, "there should be no more runts in named batches" truncate_count = len(buckets[bucket]) % batch_size if truncate_count > 0: + assert bucket[0] == DEFAULT_BATCH_ID, "there should be no more runts in named batches" runt_bucket = buckets[bucket][-truncate_count:] for item in runt_bucket: item.runt_size = truncate_count From 403f7ddf0714202151dcf04aebb749e56aed40f5 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 14 Jun 2023 09:44:37 +0200 Subject: [PATCH 5/9] fix chunked_shuffle crash with empty list --- data/data_loader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/data/data_loader.py b/data/data_loader.py index 35fddc1..78d8ffe 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -253,6 +253,8 @@ def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk. If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk) """ + if len(l) == 0: + return [] # chunk by effective batch size chunks = chunk(l, chunk_size) From a8c6a5911155e14b5b6a3c27af54f390e6f3e5e3 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sat, 17 Jun 2023 09:27:05 +0200 Subject: [PATCH 6/9] fix bug with empty samples.txt --- utils/sample_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/sample_generator.py b/utils/sample_generator.py index 6607b7b..dd13bf8 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -176,7 +176,7 @@ class SampleGenerator: wants_random_caption=p.get('random_caption', False) ) for p in sample_requests_config] if len(self.sample_requests) == 0: - self._make_random_caption_sample_requests() + self.sample_requests = self._make_random_caption_sample_requests() @torch.no_grad() def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int): From 6c60b76fb6da5562791a702f1702afeec0cc530c Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sat, 17 Jun 2023 09:27:18 +0200 Subject: [PATCH 7/9] log combined loss if there are >1 val subsets --- data/every_dream_validation.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 362c35b..f8a7887 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -167,15 +167,22 @@ class EveryDreamValidator: def do_validation(self, global_step: int, get_model_prediction_and_target_callable: Callable[ [Any, Any], tuple[torch.Tensor, torch.Tensor]]): + mean_loss_accumulator = 0 for i, dataset in enumerate(self.validation_datasets): mean_loss = self._calculate_validation_loss(dataset.name, dataset.dataloader, get_model_prediction_and_target_callable) + mean_loss_accumulator += mean_loss self.log_writer.add_scalar(tag=f"loss/{dataset.name}", scalar_value=mean_loss, global_step=global_step) dataset.track_loss_trend(mean_loss) - + # log combine loss to loss/_all_val_combined + if len(self.validation_datasets) > 1: + total_mean_loss = mean_loss_accumulator / len(self.validation_datasets) + self.log_writer.add_scalar(tag=f"loss/_all_val_combined", + scalar_value=total_mean_loss, + global_step=global_step) def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[ [Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float: From dd6b37840d6e6d442205d8b70931407b7f30f0e2 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sat, 17 Jun 2023 09:31:28 +0200 Subject: [PATCH 8/9] permit empty manual_data_root --- data/every_dream_validation.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index f8a7887..97a99a4 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -97,9 +97,12 @@ class EveryDreamValidator: self.config.update({'manual_data_root': self.config['val_data_root']}) if self.config.get('val_split_mode') == 'manual': - if 'manual_data_root' not in self.config: - raise ValueError("Error in validation config .json: 'manual' validation is missing 'manual_data_root'") - self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']}) + if 'manual_data_root' in self.config: + self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']}) + else: + if len(self.config['extra_manual_datasets']) == 0: + raise ValueError("Error in validation config .json: 'manual' validation requested but no " + "'manual_data_root' or 'extra_manual_datasets'") if 'val_split_proportion' in self.config: logging.warning(f" * {Fore.YELLOW}using old name 'val_split_proportion' for 'auto_split_proportion' - please " From 1ab27a90592d05ee690a51c915910774f9e810be Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sat, 17 Jun 2023 10:38:34 +0200 Subject: [PATCH 9/9] better check for null manual_data_root --- data/every_dream_validation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 97a99a4..e3fa770 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -97,7 +97,8 @@ class EveryDreamValidator: self.config.update({'manual_data_root': self.config['val_data_root']}) if self.config.get('val_split_mode') == 'manual': - if 'manual_data_root' in self.config: + manual_data_root = self.config.get('manual_data_root') + if manual_data_root is not None: self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']}) else: if len(self.config['extra_manual_datasets']) == 0: