EveryDream2trainer/data/every_dream.py

155 lines
6.5 KiB
Python
Raw Normal View History

2022-12-17 20:32:48 -07:00
"""
Copyright [2022] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from torch.utils.data import Dataset
from data.data_loader import DataLoaderMultiAspect as dlma
import math
import data.dl_singleton as dls
from data.image_train_item import ImageTrainItem
import random
from torchvision import transforms
from transformers import CLIPTokenizer
import torch.nn.functional as F
class EveryDreamBatch(Dataset):
"""
data_root: root path of all your training images, will be recursively searched for images
repeats: how many times to repeat each image in the dataset
flip_p: probability of flipping the image horizontally
debug_level: 0=none, 1=print drops due to unfilled batches on aspect ratio buckets, 2=debug info per image, 3=save crops to disk for inspection
batch_size: how many images to return in a batch
conditional_dropout: probability of dropping the caption for a given image
resolution: max resolution (relative to square)
jitter: number of pixels to jitter the crop by, only for non-square images
"""
def __init__(self,
data_root,
flip_p=0.0,
debug_level=0,
batch_size=1,
conditional_dropout=0.02,
resolution=512,
crop_jitter=20,
seed=555,
tokenizer=None,
):
self.data_root = data_root
self.batch_size = batch_size
self.debug_level = debug_level
self.conditional_dropout = conditional_dropout
self.crop_jitter = crop_jitter
self.unloaded_to_idx = 0
self.tokenizer = tokenizer
#print(f"tokenizer: {tokenizer}")
self.max_token_length = self.tokenizer.model_max_length
if seed == -1:
seed = random.randint(0, 99999)
if not dls.shared_dataloader:
print(" * Creating new dataloader singleton")
dls.shared_dataloader = dlma(data_root=data_root, seed=seed, debug_level=debug_level, batch_size=self.batch_size, flip_p=flip_p, resolution=resolution)
self.image_train_items = dls.shared_dataloader.get_all_images()
# for iti in self.image_train_items:
# print(f"iti caption:{iti.caption}")
# exit()
self.num_images = len(self.image_train_items)
self._length = self.num_images
self.image_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
print()
print(f" ** Trainer Set: {self._length / batch_size:.0f}, num_images: {self.num_images}, batch_size: {self.batch_size}, length w/repeats: {self._length}")
print()
def __len__(self):
return self._length
def __getitem__(self, i):
#print(" * Getting item", i)
# batch = dict()
# batch["images"] = list()
# batch["captions"] = list()
# first = True
# for j in range(i, i + self.batch_size - 1):
# if j < self.num_images:
# example = self.__get_image_for_trainer(self.image_train_items[j], self.debug_level)
# if first:
# print(f"first example {j}", example)
# batch["images"] = [torch.from_numpy(example["image"])]
# batch["captions"] = [example["caption"]]
# first = False
# else:
# print(f"subsiquent example {j}", example)
# batch["images"].extend(torch.from_numpy(example["image"]))
# batch["captions"].extend(example["caption"])
example = {}
train_item = self.__get_image_for_trainer(self.image_train_items[i], self.debug_level)
#example["image"] = torch.from_numpy(train_item["image"])
example["image"] = self.image_transforms(train_item["image"])
# if train_item["caption"] == " ":
# example["tokens"] = [0 for i in range(self.max_token_length-2)]
# else:
if random.random() > self.conditional_dropout:
example["tokens"] = self.tokenizer(train_item["caption"],
#padding="max_length",
truncation=True,
padding=False,
add_special_tokens=False,
max_length=self.max_token_length-2,
).input_ids
example["tokens"] = torch.tensor(example["tokens"])
else:
example["tokens"] = torch.zeros(75, dtype=torch.int)
#print(f"bos: {self.tokenizer.bos_token_id}{self.tokenizer.eos_token_id}")
#print(f"example['tokens']: {example['tokens']}")
pad_amt = self.max_token_length-2 - len(example["tokens"])
example['tokens']= F.pad(example['tokens'],pad=(0,pad_amt),mode='constant',value=0)
example['tokens']= F.pad(example['tokens'],pad=(1,0),mode='constant',value=int(self.tokenizer.bos_token_id))
eos_int = int(self.tokenizer.eos_token_id)
#eos_int = int(0)
example['tokens']= F.pad(example['tokens'],pad=(0,1),mode='constant',value=eos_int)
#print(f"__getitem__ train_item['caption']: {train_item['caption']}")
#print(f"__getitem__ train_item['pathname']: {train_item['pathname']}")
#print(f"__getitem__ example['tokens'] pad: {example['tokens']}")
example["caption"] = train_item["caption"] # for sampling if needed
#print(f"len tokens: {len(example['tokens'])} cap: {example['caption']}")
return example
def __get_image_for_trainer(self, image_train_item: ImageTrainItem, debug_level=0):
example = {}
save = debug_level > 2
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
example["image"] = image_train_tmp.image
example["caption"] = image_train_tmp.caption
2022-12-18 19:16:14 -07:00
2022-12-17 20:32:48 -07:00
return example