2022-12-17 20:32:48 -07:00
|
|
|
"""
|
|
|
|
Copyright [2022] Victor C Hall
|
|
|
|
|
|
|
|
Licensed under the GNU Affero General Public License;
|
|
|
|
You may not use this code except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
https://www.gnu.org/licenses/agpl-3.0.en.html
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
limitations under the License.
|
|
|
|
"""
|
2022-12-27 12:25:32 -07:00
|
|
|
import logging
|
2022-12-17 20:32:48 -07:00
|
|
|
import torch
|
|
|
|
from torch.utils.data import Dataset
|
2023-01-29 18:31:57 -07:00
|
|
|
from data.data_loader import DataLoaderMultiAspect
|
2022-12-17 20:32:48 -07:00
|
|
|
from data.image_train_item import ImageTrainItem
|
|
|
|
import random
|
|
|
|
from torchvision import transforms
|
|
|
|
from transformers import CLIPTokenizer
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
class EveryDreamBatch(Dataset):
|
|
|
|
"""
|
2023-01-29 18:58:42 -07:00
|
|
|
data_loader: `DataLoaderMultiAspect` object
|
2023-02-28 19:14:19 -07:00
|
|
|
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
|
2022-12-17 20:32:48 -07:00
|
|
|
conditional_dropout: probability of dropping the caption for a given image
|
2023-01-29 18:58:42 -07:00
|
|
|
crop_jitter: number of pixels to jitter the crop by, only for non-square images
|
|
|
|
seed: random seed
|
2022-12-17 20:32:48 -07:00
|
|
|
"""
|
|
|
|
def __init__(self,
|
2023-01-29 18:31:57 -07:00
|
|
|
data_loader: DataLoaderMultiAspect,
|
2022-12-17 20:32:48 -07:00
|
|
|
debug_level=0,
|
|
|
|
conditional_dropout=0.02,
|
|
|
|
crop_jitter=20,
|
|
|
|
seed=555,
|
|
|
|
tokenizer=None,
|
2023-02-28 19:14:19 -07:00
|
|
|
retain_contrast=False,
|
2023-01-06 17:12:52 -07:00
|
|
|
shuffle_tags=False,
|
2023-01-14 06:00:30 -07:00
|
|
|
rated_dataset=False,
|
2023-02-06 23:10:34 -07:00
|
|
|
rated_dataset_dropout_target=0.5,
|
|
|
|
name='train'
|
2022-12-17 20:32:48 -07:00
|
|
|
):
|
2023-01-29 18:08:54 -07:00
|
|
|
self.data_loader = data_loader
|
|
|
|
self.batch_size = data_loader.batch_size
|
2022-12-17 20:32:48 -07:00
|
|
|
self.debug_level = debug_level
|
|
|
|
self.conditional_dropout = conditional_dropout
|
|
|
|
self.crop_jitter = crop_jitter
|
|
|
|
self.unloaded_to_idx = 0
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
self.max_token_length = self.tokenizer.model_max_length
|
2023-02-28 19:14:19 -07:00
|
|
|
self.retain_contrast = retain_contrast
|
2023-01-06 17:12:52 -07:00
|
|
|
self.shuffle_tags = shuffle_tags
|
|
|
|
self.seed = seed
|
2023-01-14 06:00:30 -07:00
|
|
|
self.rated_dataset = rated_dataset
|
|
|
|
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
2023-01-29 18:31:57 -07:00
|
|
|
# First epoch always trains on all images
|
2023-02-06 23:10:34 -07:00
|
|
|
self.image_train_items = []
|
|
|
|
self.__update_image_train_items(1.0)
|
|
|
|
self.name = name
|
2023-01-29 18:31:57 -07:00
|
|
|
|
2023-01-14 06:00:30 -07:00
|
|
|
num_images = len(self.image_train_items)
|
2023-02-07 10:08:19 -07:00
|
|
|
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
|
2023-01-01 08:45:18 -07:00
|
|
|
|
2023-01-14 06:00:30 -07:00
|
|
|
def shuffle(self, epoch_n: int, max_epochs: int):
|
2023-01-07 11:57:23 -07:00
|
|
|
self.seed += 1
|
2023-01-29 18:31:57 -07:00
|
|
|
|
|
|
|
if self.rated_dataset:
|
|
|
|
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
|
2023-01-01 08:45:18 -07:00
|
|
|
else:
|
2023-01-29 18:31:57 -07:00
|
|
|
dropout_fraction = 1.0
|
2023-02-06 23:10:34 -07:00
|
|
|
|
|
|
|
self.__update_image_train_items(dropout_fraction)
|
2023-01-01 08:45:18 -07:00
|
|
|
|
2022-12-17 20:32:48 -07:00
|
|
|
def __len__(self):
|
2023-01-14 06:00:30 -07:00
|
|
|
return len(self.image_train_items)
|
2022-12-17 20:32:48 -07:00
|
|
|
|
|
|
|
def __getitem__(self, i):
|
|
|
|
example = {}
|
|
|
|
|
2023-02-28 19:14:19 -07:00
|
|
|
train_item = self.__get_image_for_trainer(self.image_train_items[i], self.debug_level)
|
|
|
|
|
|
|
|
if self.retain_contrast:
|
|
|
|
std_dev = 1.0
|
|
|
|
mean = 0.0
|
|
|
|
else:
|
|
|
|
std_dev = 0.5
|
|
|
|
mean = 0.5
|
2023-01-01 08:45:18 -07:00
|
|
|
|
|
|
|
image_transforms = transforms.Compose(
|
|
|
|
[
|
|
|
|
transforms.ToTensor(),
|
2023-02-28 19:14:19 -07:00
|
|
|
transforms.Normalize([mean], [std_dev]),
|
2023-01-01 08:45:18 -07:00
|
|
|
]
|
|
|
|
)
|
|
|
|
|
2023-01-07 09:29:09 -07:00
|
|
|
if self.shuffle_tags:
|
|
|
|
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
|
|
|
|
else:
|
2023-02-28 19:14:19 -07:00
|
|
|
example["caption"] = train_item["caption"].get_caption()
|
2023-01-06 17:12:52 -07:00
|
|
|
|
2023-01-01 08:45:18 -07:00
|
|
|
example["image"] = image_transforms(train_item["image"])
|
|
|
|
|
|
|
|
if random.random() > self.conditional_dropout:
|
2023-01-07 09:29:09 -07:00
|
|
|
example["tokens"] = self.tokenizer(example["caption"],
|
2023-01-01 08:45:18 -07:00
|
|
|
truncation=True,
|
|
|
|
padding="max_length",
|
|
|
|
max_length=self.tokenizer.model_max_length,
|
|
|
|
).input_ids
|
|
|
|
else:
|
|
|
|
example["tokens"] = self.tokenizer(" ",
|
|
|
|
truncation=True,
|
|
|
|
padding="max_length",
|
|
|
|
max_length=self.tokenizer.model_max_length,
|
|
|
|
).input_ids
|
2023-01-06 17:12:52 -07:00
|
|
|
|
2022-12-20 01:30:42 -07:00
|
|
|
example["tokens"] = torch.tensor(example["tokens"])
|
2023-01-07 09:29:09 -07:00
|
|
|
|
2023-01-01 08:45:18 -07:00
|
|
|
example["runt_size"] = train_item["runt_size"]
|
2022-12-17 20:32:48 -07:00
|
|
|
|
|
|
|
return example
|
|
|
|
|
2023-02-28 19:14:19 -07:00
|
|
|
def __get_image_for_trainer(self, image_train_item: ImageTrainItem, debug_level=0):
|
2022-12-17 20:32:48 -07:00
|
|
|
example = {}
|
2023-02-28 19:14:19 -07:00
|
|
|
save = debug_level > 2
|
2022-12-17 20:32:48 -07:00
|
|
|
|
2023-02-28 19:14:19 -07:00
|
|
|
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
|
2022-12-17 20:32:48 -07:00
|
|
|
|
2023-02-25 13:05:22 -07:00
|
|
|
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
|
|
|
|
image_train_tmp.image = None # hack for now to avoid memory leak
|
2023-01-01 08:45:18 -07:00
|
|
|
example["caption"] = image_train_tmp.caption
|
|
|
|
example["runt_size"] = image_train_tmp.runt_size
|
2023-02-25 13:05:22 -07:00
|
|
|
|
2022-12-17 20:32:48 -07:00
|
|
|
return example
|
2023-02-06 23:10:34 -07:00
|
|
|
|
|
|
|
def __update_image_train_items(self, dropout_fraction: float):
|
|
|
|
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
|
|
|
|
2023-02-07 09:32:54 -07:00
|
|
|
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
2023-02-06 23:10:34 -07:00
|
|
|
dataloader = torch.utils.data.DataLoader(
|
2023-02-07 09:32:54 -07:00
|
|
|
dataset,
|
2023-02-06 23:10:34 -07:00
|
|
|
batch_size=batch_size,
|
|
|
|
shuffle=False,
|
2023-02-07 09:32:54 -07:00
|
|
|
num_workers=4,
|
2023-02-06 23:10:34 -07:00
|
|
|
collate_fn=collate_fn
|
|
|
|
)
|
|
|
|
return dataloader
|
|
|
|
|
|
|
|
|
|
|
|
def collate_fn(batch):
|
|
|
|
"""
|
|
|
|
Collates batches
|
|
|
|
"""
|
|
|
|
images = [example["image"] for example in batch]
|
|
|
|
captions = [example["caption"] for example in batch]
|
|
|
|
tokens = [example["tokens"] for example in batch]
|
|
|
|
runt_size = batch[0]["runt_size"]
|
|
|
|
|
|
|
|
images = torch.stack(images)
|
|
|
|
images = images.to(memory_format=torch.contiguous_format).float()
|
|
|
|
|
|
|
|
ret = {
|
|
|
|
"tokens": torch.stack(tuple(tokens)),
|
|
|
|
"image": images,
|
|
|
|
"captions": captions,
|
|
|
|
"runt_size": runt_size,
|
|
|
|
}
|
|
|
|
del batch
|
2023-02-07 09:32:54 -07:00
|
|
|
return ret
|