Merge pull request #32 from Maw-Fox/main
Add resize argument, various fixes
This commit is contained in:
commit
548ebea881
|
@ -37,7 +37,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMSc
|
|||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.optimization import get_scheduler
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from typing import Dict, List, Generator, Tuple
|
||||
from scipy.interpolate import interp1d
|
||||
|
@ -81,6 +81,7 @@ parser.add_argument('--image_log_inference_steps', type=int, default=50, help='N
|
|||
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
|
||||
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
|
||||
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
|
||||
parser.add_argument('--resize', type=bool, default=False, help="Resizes dataset's images to the appropriate bucket dimensions.")
|
||||
parser.add_argument('--use_xformers', type=bool, default=False, help='Use memory efficient attention')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -163,16 +164,16 @@ class ImageStore:
|
|||
# iterator returns images as PIL images and their index in the store
|
||||
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
|
||||
for f in range(len(self)):
|
||||
yield Image.open(self.image_files[f]), f
|
||||
yield Image.open(self.image_files[f]).convert(mode='RGB'), f
|
||||
|
||||
# get image by index
|
||||
def get_image(self, index: int) -> Image.Image:
|
||||
return Image.open(self.image_files[index])
|
||||
def get_image(self, ref: Tuple[int, int, int]) -> Image.Image:
|
||||
return Image.open(self.image_files[ref[0]]).convert(mode='RGB')
|
||||
|
||||
# gets caption by removing the extension from the filename and replacing it with .txt
|
||||
def get_caption(self, index: int) -> str:
|
||||
filename = re.sub('\.[^/.]+$', '', self.image_files[index]) + '.txt'
|
||||
with open(filename, 'r') as f:
|
||||
def get_caption(self, ref: Tuple[int, int, int]) -> str:
|
||||
filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt'
|
||||
with open(filename, 'r', encoding='UTF-8') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
|
@ -270,12 +271,12 @@ class AspectBucket:
|
|||
def get_bucket_info(self):
|
||||
return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios })
|
||||
|
||||
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]:
|
||||
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int, int]], None, None]:
|
||||
"""
|
||||
Generator that provides batches where the images in a batch fall on the same bucket
|
||||
|
||||
Each element generated will be:
|
||||
((w, h), [image1, image2, ..., image{batch_size}])
|
||||
(index, w, h)
|
||||
|
||||
where each image is an index into the dataset
|
||||
:return:
|
||||
|
@ -319,7 +320,7 @@ class AspectBucket:
|
|||
|
||||
total_generated_by_bucket[b] += self.batch_size
|
||||
bucket_pos[b] = i
|
||||
yield [idx for idx in batch]
|
||||
yield [(idx, *b) for idx in batch]
|
||||
|
||||
def fill_buckets(self):
|
||||
entries = self.store.entries_iterator()
|
||||
|
@ -384,16 +385,26 @@ class AspectDataset(torch.utils.data.Dataset):
|
|||
self.transforms = torchvision.transforms.Compose([
|
||||
torchvision.transforms.RandomHorizontalFlip(p=0.5),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize([0.5], [0.5]),
|
||||
torchvision.transforms.Normalize([0.5], [0.5])
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.store)
|
||||
|
||||
def __getitem__(self, item: int):
|
||||
def __getitem__(self, item: Tuple[int, int, int]):
|
||||
return_dict = {'pixel_values': None, 'input_ids': None}
|
||||
|
||||
image_file = self.store.get_image(item)
|
||||
|
||||
if args.resize:
|
||||
image_file = ImageOps.fit(
|
||||
image_file,
|
||||
(item[1], item[2]),
|
||||
bleed=0.0,
|
||||
centering=(0.5, 0.5),
|
||||
method=Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
return_dict['pixel_values'] = self.transforms(image_file)
|
||||
if random.random() > self.ucg:
|
||||
caption_file = self.store.get_caption(item)
|
||||
|
@ -580,7 +591,6 @@ def main():
|
|||
)
|
||||
|
||||
# load dataset
|
||||
|
||||
store = ImageStore(args.dataset)
|
||||
dataset = AspectDataset(store, tokenizer)
|
||||
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
|
||||
|
|
Loading…
Reference in New Issue