print bucket sizes for training without resizing images #6620

fix an error when generating a picture with embedding in it
This commit is contained in:
AUTOMATIC 2023-01-13 14:32:15 +03:00
parent 486bda9b33
commit a176d89487
3 changed files with 19 additions and 3 deletions

View File

@ -118,6 +118,12 @@ class PersonalizedBase(Dataset):
self.gradient_step = min(gradient_step, self.length // self.batch_size) self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.latent_sampling_method = latent_sampling_method self.latent_sampling_method = latent_sampling_method
if len(groups) > 1:
print("Buckets:")
for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
print(f" {w}x{h}: {len(ids)}")
print()
def create_text(self, filename_text): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
tags = filename_text.split(',') tags = filename_text.split(',')
@ -140,8 +146,11 @@ class PersonalizedBase(Dataset):
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry return entry
class GroupedBatchSampler(Sampler): class GroupedBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int): def __init__(self, data_source: PersonalizedBase, batch_size: int):
super().__init__(data_source)
n = len(data_source) n = len(data_source)
self.groups = data_source.groups self.groups = data_source.groups
self.len = n_batch = n // batch_size self.len = n_batch = n // batch_size
@ -150,21 +159,28 @@ class GroupedBatchSampler(Sampler):
self.n_rand_batches = nrb = n_batch - sum(self.base) self.n_rand_batches = nrb = n_batch - sum(self.base)
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.batch_size = batch_size self.batch_size = batch_size
def __len__(self): def __len__(self):
return self.len return self.len
def __iter__(self): def __iter__(self):
b = self.batch_size b = self.batch_size
for g in self.groups: for g in self.groups:
shuffle(g) shuffle(g)
batches = [] batches = []
for g in self.groups: for g in self.groups:
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for _ in range(self.n_rand_batches): for _ in range(self.n_rand_batches):
rand_group = choices(self.groups, self.probs)[0] rand_group = choices(self.groups, self.probs)[0]
batches.append(choices(rand_group, k=b)) batches.append(choices(rand_group, k=b))
shuffle(batches) shuffle(batches)
yield from batches yield from batches
class PersonalizedDataLoader(DataLoader): class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)

View File

@ -76,10 +76,10 @@ def insert_image_data_embed(image, data):
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h)) next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
next_size = next_size + ((h*d)-(next_size % (h*d))) next_size = next_size + ((h*d)-(next_size % (h*d)))
data_np_low.resize(next_size) data_np_low = np.resize(data_np_low, next_size)
data_np_low = data_np_low.reshape((h, -1, d)) data_np_low = data_np_low.reshape((h, -1, d))
data_np_high.resize(next_size) data_np_high = np.resize(data_np_high, next_size)
data_np_high = data_np_high.reshape((h, -1, d)) data_np_high = data_np_high.reshape((h, -1, d))
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]

View File

@ -479,7 +479,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
epoch_num = embedding.step // steps_per_epoch epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % steps_per_epoch epoch_step = embedding.step % steps_per_epoch
description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
pbar.set_description(description) pbar.set_description(description)
shared.state.textinfo = description shared.state.textinfo = description
if embedding_dir is not None and steps_done % save_embedding_every == 0: if embedding_dir is not None and steps_done % save_embedding_every == 0: