Merge pull request #2201 from alg-wiki/textual__inversion

Textual Inversion: Preprocess and Training will only pick-up image files instead
This commit is contained in:
AUTOMATIC1111 2022-10-11 17:25:36 +03:00 committed by GitHub
commit 4f96ffd0b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 10 deletions

View File

@ -35,9 +35,10 @@ class PersonalizedBase(Dataset):
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...") print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths): for path in tqdm.tqdm(self.image_paths):
image = Image.open(path) try:
image = image.convert('RGB') image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
image = image.resize((self.width, self.height), PIL.Image.BICUBIC) except Exception:
continue
filename = os.path.basename(path) filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0] filename_tokens = os.path.splitext(filename)[0]

View File

@ -46,7 +46,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
for index, imagefile in enumerate(tqdm.tqdm(files)): for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0] subindex = [0]
filename = os.path.join(src, imagefile) filename = os.path.join(src, imagefile)
img = Image.open(filename).convert("RGB") try:
img = Image.open(filename).convert("RGB")
except Exception:
continue
if shared.state.interrupted: if shared.state.interrupted:
break break

View File

@ -200,9 +200,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps: if ititial_step > steps:
return embedding, filename return embedding, filename
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
epoch_len = (tr_img_len * num_repeats) + tr_img_len
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar: for i, (x, text) in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
@ -226,10 +223,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
loss.backward() loss.backward()
optimizer.step() optimizer.step()
epoch_num = embedding.step // epoch_len epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * epoch_len) + 1 epoch_step = embedding.step - (epoch_num * len(ds)) + 1
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}") pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')