2022-12-17 20:32:48 -07:00
|
|
|
"""
|
|
|
|
Copyright [2022] Victor C Hall
|
|
|
|
|
|
|
|
Licensed under the GNU Affero General Public License;
|
|
|
|
You may not use this code except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
https://www.gnu.org/licenses/agpl-3.0.en.html
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
"""
|
2022-12-27 12:25:32 -07:00
|
|
|
import logging
|
2022-12-17 20:32:48 -07:00
|
|
|
import torch
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from data.data_loader import DataLoaderMultiAspect as dlma
|
|
|
|
import math
|
|
|
|
import data.dl_singleton as dls
|
|
|
|
from data.image_train_item import ImageTrainItem
|
|
|
|
import random
|
|
|
|
from torchvision import transforms
|
|
|
|
from transformers import CLIPTokenizer
|
|
|
|
import torch.nn.functional as F
|
2023-01-01 08:45:18 -07:00
|
|
|
import numpy
|
2022-12-17 20:32:48 -07:00
|
|
|
|
|
|
|
class EveryDreamBatch(Dataset):
|
|
|
|
"""
|
|
|
|
data_root: root path of all your training images, will be recursively searched for images
|
|
|
|
repeats: how many times to repeat each image in the dataset
|
|
|
|
flip_p: probability of flipping the image horizontally
|
|
|
|
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
|
|
|
|
batch_size: how many images to return in a batch
|
|
|
|
conditional_dropout: probability of dropping the caption for a given image
|
|
|
|
resolution: max resolution (relative to square)
|
|
|
|
jitter: number of pixels to jitter the crop by, only for non-square images
|
|
|
|
"""
|
|
|
|
def __init__(self,
|
|
|
|
data_root,
|
|
|
|
flip_p=0.0,
|
|
|
|
debug_level=0,
|
|
|
|
batch_size=1,
|
|
|
|
conditional_dropout=0.02,
|
|
|
|
resolution=512,
|
|
|
|
crop_jitter=20,
|
|
|
|
seed=555,
|
|
|
|
tokenizer=None,
|
2022-12-27 12:25:32 -07:00
|
|
|
log_folder=None,
|
2023-01-01 08:45:18 -07:00
|
|
|
retain_contrast=False,
|
|
|
|
write_schedule=False,
|
2023-01-06 17:12:52 -07:00
|
|
|
shuffle_tags=False,
|
2023-01-14 06:00:30 -07:00
|
|
|
rated_dataset=False,
|
|
|
|
rated_dataset_dropout_target=0.5
|
2022-12-17 20:32:48 -07:00
|
|
|
):
|
|
|
|
self.data_root = data_root
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.debug_level = debug_level
|
|
|
|
self.conditional_dropout = conditional_dropout
|
|
|
|
self.crop_jitter = crop_jitter
|
|
|
|
self.unloaded_to_idx = 0
|
|
|
|
self.tokenizer = tokenizer
|
2022-12-27 12:25:32 -07:00
|
|
|
self.log_folder = log_folder
|
2022-12-17 20:32:48 -07:00
|
|
|
#print(f"tokenizer: {tokenizer}")
|
|
|
|
self.max_token_length = self.tokenizer.model_max_length
|
2023-01-01 08:45:18 -07:00
|
|
|
self.retain_contrast = retain_contrast
|
|
|
|
self.write_schedule = write_schedule
|
2023-01-06 17:12:52 -07:00
|
|
|
self.shuffle_tags = shuffle_tags
|
|
|
|
self.seed = seed
|
2023-01-14 06:00:30 -07:00
|
|
|
self.rated_dataset = rated_dataset
|
|
|
|
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
|
|
|
|
2022-12-17 20:32:48 -07:00
|
|
|
if seed == -1:
|
|
|
|
seed = random.randint(0, 99999)
|
|
|
|
|
|
|
|
if not dls.shared_dataloader:
|
2022-12-27 12:25:32 -07:00
|
|
|
logging.info(" * Creating new dataloader singleton")
|
|
|
|
dls.shared_dataloader = dlma(data_root=data_root,
|
|
|
|
seed=seed,
|
|
|
|
debug_level=debug_level,
|
|
|
|
batch_size=self.batch_size,
|
|
|
|
flip_p=flip_p,
|
|
|
|
resolution=resolution,
|
|
|
|
log_folder=self.log_folder,
|
|
|
|
)
|
2023-01-22 16:59:59 -07:00
|
|
|
|
2023-01-14 06:00:30 -07:00
|
|
|
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(1.0) # First epoch always trains on all images
|
2022-12-17 20:32:48 -07:00
|
|
|
|
2023-01-14 06:00:30 -07:00
|
|
|
num_images = len(self.image_train_items)
|
2022-12-17 20:32:48 -07:00
|
|
|
|
2023-01-14 06:00:30 -07:00
|
|
|
logging.info(f" ** Trainer Set: {num_images / batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
|
2023-01-01 08:45:18 -07:00
|
|
|
if self.write_schedule:
|
2023-01-14 06:00:30 -07:00
|
|
|
self.__write_batch_schedule(0)
|
2023-01-01 08:45:18 -07:00
|
|
|
|
2023-01-14 06:00:30 -07:00
|
|
|
def __write_batch_schedule(self, epoch_n):
|
2023-01-02 16:02:35 -07:00
|
|
|
with open(f"{self.log_folder}/ep{epoch_n}_batch_schedule.txt", "w", encoding='utf-8') as f:
|
2023-01-01 08:45:18 -07:00
|
|
|
for i in range(len(self.image_train_items)):
|
2023-01-02 16:02:35 -07:00
|
|
|
try:
|
2023-01-03 12:27:26 -07:00
|
|
|
f.write(f"step:{int(i / self.batch_size):05}, wh:{self.image_train_items[i].target_wh}, r:{self.image_train_items[i].runt_size}, path:{self.image_train_items[i].pathname}\n")
|
2023-01-02 16:02:35 -07:00
|
|
|
except Exception as e:
|
|
|
|
logging.error(f" * Error writing to batch schedule for file path: {self.image_train_items[i].pathname}")
|
2023-01-01 08:45:18 -07:00
|
|
|
|
|
|
|
def get_runts():
|
|
|
|
return dls.shared_dataloader.runts
|
|
|
|
|
2023-01-14 06:00:30 -07:00
|
|
|
def shuffle(self, epoch_n: int, max_epochs: int):
|
2023-01-07 11:57:23 -07:00
|
|
|
self.seed += 1
|
2023-01-01 08:45:18 -07:00
|
|
|
if dls.shared_dataloader:
|
2023-01-14 06:00:30 -07:00
|
|
|
if self.rated_dataset:
|
|
|
|
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
|
|
|
|
else:
|
|
|
|
dropout_fraction = 1.0
|
|
|
|
|
|
|
|
self.image_train_items = dls.shared_dataloader.get_shuffled_image_buckets(dropout_fraction)
|
2023-01-01 08:45:18 -07:00
|
|
|
else:
|
|
|
|
raise Exception("No dataloader singleton to shuffle")
|
|
|
|
|
|
|
|
if self.write_schedule:
|
2023-01-14 06:00:30 -07:00
|
|
|
self.__write_batch_schedule(epoch_n + 1)
|
2022-12-17 20:32:48 -07:00
|
|
|
|
|
|
|
def __len__(self):
|
2023-01-14 06:00:30 -07:00
|
|
|
return len(self.image_train_items)
|
2022-12-17 20:32:48 -07:00
|
|
|
|
|
|
|
def __getitem__(self, i):
|
|
|
|
example = {}
|
|
|
|
|
|
|
|
train_item = self.__get_image_for_trainer(self.image_train_items[i], self.debug_level)
|
2023-01-01 08:45:18 -07:00
|
|
|
|
|
|
|
if self.retain_contrast:
|
|
|
|
std_dev = 1.0
|
|
|
|
mean = 0.0
|
|
|
|
else:
|
|
|
|
std_dev = 0.5
|
|
|
|
mean = 0.5
|
|
|
|
|
|
|
|
image_transforms = transforms.Compose(
|
|
|
|
[
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize([mean], [std_dev]),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
2023-01-07 09:29:09 -07:00
|
|
|
if self.shuffle_tags:
|
|
|
|
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
|
|
|
|
else:
|
|
|
|
example["caption"] = train_item["caption"].get_caption()
|
2023-01-06 17:12:52 -07:00
|
|
|
|
2023-01-01 08:45:18 -07:00
|
|
|
example["image"] = image_transforms(train_item["image"])
|
|
|
|
|
|
|
|
if random.random() > self.conditional_dropout:
|
2023-01-07 09:29:09 -07:00
|
|
|
example["tokens"] = self.tokenizer(example["caption"],
|
2023-01-01 08:45:18 -07:00
|
|
|
truncation=True,
|
|
|
|
padding="max_length",
|
|
|
|
max_length=self.tokenizer.model_max_length,
|
|
|
|
).input_ids
|
|
|
|
else:
|
|
|
|
example["tokens"] = self.tokenizer(" ",
|
|
|
|
truncation=True,
|
|
|
|
padding="max_length",
|
|
|
|
max_length=self.tokenizer.model_max_length,
|
|
|
|
).input_ids
|
2023-01-06 17:12:52 -07:00
|
|
|
|
2022-12-20 01:30:42 -07:00
|
|
|
example["tokens"] = torch.tensor(example["tokens"])
|
2023-01-07 09:29:09 -07:00
|
|
|
|
2023-01-01 08:45:18 -07:00
|
|
|
example["runt_size"] = train_item["runt_size"]
|
2022-12-17 20:32:48 -07:00
|
|
|
|
|
|
|
return example
|
|
|
|
|
|
|
|
def __get_image_for_trainer(self, image_train_item: ImageTrainItem, debug_level=0):
|
|
|
|
example = {}
|
|
|
|
save = debug_level > 2
|
|
|
|
|
|
|
|
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
|
|
|
|
|
|
|
|
example["image"] = image_train_tmp.image
|
2023-01-01 08:45:18 -07:00
|
|
|
example["caption"] = image_train_tmp.caption
|
|
|
|
example["runt_size"] = image_train_tmp.runt_size
|
2022-12-17 20:32:48 -07:00
|
|
|
return example
|