Merge pull request #32 from Maw-Fox/main

Add resize argument, various fixes
This commit is contained in:
Anthony Mercurio 2022-11-10 09:22:48 -07:00 committed by GitHub
commit 548ebea881
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 13 deletions

View File

@ -37,7 +37,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMSc
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.optimization import get_scheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from PIL import Image
from PIL import Image, ImageOps
from typing import Dict, List, Generator, Tuple
from scipy.interpolate import interp1d
@ -81,6 +81,7 @@ parser.add_argument('--image_log_inference_steps', type=int, default=50, help='N
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
parser.add_argument('--resize', type=bool, default=False, help="Resizes dataset's images to the appropriate bucket dimensions.")
parser.add_argument('--use_xformers', type=bool, default=False, help='Use memory efficient attention')
args = parser.parse_args()
@ -163,16 +164,16 @@ class ImageStore:
# iterator returns images as PIL images and their index in the store
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
for f in range(len(self)):
yield Image.open(self.image_files[f]), f
yield Image.open(self.image_files[f]).convert(mode='RGB'), f
# get image by index
def get_image(self, index: int) -> Image.Image:
return Image.open(self.image_files[index])
def get_image(self, ref: Tuple[int, int, int]) -> Image.Image:
return Image.open(self.image_files[ref[0]]).convert(mode='RGB')
# gets caption by removing the extension from the filename and replacing it with .txt
def get_caption(self, index: int) -> str:
filename = re.sub('\.[^/.]+$', '', self.image_files[index]) + '.txt'
with open(filename, 'r') as f:
def get_caption(self, ref: Tuple[int, int, int]) -> str:
filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt'
with open(filename, 'r', encoding='UTF-8') as f:
return f.read()
@ -270,12 +271,12 @@ class AspectBucket:
def get_bucket_info(self):
return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios })
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]:
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int, int]], None, None]:
"""
Generator that provides batches where the images in a batch fall on the same bucket
Each element generated will be:
((w, h), [image1, image2, ..., image{batch_size}])
(index, w, h)
where each image is an index into the dataset
:return:
@ -319,7 +320,7 @@ class AspectBucket:
total_generated_by_bucket[b] += self.batch_size
bucket_pos[b] = i
yield [idx for idx in batch]
yield [(idx, *b) for idx in batch]
def fill_buckets(self):
entries = self.store.entries_iterator()
@ -384,16 +385,26 @@ class AspectDataset(torch.utils.data.Dataset):
self.transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5], [0.5]),
torchvision.transforms.Normalize([0.5], [0.5])
])
def __len__(self):
return len(self.store)
def __getitem__(self, item: int):
def __getitem__(self, item: Tuple[int, int, int]):
return_dict = {'pixel_values': None, 'input_ids': None}
image_file = self.store.get_image(item)
if args.resize:
image_file = ImageOps.fit(
image_file,
(item[1], item[2]),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
)
return_dict['pixel_values'] = self.transforms(image_file)
if random.random() > self.ucg:
caption_file = self.store.get_caption(item)
@ -580,7 +591,6 @@ def main():
)
# load dataset
store = ImageStore(args.dataset)
dataset = AspectDataset(store, tokenizer)
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)