2022-09-06 01:00:21 -06:00
import os
import numpy as np
import PIL
from PIL import Image
from torch . utils . data import Dataset
from torchvision import transforms
import random
training_templates_smallest = [
2022-09-26 19:06:04 -06:00
' joepenna {} ' ,
2022-09-06 01:00:21 -06:00
]
reg_templates_smallest = [
2022-09-17 17:00:31 -06:00
' {} ' ,
2022-09-06 01:00:21 -06:00
]
imagenet_templates_small = [
2022-09-17 17:00:31 -06:00
' {} ' ,
2022-09-06 01:00:21 -06:00
]
imagenet_dual_templates_small = [
2022-09-17 17:00:31 -06:00
' {} with {} '
2022-09-06 01:00:21 -06:00
]
per_img_token_list = [
' א ' , ' ב ' , ' ג ' , ' ד ' , ' ה ' , ' ו ' , ' ז ' , ' ח ' , ' ט ' , ' י ' , ' כ ' , ' ל ' , ' מ ' , ' נ ' , ' ס ' , ' ע ' , ' פ ' , ' צ ' , ' ק ' , ' ר ' , ' ש ' , ' ת ' ,
]
2022-09-26 19:06:04 -06:00
2022-09-06 01:00:21 -06:00
class PersonalizedBase ( Dataset ) :
def __init__ ( self ,
data_root ,
size = None ,
repeats = 100 ,
interpolation = " bicubic " ,
flip_p = 0.5 ,
set = " train " ,
placeholder_token = " dog " ,
per_image_tokens = False ,
center_crop = False ,
mixing_prob = 0.25 ,
coarse_class_text = None ,
2022-09-26 19:06:04 -06:00
reg = False
2022-09-06 01:00:21 -06:00
) :
self . data_root = data_root
2022-09-26 19:06:04 -06:00
self . image_paths = [ os . path . join (
self . data_root , file_path ) for file_path in os . listdir ( self . data_root ) ]
2022-09-06 01:00:21 -06:00
# self._length = len(self.image_paths)
self . num_images = len ( self . image_paths )
2022-09-26 19:06:04 -06:00
self . _length = self . num_images
2022-09-06 01:00:21 -06:00
self . placeholder_token = placeholder_token
self . per_image_tokens = per_image_tokens
self . center_crop = center_crop
self . mixing_prob = mixing_prob
self . coarse_class_text = coarse_class_text
if per_image_tokens :
2022-09-26 19:06:04 -06:00
assert self . num_images < len (
per_img_token_list ) , f " Can ' t use per-image tokens when the training set contains more than { len ( per_img_token_list ) } tokens. To enable larger sets, add more tokens to ' per_img_token_list ' . "
2022-09-06 01:00:21 -06:00
if set == " train " :
self . _length = self . num_images * repeats
self . size = size
self . interpolation = { " linear " : PIL . Image . LINEAR ,
" bilinear " : PIL . Image . BILINEAR ,
" bicubic " : PIL . Image . BICUBIC ,
" lanczos " : PIL . Image . LANCZOS ,
} [ interpolation ]
self . flip = transforms . RandomHorizontalFlip ( p = flip_p )
self . reg = reg
def __len__ ( self ) :
return self . _length
def __getitem__ ( self , i ) :
example = { }
image = Image . open ( self . image_paths [ i % self . num_images ] )
if not image . mode == " RGB " :
image = image . convert ( " RGB " )
placeholder_string = self . placeholder_token
if self . coarse_class_text :
placeholder_string = f " { self . coarse_class_text } { placeholder_string } "
if not self . reg :
2022-09-26 19:06:04 -06:00
text = random . choice ( training_templates_smallest ) . format (
placeholder_string )
2022-09-06 01:00:21 -06:00
else :
2022-09-26 19:06:04 -06:00
text = random . choice ( reg_templates_smallest ) . format (
placeholder_string )
2022-09-06 01:00:21 -06:00
example [ " caption " ] = text
# default to score-sde preprocessing
img = np . array ( image ) . astype ( np . uint8 )
2022-09-26 19:06:04 -06:00
2022-09-06 01:00:21 -06:00
if self . center_crop :
crop = min ( img . shape [ 0 ] , img . shape [ 1 ] )
h , w , = img . shape [ 0 ] , img . shape [ 1 ]
img = img [ ( h - crop ) / / 2 : ( h + crop ) / / 2 ,
2022-09-26 19:06:04 -06:00
( w - crop ) / / 2 : ( w + crop ) / / 2 ]
2022-09-06 01:00:21 -06:00
image = Image . fromarray ( img )
if self . size is not None :
2022-09-26 19:06:04 -06:00
image = image . resize ( ( self . size , self . size ) ,
resample = self . interpolation )
2022-09-06 01:00:21 -06:00
image = self . flip ( image )
image = np . array ( image ) . astype ( np . uint8 )
example [ " image " ] = ( image / 127.5 - 1.0 ) . astype ( np . float32 )
2022-09-26 19:06:04 -06:00
return example