Merge pull request #32 from Maw-Fox/main

Add resize argument, various fixes
This commit is contained in:
Anthony Mercurio 2022-11-10 09:22:48 -07:00 committed by GitHub
commit 548ebea881
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 13 deletions

View File

@ -37,7 +37,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMSc
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from PIL import Image from PIL import Image, ImageOps
from typing import Dict, List, Generator, Tuple from typing import Dict, List, Generator, Tuple
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
@ -81,6 +81,7 @@ parser.add_argument('--image_log_inference_steps', type=int, default=50, help='N
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.') parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding') parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits') parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
parser.add_argument('--resize', type=bool, default=False, help="Resizes dataset's images to the appropriate bucket dimensions.")
parser.add_argument('--use_xformers', type=bool, default=False, help='Use memory efficient attention') parser.add_argument('--use_xformers', type=bool, default=False, help='Use memory efficient attention')
args = parser.parse_args() args = parser.parse_args()
@ -163,16 +164,16 @@ class ImageStore:
# iterator returns images as PIL images and their index in the store # iterator returns images as PIL images and their index in the store
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]: def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
for f in range(len(self)): for f in range(len(self)):
yield Image.open(self.image_files[f]), f yield Image.open(self.image_files[f]).convert(mode='RGB'), f
# get image by index # get image by index
def get_image(self, index: int) -> Image.Image: def get_image(self, ref: Tuple[int, int, int]) -> Image.Image:
return Image.open(self.image_files[index]) return Image.open(self.image_files[ref[0]]).convert(mode='RGB')
# gets caption by removing the extension from the filename and replacing it with .txt # gets caption by removing the extension from the filename and replacing it with .txt
def get_caption(self, index: int) -> str: def get_caption(self, ref: Tuple[int, int, int]) -> str:
filename = re.sub('\.[^/.]+$', '', self.image_files[index]) + '.txt' filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt'
with open(filename, 'r') as f: with open(filename, 'r', encoding='UTF-8') as f:
return f.read() return f.read()
@ -270,12 +271,12 @@ class AspectBucket:
def get_bucket_info(self): def get_bucket_info(self):
return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios }) return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios })
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]: def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int, int]], None, None]:
""" """
Generator that provides batches where the images in a batch fall on the same bucket Generator that provides batches where the images in a batch fall on the same bucket
Each element generated will be: Each element generated will be:
((w, h), [image1, image2, ..., image{batch_size}]) (index, w, h)
where each image is an index into the dataset where each image is an index into the dataset
:return: :return:
@ -319,7 +320,7 @@ class AspectBucket:
total_generated_by_bucket[b] += self.batch_size total_generated_by_bucket[b] += self.batch_size
bucket_pos[b] = i bucket_pos[b] = i
yield [idx for idx in batch] yield [(idx, *b) for idx in batch]
def fill_buckets(self): def fill_buckets(self):
entries = self.store.entries_iterator() entries = self.store.entries_iterator()
@ -384,16 +385,26 @@ class AspectDataset(torch.utils.data.Dataset):
self.transforms = torchvision.transforms.Compose([ self.transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(p=0.5), torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(), torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5], [0.5]), torchvision.transforms.Normalize([0.5], [0.5])
]) ])
def __len__(self): def __len__(self):
return len(self.store) return len(self.store)
def __getitem__(self, item: int): def __getitem__(self, item: Tuple[int, int, int]):
return_dict = {'pixel_values': None, 'input_ids': None} return_dict = {'pixel_values': None, 'input_ids': None}
image_file = self.store.get_image(item) image_file = self.store.get_image(item)
if args.resize:
image_file = ImageOps.fit(
image_file,
(item[1], item[2]),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
)
return_dict['pixel_values'] = self.transforms(image_file) return_dict['pixel_values'] = self.transforms(image_file)
if random.random() > self.ucg: if random.random() > self.ucg:
caption_file = self.store.get_caption(item) caption_file = self.store.get_caption(item)
@ -580,7 +591,6 @@ def main():
) )
# load dataset # load dataset
store = ImageStore(args.dataset) store = ImageStore(args.dataset)
dataset = AspectDataset(store, tokenizer) dataset = AspectDataset(store, tokenizer)
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0) bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)