diff --git a/data/every_dream.py b/data/every_dream.py index 465b6fc..0d0cd41 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -16,15 +16,12 @@ limitations under the License. import logging import torch from torch.utils.data import Dataset -from data.data_loader import DataLoaderMultiAspect as dlma -import math -import data.dl_singleton as dls +from data.data_loader import DataLoaderMultiAspect from data.image_train_item import ImageTrainItem import random from torchvision import transforms from transformers import CLIPTokenizer import torch.nn.functional as F -import numpy class EveryDreamBatch(Dataset): """ @@ -38,7 +35,7 @@ class EveryDreamBatch(Dataset): jitter: number of pixels to jitter the crop by, only for non-square images """ def __init__(self, - data_loader: dlma, + data_loader: DataLoaderMultiAspect, debug_level=0, conditional_dropout=0.02, crop_jitter=20, @@ -59,7 +56,6 @@ class EveryDreamBatch(Dataset): self.unloaded_to_idx = 0 self.tokenizer = tokenizer self.log_folder = log_folder - #print(f"tokenizer: {tokenizer}") self.max_token_length = self.tokenizer.model_max_length self.retain_contrast = retain_contrast self.write_schedule = write_schedule @@ -67,8 +63,9 @@ class EveryDreamBatch(Dataset): self.seed = seed self.rated_dataset = rated_dataset self.rated_dataset_dropout_target = rated_dataset_dropout_target - self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images - + # First epoch always trains on all images + self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0) + num_images = len(self.image_train_items) logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}") @@ -83,20 +80,15 @@ class EveryDreamBatch(Dataset): except Exception as e: logging.error(f" * Error writing to batch schedule for file path: {self.image_train_items[i].pathname}") - def get_runts(): - return dls.shared_dataloader.runts - def shuffle(self, epoch_n: int, max_epochs: int): self.seed += 1 - if dls.shared_dataloader: - if self.rated_dataset: - dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs - else: - dropout_fraction = 1.0 - - self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(dropout_fraction) + + if self.rated_dataset: + dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs else: - raise Exception("No dataloader singleton to shuffle") + dropout_fraction = 1.0 + + self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction) if self.write_schedule: self.__write_batch_schedule(epoch_n + 1)