EveryDream2trainer/data/every_dream.py

236 lines
8.9 KiB
Python

"""
Copyright [2022] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import os
import torch
from torch.utils.data import Dataset
from data.data_loader import DataLoaderMultiAspect
from data.image_train_item import ImageTrainItem
import random
from torchvision import transforms
from transformers import CLIPTokenizer
import torch.nn.functional as F
from plugins.plugins import PluginRunner
class EveryDreamBatch(Dataset):
"""
data_loader: `DataLoaderMultiAspect` object
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
conditional_dropout: probability of dropping the caption for a given image
crop_jitter: percent of maximum cropping for crop jitter (ex 0.02 is two percent)
seed: random seed
"""
def __init__(self,
data_loader: DataLoaderMultiAspect,
debug_level=0,
conditional_dropout=0.02,
crop_jitter=0.02,
seed=555,
tokenizer=None,
shuffle_tags=False,
keep_tags=0,
plugin_runner:PluginRunner=None,
rated_dataset=False,
rated_dataset_dropout_target=0.5,
name='train'
):
self.data_loader = data_loader
self.batch_size = data_loader.batch_size
self.debug_level = debug_level
self.conditional_dropout = conditional_dropout
self.crop_jitter = crop_jitter
self.unloaded_to_idx = 0
self.tokenizer = tokenizer
self.max_token_length = self.tokenizer.model_max_length
self.shuffle_tags = shuffle_tags
self.keep_tags = keep_tags
self.plugin_runner = plugin_runner
self.seed = seed
self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target
# First epoch always trains on all images
self.image_train_items = []
self.__update_image_train_items(1.0)
self.name = name
num_images = len(self.image_train_items)
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
def shuffle(self, epoch_n: int, max_epochs: int):
self.seed += 1
if self.rated_dataset:
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
else:
dropout_fraction = 1.0
self.__update_image_train_items(dropout_fraction)
def __len__(self):
return len(self.image_train_items)
def __getitem__(self, i):
example = {}
train_item = self.__get_image_for_trainer(self.image_train_items[i], self.debug_level)
std_dev = 0.5
mean = 0.5
image_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([mean], [std_dev]),
]
)
if self.shuffle_tags or train_item["shuffle_tags"]:
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed, keep_tags=self.keep_tags)
else:
example["caption"] = train_item["caption"].get_caption()
example["image"] = self.plugin_runner.run_transform_pil_image(train_item["image"])
example["image"] = image_transforms(example["image"])
example["caption"] = self.plugin_runner.run_transform_caption(example["caption"])
if random.random() > (train_item.get("cond_dropout", self.conditional_dropout)):
example["tokens"] = self.tokenizer(example["caption"],
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
).input_ids
else:
example["tokens"] = self.tokenizer(" ",
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
).input_ids
example["tokens"] = torch.tensor(example["tokens"])
example["runt_size"] = train_item["runt_size"]
example["loss_scale"] = train_item["loss_scale"]
return example
def __get_image_for_trainer(self, image_train_item: ImageTrainItem, debug_level=0):
example = {}
save = debug_level > 2
image_train_tmp = image_train_item.hydrate(save=save, crop_jitter=self.crop_jitter)
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
image_train_tmp.image = None # hack for now to avoid memory leak
example["caption"] = image_train_tmp.caption
if image_train_tmp.cond_dropout is not None:
example["cond_dropout"] = image_train_tmp.cond_dropout
example["runt_size"] = image_train_tmp.runt_size
example["shuffle_tags"] = image_train_tmp.shuffle_tags
example["loss_scale"] = image_train_tmp.loss_scale
return example
def __update_image_train_items(self, dropout_fraction: float):
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
class DataLoaderWithFixedBuffer(torch.utils.data.DataLoader):
def __init__(self, dataset, buffer_tensor, batch_size:int, max_pixels: int, buffer_dtype: torch.dtype, device="cuda"):
color_channels = 3
buffer_size = batch_size * color_channels * max_pixels
self.buffer_size = buffer_size
buffer_tensor = torch.empty(buffer_size, dtype=buffer_dtype, device=device).pin_memory()
self.buffer_tensor = buffer_tensor
logging.info(f"buffer_tensor created with shape: {buffer_tensor.shape}")
super().__init__(dataset, batch_size=batch_size, shuffle=False, num_workers=min(batch_size, os.cpu_count()), collate_fn=self.fixed_collate_fn)
def fixed_collate_fn(self, batch):
"""
Collates images to a pinned buffer returned as a view using actual resolution shape
"""
images = [example["image"] for example in batch]
# map the image data to the fixed buffer view
w, h = images[0].size
for i in range(self.batch_size):
image = batch["image"][i]
self.buffer_tensor[i*self.buffer_size//self.batch_size:(i+1)*self.buffer_size//self.batch_size] = image.view(-1)
images = self.buffer_tensor.view(self.batch_size, 3, w, h)
captions = [example["caption"] for example in batch]
tokens = [example["tokens"] for example in batch]
runt_size = batch[0]["runt_size"]
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
ret = {
"tokens": torch.stack(tuple(tokens)),
"image": images,
"captions": captions,
"runt_size": runt_size,
}
del batch
return ret
def build_torch_dataloader2(dataset, batch_size, max_pixels) -> torch.utils.data.DataLoader:
dataloader = DataLoaderWithFixedBuffer(
dataset,
max_pixels=max_pixels,
batch_size=batch_size,
shuffle=False,
num_workers=min(batch_size, os.cpu_count()),
collate_fn=collate_fn
)
return dataloader
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size= batch_size,
shuffle=False,
num_workers=min(batch_size, os.cpu_count()),
collate_fn=collate_fn
)
return dataloader
def collate_fn(batch):
"""
Collates batches
"""
images = [example["image"] for example in batch]
captions = [example["caption"] for example in batch]
tokens = [example["tokens"] for example in batch]
runt_size = batch[0]["runt_size"]
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
loss_scale = torch.tensor([example.get("loss_scale", 1) for example in batch])
ret = {
"tokens": torch.stack(tuple(tokens)),
"image": images,
"captions": captions,
"runt_size": runt_size,
"loss_scale": loss_scale
}
del batch
return ret