Don't need to set data loader singleton; formatting tweaks

This commit is contained in:
Joel Holdbrooks 2023-01-29 17:31:57 -08:00
parent 09d95fac58
commit c0ec46c030
1 changed files with 11 additions and 19 deletions

View File

@ -16,15 +16,12 @@ limitations under the License.
import logging
import torch
from torch.utils.data import Dataset
from data.data_loader import DataLoaderMultiAspect as dlma
import math
import data.dl_singleton as dls
from data.data_loader import DataLoaderMultiAspect
from data.image_train_item import ImageTrainItem
import random
from torchvision import transforms
from transformers import CLIPTokenizer
import torch.nn.functional as F
import numpy
class EveryDreamBatch(Dataset):
"""
@ -38,7 +35,7 @@ class EveryDreamBatch(Dataset):
jitter: number of pixels to jitter the crop by, only for non-square images
"""
def __init__(self,
data_loader: dlma,
data_loader: DataLoaderMultiAspect,
debug_level=0,
conditional_dropout=0.02,
crop_jitter=20,
@ -59,7 +56,6 @@ class EveryDreamBatch(Dataset):
self.unloaded_to_idx = 0
self.tokenizer = tokenizer
self.log_folder = log_folder
#print(f"tokenizer: {tokenizer}")
self.max_token_length = self.tokenizer.model_max_length
self.retain_contrast = retain_contrast
self.write_schedule = write_schedule
@ -67,8 +63,9 @@ class EveryDreamBatch(Dataset):
self.seed = seed
self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target
self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
# First epoch always trains on all images
self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0)
num_images = len(self.image_train_items)
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
@ -83,20 +80,15 @@ class EveryDreamBatch(Dataset):
except Exception as e:
logging.error(f" * Error writing to batch schedule for file path: {self.image_train_items[i].pathname}")
def get_runts():
return dls.shared_dataloader.runts
def shuffle(self, epoch_n: int, max_epochs: int):
self.seed += 1
if dls.shared_dataloader:
if self.rated_dataset:
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
else:
dropout_fraction = 1.0
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(dropout_fraction)
if self.rated_dataset:
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
else:
raise Exception("No dataloader singleton to shuffle")
dropout_fraction = 1.0
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
if self.write_schedule:
self.__write_batch_schedule(epoch_n + 1)