Merge pull request #2201 from alg-wiki/textual__inversion
Textual Inversion: Preprocess and Training will only pick-up image files instead
This commit is contained in:
commit
4f96ffd0b5
|
@ -35,9 +35,10 @@ class PersonalizedBase(Dataset):
|
||||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
print("Preparing dataset...")
|
print("Preparing dataset...")
|
||||||
for path in tqdm.tqdm(self.image_paths):
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
image = Image.open(path)
|
try:
|
||||||
image = image.convert('RGB')
|
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||||
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
filename = os.path.basename(path)
|
filename = os.path.basename(path)
|
||||||
filename_tokens = os.path.splitext(filename)[0]
|
filename_tokens = os.path.splitext(filename)[0]
|
||||||
|
|
|
@ -46,7 +46,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||||
subindex = [0]
|
subindex = [0]
|
||||||
filename = os.path.join(src, imagefile)
|
filename = os.path.join(src, imagefile)
|
||||||
img = Image.open(filename).convert("RGB")
|
try:
|
||||||
|
img = Image.open(filename).convert("RGB")
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
|
@ -200,9 +200,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
|
|
||||||
epoch_len = (tr_img_len * num_repeats) + tr_img_len
|
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
for i, (x, text) in pbar:
|
for i, (x, text) in pbar:
|
||||||
embedding.step = i + ititial_step
|
embedding.step = i + ititial_step
|
||||||
|
@ -226,10 +223,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
epoch_num = embedding.step // epoch_len
|
epoch_num = embedding.step // len(ds)
|
||||||
epoch_step = embedding.step - (epoch_num * epoch_len) + 1
|
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|
||||||
|
|
||||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||||
|
|
||||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||||
|
|
Loading…
Reference in New Issue