commit
c4978ac229
|
@ -108,19 +108,35 @@ class DataLoaderMultiAspect():
|
|||
batch_size = self.batch_size
|
||||
grad_accum = self.grad_accum
|
||||
|
||||
for image_caption_pair in picked_images:
|
||||
image_caption_pair.runt_size = 0
|
||||
bucket_key = (image_caption_pair.batch_id,
|
||||
image_caption_pair.target_wh[0],
|
||||
image_caption_pair.target_wh[1])
|
||||
def add_image_to_appropriate_bucket(image: ImageTrainItem, batch_id_override: str=None):
|
||||
bucket_key = (image.batch_id if batch_id_override is None else batch_id_override,
|
||||
image.target_wh[0],
|
||||
image.target_wh[1])
|
||||
if bucket_key not in buckets:
|
||||
buckets[bucket_key] = []
|
||||
buckets[bucket_key].append(image_caption_pair)
|
||||
buckets[bucket_key].append(image)
|
||||
|
||||
# handle runts by randomly duplicating items
|
||||
for image_caption_pair in picked_images:
|
||||
image_caption_pair.runt_size = 0
|
||||
add_image_to_appropriate_bucket(image_caption_pair)
|
||||
|
||||
# handled named batch runts by demoting them to the DEFAULT_BATCH_ID
|
||||
for key, bucket_contents in [(k, b) for k, b in buckets.items() if k[0] != DEFAULT_BATCH_ID]:
|
||||
runt_count = len(bucket_contents) % batch_size
|
||||
if runt_count == 0:
|
||||
continue
|
||||
runts = bucket_contents[-runt_count:]
|
||||
del bucket_contents[-runt_count:]
|
||||
for r in runts:
|
||||
add_image_to_appropriate_bucket(r, batch_id_override=DEFAULT_BATCH_ID)
|
||||
if len(bucket_contents) == 0:
|
||||
del buckets[key]
|
||||
|
||||
# handle remaining runts by randomly duplicating items
|
||||
for bucket in buckets:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
if truncate_count > 0:
|
||||
assert bucket[0] == DEFAULT_BATCH_ID, "there should be no more runts in named batches"
|
||||
runt_bucket = buckets[bucket][-truncate_count:]
|
||||
for item in runt_bucket:
|
||||
item.runt_size = truncate_count
|
||||
|
@ -132,9 +148,21 @@ class DataLoaderMultiAspect():
|
|||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||
buckets[bucket].extend(runt_bucket)
|
||||
|
||||
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
|
||||
# at this point items have a partially deterministic order
|
||||
# (in particular: rarer aspect ratios are more likely to cluster at the end due to stochastic sampling)
|
||||
# so we shuffle them to mitigate this, using chunked_shuffle to keep batches with the same aspect ratio together
|
||||
items_by_batch_id = {k: chunked_shuffle(v, chunk_size=batch_size, randomizer=randomizer)
|
||||
for k,v in items_by_batch_id.items()}
|
||||
# paranoia: verify that this hasn't fucked up the aspect ratio batching
|
||||
for items in items_by_batch_id.values():
|
||||
batches = chunk(items, chunk_size=batch_size)
|
||||
for batch in batches:
|
||||
target_wh = batch[0].target_wh
|
||||
assert all(target_wh == i.target_wh for i in batch[1:]), "mixed aspect ratios in a batch - this shouldn't happen"
|
||||
|
||||
# handle batch_id
|
||||
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
|
||||
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
|
||||
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
|
||||
batch_size=batch_size,
|
||||
grad_accum=grad_accum)
|
||||
|
@ -225,6 +253,8 @@ def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List
|
|||
Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk.
|
||||
If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk)
|
||||
"""
|
||||
if len(l) == 0:
|
||||
return []
|
||||
|
||||
# chunk by effective batch size
|
||||
chunks = chunk(l, chunk_size)
|
||||
|
|
|
@ -97,9 +97,13 @@ class EveryDreamValidator:
|
|||
self.config.update({'manual_data_root': self.config['val_data_root']})
|
||||
|
||||
if self.config.get('val_split_mode') == 'manual':
|
||||
if 'manual_data_root' not in self.config:
|
||||
raise ValueError("Error in validation config .json: 'manual' validation is missing 'manual_data_root'")
|
||||
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
|
||||
manual_data_root = self.config.get('manual_data_root')
|
||||
if manual_data_root is not None:
|
||||
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
|
||||
else:
|
||||
if len(self.config['extra_manual_datasets']) == 0:
|
||||
raise ValueError("Error in validation config .json: 'manual' validation requested but no "
|
||||
"'manual_data_root' or 'extra_manual_datasets'")
|
||||
|
||||
if 'val_split_proportion' in self.config:
|
||||
logging.warning(f" * {Fore.YELLOW}using old name 'val_split_proportion' for 'auto_split_proportion' - please "
|
||||
|
@ -167,15 +171,22 @@ class EveryDreamValidator:
|
|||
def do_validation(self, global_step: int,
|
||||
get_model_prediction_and_target_callable: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
mean_loss_accumulator = 0
|
||||
for i, dataset in enumerate(self.validation_datasets):
|
||||
mean_loss = self._calculate_validation_loss(dataset.name,
|
||||
dataset.dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
mean_loss_accumulator += mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
|
||||
scalar_value=mean_loss,
|
||||
global_step=global_step)
|
||||
dataset.track_loss_trend(mean_loss)
|
||||
|
||||
# log combine loss to loss/_all_val_combined
|
||||
if len(self.validation_datasets) > 1:
|
||||
total_mean_loss = mean_loss_accumulator / len(self.validation_datasets)
|
||||
self.log_writer.add_scalar(tag=f"loss/_all_val_combined",
|
||||
scalar_value=total_mean_loss,
|
||||
global_step=global_step)
|
||||
|
||||
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
|
||||
|
|
|
@ -176,7 +176,7 @@ class SampleGenerator:
|
|||
wants_random_caption=p.get('random_caption', False)
|
||||
) for p in sample_requests_config]
|
||||
if len(self.sample_requests) == 0:
|
||||
self._make_random_caption_sample_requests()
|
||||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
|
||||
|
|
Loading…
Reference in New Issue