Merge pull request #141 from tjennings/main
image preloading and data loading optimizations.
This commit is contained in:
commit
7ba9d43db5
|
@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from data.data_loader import DataLoaderMultiAspect
|
||||
|
@ -61,18 +63,18 @@ class EveryDreamBatch(Dataset):
|
|||
self.image_train_items = []
|
||||
self.__update_image_train_items(1.0)
|
||||
self.name = name
|
||||
|
||||
|
||||
num_images = len(self.image_train_items)
|
||||
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||
|
||||
def shuffle(self, epoch_n: int, max_epochs: int):
|
||||
self.seed += 1
|
||||
|
||||
|
||||
if self.rated_dataset:
|
||||
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
|
||||
else:
|
||||
dropout_fraction = 1.0
|
||||
|
||||
|
||||
self.__update_image_train_items(dropout_fraction)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -135,18 +137,18 @@ class EveryDreamBatch(Dataset):
|
|||
if image_train_tmp.cond_dropout is not None:
|
||||
example["cond_dropout"] = image_train_tmp.cond_dropout
|
||||
example["runt_size"] = image_train_tmp.runt_size
|
||||
|
||||
|
||||
return example
|
||||
|
||||
def __update_image_train_items(self, dropout_fraction: float):
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||
|
||||
|
||||
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
batch_size= batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
num_workers=min(batch_size, os.cpu_count()),
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return dataloader
|
||||
|
|
|
@ -106,7 +106,7 @@ class ImageCaption:
|
|||
|
||||
weights_copy.pop(pos)
|
||||
tag = tags_copy.pop(pos)
|
||||
|
||||
|
||||
if caption:
|
||||
caption += ", "
|
||||
caption += tag
|
||||
|
@ -122,7 +122,7 @@ class ImageTrainItem:
|
|||
"""
|
||||
image: PIL.Image
|
||||
identifier: caption,
|
||||
target_aspect: (width, height),
|
||||
target_aspect: (width, height),
|
||||
pathname: path to image file
|
||||
flip_p: probability of flipping image (0.0 to 1.0)
|
||||
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
||||
|
@ -144,14 +144,14 @@ class ImageTrainItem:
|
|||
self.image = image
|
||||
self.image_size = image.size
|
||||
self.target_size = None
|
||||
|
||||
|
||||
self.is_undersized = False
|
||||
self.error = None
|
||||
self.__compute_target_width_height()
|
||||
|
||||
def load_image(self):
|
||||
image = PIL.Image.open(self.pathname).convert('RGB')
|
||||
try:
|
||||
image = PIL.Image.open(self.pathname).convert('RGB')
|
||||
image = ImageOps.exif_transpose(image)
|
||||
except SyntaxError as e:
|
||||
pass
|
||||
|
@ -232,15 +232,15 @@ class ImageTrainItem:
|
|||
# print(self.image.shape)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def __compute_target_width_height(self):
|
||||
self.target_wh = None
|
||||
try:
|
||||
with self.load_image() as image:
|
||||
with PIL.Image.open(self.pathname) as image:
|
||||
width, height = image.size
|
||||
image_aspect = width / height
|
||||
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
||||
|
||||
|
||||
self.is_undersized = (width * height) < (target_wh[0] * target_wh[1])
|
||||
self.target_wh = target_wh
|
||||
except Exception as e:
|
||||
|
|
Loading…
Reference in New Issue