modified the dataloader for PIL errors and added progress bar
This commit is contained in:
parent
317af6ace1
commit
3b0a7bb34a
|
@ -1,8 +1,12 @@
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import PIL
|
||||||
import random
|
import random
|
||||||
from ldm.data.image_train_item import ImageTrainItem
|
from ldm.data.image_train_item import ImageTrainItem
|
||||||
import ldm.data.aspects as aspects
|
import ldm.data.aspects as aspects
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
PIL.Image.MAX_IMAGE_PIXELS = 933120000
|
||||||
|
|
||||||
class DataLoaderMultiAspect():
|
class DataLoaderMultiAspect():
|
||||||
"""
|
"""
|
||||||
|
@ -50,7 +54,7 @@ class DataLoaderMultiAspect():
|
||||||
"""
|
"""
|
||||||
decorated_image_train_items = []
|
decorated_image_train_items = []
|
||||||
|
|
||||||
for pathname in image_paths:
|
for pathname in tqdm(image_paths):
|
||||||
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
|
||||||
|
|
||||||
txt_file_path = os.path.splitext(pathname)[0] + ".txt"
|
txt_file_path = os.path.splitext(pathname)[0] + ".txt"
|
||||||
|
|
Loading…
Reference in New Issue