add embedding load and save from b64 json
This commit is contained in:
parent
fa0c5eb81b
commit
03694e1f99
|
@ -7,9 +7,11 @@ import tqdm
|
|||
import html
|
||||
import datetime
|
||||
|
||||
from PIL import Image, PngImagePlugin
|
||||
from PIL import Image,PngImagePlugin
|
||||
from ..images import captionImge
|
||||
import numpy as np
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import json
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||
import modules.textual_inversion.dataset
|
||||
|
@ -87,9 +89,9 @@ class EmbeddingDatabase:
|
|||
|
||||
if filename.upper().endswith('.PNG'):
|
||||
embed_image = Image.open(path)
|
||||
if 'sd-embedding' in embed_image.text:
|
||||
embeddingData = base64.b64decode(embed_image.text['sd-embedding'])
|
||||
data = torch.load(BytesIO(embeddingData), map_location="cpu")
|
||||
if 'sd-ti-embedding' in embed_image.text:
|
||||
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
|
||||
name = data.get('name',name)
|
||||
else:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
|
||||
|
@ -258,13 +260,23 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|||
|
||||
if save_image_with_stored_embedding:
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read()))
|
||||
image.save(last_saved_image, "PNG", pnginfo=info)
|
||||
data = torch.load(last_saved_file)
|
||||
info.add_text("sd-ti-embedding", embeddingToB64(data))
|
||||
|
||||
pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
|
||||
|
||||
caption_checkpoint_hash = data.get('sd_checkpoint','UNK')
|
||||
caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK'
|
||||
caption_stepcount = data.get('step',0)
|
||||
caption_stepcount = caption_stepcount if caption_stepcount else 0
|
||||
|
||||
post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash,
|
||||
caption_stepcount))]
|
||||
captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
|
||||
captioned_image.save(last_saved_image, "PNG", pnginfo=info)
|
||||
else:
|
||||
image.save(last_saved_image)
|
||||
|
||||
|
||||
|
||||
last_saved_image += f", prompt: {text}"
|
||||
|
||||
shared.state.job_no = embedding.step
|
||||
|
|
Loading…
Reference in New Issue