Merge pull request #32 from Maw-Fox/main
Add resize argument, various fixes
This commit is contained in:
commit
548ebea881
|
@ -37,7 +37,7 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMSc
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from typing import Dict, List, Generator, Tuple
|
from typing import Dict, List, Generator, Tuple
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
|
@ -81,6 +81,7 @@ parser.add_argument('--image_log_inference_steps', type=int, default=50, help='N
|
||||||
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
|
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
|
||||||
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
|
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
|
||||||
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
|
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
|
||||||
|
parser.add_argument('--resize', type=bool, default=False, help="Resizes dataset's images to the appropriate bucket dimensions.")
|
||||||
parser.add_argument('--use_xformers', type=bool, default=False, help='Use memory efficient attention')
|
parser.add_argument('--use_xformers', type=bool, default=False, help='Use memory efficient attention')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -163,16 +164,16 @@ class ImageStore:
|
||||||
# iterator returns images as PIL images and their index in the store
|
# iterator returns images as PIL images and their index in the store
|
||||||
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
|
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
|
||||||
for f in range(len(self)):
|
for f in range(len(self)):
|
||||||
yield Image.open(self.image_files[f]), f
|
yield Image.open(self.image_files[f]).convert(mode='RGB'), f
|
||||||
|
|
||||||
# get image by index
|
# get image by index
|
||||||
def get_image(self, index: int) -> Image.Image:
|
def get_image(self, ref: Tuple[int, int, int]) -> Image.Image:
|
||||||
return Image.open(self.image_files[index])
|
return Image.open(self.image_files[ref[0]]).convert(mode='RGB')
|
||||||
|
|
||||||
# gets caption by removing the extension from the filename and replacing it with .txt
|
# gets caption by removing the extension from the filename and replacing it with .txt
|
||||||
def get_caption(self, index: int) -> str:
|
def get_caption(self, ref: Tuple[int, int, int]) -> str:
|
||||||
filename = re.sub('\.[^/.]+$', '', self.image_files[index]) + '.txt'
|
filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt'
|
||||||
with open(filename, 'r') as f:
|
with open(filename, 'r', encoding='UTF-8') as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
@ -270,12 +271,12 @@ class AspectBucket:
|
||||||
def get_bucket_info(self):
|
def get_bucket_info(self):
|
||||||
return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios })
|
return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios })
|
||||||
|
|
||||||
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int], List[int]], None, None]:
|
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int, int]], None, None]:
|
||||||
"""
|
"""
|
||||||
Generator that provides batches where the images in a batch fall on the same bucket
|
Generator that provides batches where the images in a batch fall on the same bucket
|
||||||
|
|
||||||
Each element generated will be:
|
Each element generated will be:
|
||||||
((w, h), [image1, image2, ..., image{batch_size}])
|
(index, w, h)
|
||||||
|
|
||||||
where each image is an index into the dataset
|
where each image is an index into the dataset
|
||||||
:return:
|
:return:
|
||||||
|
@ -319,7 +320,7 @@ class AspectBucket:
|
||||||
|
|
||||||
total_generated_by_bucket[b] += self.batch_size
|
total_generated_by_bucket[b] += self.batch_size
|
||||||
bucket_pos[b] = i
|
bucket_pos[b] = i
|
||||||
yield [idx for idx in batch]
|
yield [(idx, *b) for idx in batch]
|
||||||
|
|
||||||
def fill_buckets(self):
|
def fill_buckets(self):
|
||||||
entries = self.store.entries_iterator()
|
entries = self.store.entries_iterator()
|
||||||
|
@ -384,16 +385,26 @@ class AspectDataset(torch.utils.data.Dataset):
|
||||||
self.transforms = torchvision.transforms.Compose([
|
self.transforms = torchvision.transforms.Compose([
|
||||||
torchvision.transforms.RandomHorizontalFlip(p=0.5),
|
torchvision.transforms.RandomHorizontalFlip(p=0.5),
|
||||||
torchvision.transforms.ToTensor(),
|
torchvision.transforms.ToTensor(),
|
||||||
torchvision.transforms.Normalize([0.5], [0.5]),
|
torchvision.transforms.Normalize([0.5], [0.5])
|
||||||
])
|
])
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.store)
|
return len(self.store)
|
||||||
|
|
||||||
def __getitem__(self, item: int):
|
def __getitem__(self, item: Tuple[int, int, int]):
|
||||||
return_dict = {'pixel_values': None, 'input_ids': None}
|
return_dict = {'pixel_values': None, 'input_ids': None}
|
||||||
|
|
||||||
image_file = self.store.get_image(item)
|
image_file = self.store.get_image(item)
|
||||||
|
|
||||||
|
if args.resize:
|
||||||
|
image_file = ImageOps.fit(
|
||||||
|
image_file,
|
||||||
|
(item[1], item[2]),
|
||||||
|
bleed=0.0,
|
||||||
|
centering=(0.5, 0.5),
|
||||||
|
method=Image.Resampling.LANCZOS
|
||||||
|
)
|
||||||
|
|
||||||
return_dict['pixel_values'] = self.transforms(image_file)
|
return_dict['pixel_values'] = self.transforms(image_file)
|
||||||
if random.random() > self.ucg:
|
if random.random() > self.ucg:
|
||||||
caption_file = self.store.get_caption(item)
|
caption_file = self.store.get_caption(item)
|
||||||
|
@ -580,7 +591,6 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# load dataset
|
# load dataset
|
||||||
|
|
||||||
store = ImageStore(args.dataset)
|
store = ImageStore(args.dataset)
|
||||||
dataset = AspectDataset(store, tokenizer)
|
dataset = AspectDataset(store, tokenizer)
|
||||||
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
|
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
|
||||||
|
|
Loading…
Reference in New Issue