change caption method

This commit is contained in:
DepFA 2022-10-10 00:07:52 +01:00 committed by GitHub
parent 0ac3a07eec
commit d6a599ef9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 9 deletions

View File

@ -8,7 +8,7 @@ import html
import datetime
from PIL import Image,PngImagePlugin
from ..images import captionImge
from ..images import captionImageOverlay
import numpy as np
import base64
import json
@ -212,6 +212,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
else:
images_dir = None
if create_image_every > 0 and save_image_with_stored_embedding:
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
os.makedirs(images_embeds_dir, exist_ok=True)
else:
images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@ -279,19 +285,25 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
shared.state.current_image = image
if save_image_with_stored_embedding:
if save_image_with_stored_embedding and os.path.exists(last_saved_file):
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embeddingToB64(data))
pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
title = "<{}>".format(data.get('name','???'))
checkpoint = sd_models.select_checkpoint()
post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash,
embedding.step))]
captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
captioned_image.save(last_saved_image, "PNG", pnginfo=info)
else:
image.save(last_saved_image)
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '[{}]'.format(embedding.step)
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
image.save(last_saved_image)
last_saved_image += f", prompt: {text}"