Merge pull request #6620 from guaneec/varsize_batch
Enable batch_size>1 for mixed-sized training
This commit is contained in:
commit
486bda9b33
|
@ -3,8 +3,10 @@ import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
from collections import defaultdict
|
||||||
|
from random import shuffle, choices
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -45,12 +47,12 @@ class PersonalizedBase(Dataset):
|
||||||
assert data_root, 'dataset directory not specified'
|
assert data_root, 'dataset directory not specified'
|
||||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||||
assert os.listdir(data_root), "Dataset directory is empty"
|
assert os.listdir(data_root), "Dataset directory is empty"
|
||||||
assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
|
|
||||||
|
|
||||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
|
|
||||||
self.shuffle_tags = shuffle_tags
|
self.shuffle_tags = shuffle_tags
|
||||||
self.tag_drop_out = tag_drop_out
|
self.tag_drop_out = tag_drop_out
|
||||||
|
groups = defaultdict(list)
|
||||||
|
|
||||||
print("Preparing dataset...")
|
print("Preparing dataset...")
|
||||||
for path in tqdm.tqdm(self.image_paths):
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
|
@ -103,13 +105,14 @@ class PersonalizedBase(Dataset):
|
||||||
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||||
|
groups[image.size].append(len(self.dataset))
|
||||||
self.dataset.append(entry)
|
self.dataset.append(entry)
|
||||||
del torchdata
|
del torchdata
|
||||||
del latent_dist
|
del latent_dist
|
||||||
del latent_sample
|
del latent_sample
|
||||||
|
|
||||||
self.length = len(self.dataset)
|
self.length = len(self.dataset)
|
||||||
|
self.groups = list(groups.values())
|
||||||
assert self.length > 0, "No images have been found in the dataset."
|
assert self.length > 0, "No images have been found in the dataset."
|
||||||
self.batch_size = min(batch_size, self.length)
|
self.batch_size = min(batch_size, self.length)
|
||||||
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||||
|
@ -137,9 +140,34 @@ class PersonalizedBase(Dataset):
|
||||||
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
class GroupedBatchSampler(Sampler):
|
||||||
|
def __init__(self, data_source: PersonalizedBase, batch_size: int):
|
||||||
|
n = len(data_source)
|
||||||
|
self.groups = data_source.groups
|
||||||
|
self.len = n_batch = n // batch_size
|
||||||
|
expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
|
||||||
|
self.base = [int(e) // batch_size for e in expected]
|
||||||
|
self.n_rand_batches = nrb = n_batch - sum(self.base)
|
||||||
|
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
|
||||||
|
self.batch_size = batch_size
|
||||||
|
def __len__(self):
|
||||||
|
return self.len
|
||||||
|
def __iter__(self):
|
||||||
|
b = self.batch_size
|
||||||
|
for g in self.groups:
|
||||||
|
shuffle(g)
|
||||||
|
batches = []
|
||||||
|
for g in self.groups:
|
||||||
|
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
|
||||||
|
for _ in range(self.n_rand_batches):
|
||||||
|
rand_group = choices(self.groups, self.probs)[0]
|
||||||
|
batches.append(choices(rand_group, k=b))
|
||||||
|
shuffle(batches)
|
||||||
|
yield from batches
|
||||||
|
|
||||||
class PersonalizedDataLoader(DataLoader):
|
class PersonalizedDataLoader(DataLoader):
|
||||||
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||||
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
|
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|
||||||
if latent_sampling_method == "random":
|
if latent_sampling_method == "random":
|
||||||
self.collate_fn = collate_wrapper_random
|
self.collate_fn = collate_wrapper_random
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue