Merge pull request #141 from tjennings/main

image preloading and data loading optimizations.
This commit is contained in:
Victor Hall 2023-04-14 21:25:32 -04:00 committed by GitHub
commit 7ba9d43db5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 14 deletions

View File

@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import os
import torch
from torch.utils.data import Dataset
from data.data_loader import DataLoaderMultiAspect
@ -144,9 +146,9 @@ class EveryDreamBatch(Dataset):
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
batch_size= batch_size,
shuffle=False,
num_workers=4,
num_workers=min(batch_size, os.cpu_count()),
collate_fn=collate_fn
)
return dataloader

View File

@ -150,8 +150,8 @@ class ImageTrainItem:
self.__compute_target_width_height()
def load_image(self):
image = PIL.Image.open(self.pathname).convert('RGB')
try:
image = PIL.Image.open(self.pathname).convert('RGB')
image = ImageOps.exif_transpose(image)
except SyntaxError as e:
pass
@ -236,7 +236,7 @@ class ImageTrainItem:
def __compute_target_width_height(self):
self.target_wh = None
try:
with self.load_image() as image:
with PIL.Image.open(self.pathname) as image:
width, height = image.size
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))