Implement PR #3189 but for embeddings.
This commit is contained in:
parent
a524d137d0
commit
c2dc9bfa89
|
@ -10,7 +10,7 @@ import csv
|
|||
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
|
@ -247,6 +247,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
ititial_step = embedding.step or 0
|
||||
|
@ -296,8 +297,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
})
|
||||
|
||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
|
||||
forced_filename = f'{embedding_name}-{embedding.step}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
do_not_save_grid=True,
|
||||
|
@ -353,8 +354,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||
embedding_yet_to_be_embedded = False
|
||||
|
||||
image.save(last_saved_image)
|
||||
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = embedding.step
|
||||
|
|
Loading…
Reference in New Issue