Don't need to set data loader singleton; formatting tweaks
This commit is contained in:
parent
09d95fac58
commit
c0ec46c030
|
@ -16,15 +16,12 @@ limitations under the License.
|
|||
import logging
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from data.data_loader import DataLoaderMultiAspect as dlma
|
||||
import math
|
||||
import data.dl_singleton as dls
|
||||
from data.data_loader import DataLoaderMultiAspect
|
||||
from data.image_train_item import ImageTrainItem
|
||||
import random
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
import torch.nn.functional as F
|
||||
import numpy
|
||||
|
||||
class EveryDreamBatch(Dataset):
|
||||
"""
|
||||
|
@ -38,7 +35,7 @@ class EveryDreamBatch(Dataset):
|
|||
jitter: number of pixels to jitter the crop by, only for non-square images
|
||||
"""
|
||||
def __init__(self,
|
||||
data_loader: dlma,
|
||||
data_loader: DataLoaderMultiAspect,
|
||||
debug_level=0,
|
||||
conditional_dropout=0.02,
|
||||
crop_jitter=20,
|
||||
|
@ -59,7 +56,6 @@ class EveryDreamBatch(Dataset):
|
|||
self.unloaded_to_idx = 0
|
||||
self.tokenizer = tokenizer
|
||||
self.log_folder = log_folder
|
||||
#print(f"tokenizer: {tokenizer}")
|
||||
self.max_token_length = self.tokenizer.model_max_length
|
||||
self.retain_contrast = retain_contrast
|
||||
self.write_schedule = write_schedule
|
||||
|
@ -67,8 +63,9 @@ class EveryDreamBatch(Dataset):
|
|||
self.seed = seed
|
||||
self.rated_dataset = rated_dataset
|
||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
|
||||
|
||||
# First epoch always trains on all images
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0)
|
||||
|
||||
num_images = len(self.image_train_items)
|
||||
|
||||
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||
|
@ -83,20 +80,15 @@ class EveryDreamBatch(Dataset):
|
|||
except Exception as e:
|
||||
logging.error(f" * Error writing to batch schedule for file path: {self.image_train_items[i].pathname}")
|
||||
|
||||
def get_runts():
|
||||
return dls.shared_dataloader.runts
|
||||
|
||||
def shuffle(self, epoch_n: int, max_epochs: int):
|
||||
self.seed += 1
|
||||
if dls.shared_dataloader:
|
||||
if self.rated_dataset:
|
||||
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
|
||||
else:
|
||||
dropout_fraction = 1.0
|
||||
|
||||
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(dropout_fraction)
|
||||
|
||||
if self.rated_dataset:
|
||||
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
|
||||
else:
|
||||
raise Exception("No dataloader singleton to shuffle")
|
||||
dropout_fraction = 1.0
|
||||
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||
|
||||
if self.write_schedule:
|
||||
self.__write_batch_schedule(epoch_n + 1)
|
||||
|
|
Loading…
Reference in New Issue