Merge pull request #196 from damian0815/fixes_misc

misc small fixes
This commit is contained in:
Victor Hall 2023-06-18 00:48:50 -04:00 committed by GitHub
commit c4978ac229
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 13 deletions

View File

@ -108,19 +108,35 @@ class DataLoaderMultiAspect():
batch_size = self.batch_size
grad_accum = self.grad_accum
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
bucket_key = (image_caption_pair.batch_id,
image_caption_pair.target_wh[0],
image_caption_pair.target_wh[1])
def add_image_to_appropriate_bucket(image: ImageTrainItem, batch_id_override: str=None):
bucket_key = (image.batch_id if batch_id_override is None else batch_id_override,
image.target_wh[0],
image.target_wh[1])
if bucket_key not in buckets:
buckets[bucket_key] = []
buckets[bucket_key].append(image_caption_pair)
buckets[bucket_key].append(image)
# handle runts by randomly duplicating items
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
add_image_to_appropriate_bucket(image_caption_pair)
# handled named batch runts by demoting them to the DEFAULT_BATCH_ID
for key, bucket_contents in [(k, b) for k, b in buckets.items() if k[0] != DEFAULT_BATCH_ID]:
runt_count = len(bucket_contents) % batch_size
if runt_count == 0:
continue
runts = bucket_contents[-runt_count:]
del bucket_contents[-runt_count:]
for r in runts:
add_image_to_appropriate_bucket(r, batch_id_override=DEFAULT_BATCH_ID)
if len(bucket_contents) == 0:
del buckets[key]
# handle remaining runts by randomly duplicating items
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
assert bucket[0] == DEFAULT_BATCH_ID, "there should be no more runts in named batches"
runt_bucket = buckets[bucket][-truncate_count:]
for item in runt_bucket:
item.runt_size = truncate_count
@ -132,9 +148,21 @@ class DataLoaderMultiAspect():
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket)
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
# at this point items have a partially deterministic order
# (in particular: rarer aspect ratios are more likely to cluster at the end due to stochastic sampling)
# so we shuffle them to mitigate this, using chunked_shuffle to keep batches with the same aspect ratio together
items_by_batch_id = {k: chunked_shuffle(v, chunk_size=batch_size, randomizer=randomizer)
for k,v in items_by_batch_id.items()}
# paranoia: verify that this hasn't fucked up the aspect ratio batching
for items in items_by_batch_id.values():
batches = chunk(items, chunk_size=batch_size)
for batch in batches:
target_wh = batch[0].target_wh
assert all(target_wh == i.target_wh for i in batch[1:]), "mixed aspect ratios in a batch - this shouldn't happen"
# handle batch_id
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
batch_size=batch_size,
grad_accum=grad_accum)
@ -225,6 +253,8 @@ def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List
Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk.
If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk)
"""
if len(l) == 0:
return []
# chunk by effective batch size
chunks = chunk(l, chunk_size)

View File

@ -97,9 +97,13 @@ class EveryDreamValidator:
self.config.update({'manual_data_root': self.config['val_data_root']})
if self.config.get('val_split_mode') == 'manual':
if 'manual_data_root' not in self.config:
raise ValueError("Error in validation config .json: 'manual' validation is missing 'manual_data_root'")
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
manual_data_root = self.config.get('manual_data_root')
if manual_data_root is not None:
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
else:
if len(self.config['extra_manual_datasets']) == 0:
raise ValueError("Error in validation config .json: 'manual' validation requested but no "
"'manual_data_root' or 'extra_manual_datasets'")
if 'val_split_proportion' in self.config:
logging.warning(f" * {Fore.YELLOW}using old name 'val_split_proportion' for 'auto_split_proportion' - please "
@ -167,15 +171,22 @@ class EveryDreamValidator:
def do_validation(self, global_step: int,
get_model_prediction_and_target_callable: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
mean_loss_accumulator = 0
for i, dataset in enumerate(self.validation_datasets):
mean_loss = self._calculate_validation_loss(dataset.name,
dataset.dataloader,
get_model_prediction_and_target_callable)
mean_loss_accumulator += mean_loss
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
scalar_value=mean_loss,
global_step=global_step)
dataset.track_loss_trend(mean_loss)
# log combine loss to loss/_all_val_combined
if len(self.validation_datasets) > 1:
total_mean_loss = mean_loss_accumulator / len(self.validation_datasets)
self.log_writer.add_scalar(tag=f"loss/_all_val_combined",
scalar_value=total_mean_loss,
global_step=global_step)
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:

View File

@ -176,7 +176,7 @@ class SampleGenerator:
wants_random_caption=p.get('random_caption', False)
) for p in sample_requests_config]
if len(self.sample_requests) == 0:
self._make_random_caption_sample_requests()
self.sample_requests = self._make_random_caption_sample_requests()
@torch.no_grad()
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):