Merge pull request #15319 from catboxanon/feat/ssmd_cover_images

Support cover images embedded in safetensors metadata
This commit is contained in:
AUTOMATIC1111 2024-03-24 13:43:37 +03:00 committed by GitHub
commit b0b90dc0d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 2 deletions

View File

@ -29,7 +29,6 @@ class NetworkOnDisk:
def read_metadata(): def read_metadata():
metadata = sd_models.read_metadata_from_safetensors(filename) metadata = sd_models.read_metadata_from_safetensors(filename)
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
return metadata return metadata

View File

@ -31,7 +31,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"name": name, "name": name,
"filename": lora_on_disk.filename, "filename": lora_on_disk.filename,
"shorthash": lora_on_disk.shorthash, "shorthash": lora_on_disk.shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata),
"description": self.find_description(path), "description": self.find_description(path),
"search_terms": search_terms, "search_terms": search_terms,
"local_preview": f"{path}.{shared.opts.samples_format}", "local_preview": f"{path}.{shared.opts.samples_format}",

View File

@ -1,6 +1,8 @@
import functools import functools
import os.path import os.path
import urllib.parse import urllib.parse
from base64 import b64decode
from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
from dataclasses import dataclass from dataclasses import dataclass
@ -11,6 +13,7 @@ import gradio as gr
import json import json
import html import html
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from PIL import Image
from modules.infotext_utils import image_from_url_text from modules.infotext_utils import image_from_url_text
@ -108,6 +111,31 @@ def fetch_file(filename: str = ""):
return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
def fetch_cover_images(page: str = "", item: str = "", index: int = 0):
from starlette.responses import Response
page = next(iter([x for x in extra_pages if x.name == page]), None)
if page is None:
raise HTTPException(status_code=404, detail="File not found")
metadata = page.metadata.get(item)
if metadata is None:
raise HTTPException(status_code=404, detail="File not found")
cover_images = json.loads(metadata.get('ssmd_cover_images', {}))
image = cover_images[index] if index < len(cover_images) else None
if not image:
raise HTTPException(status_code=404, detail="File not found")
try:
image = Image.open(BytesIO(b64decode(image)))
buffer = BytesIO()
image.save(buffer, format=image.format)
return Response(content=buffer.getvalue(), media_type=image.get_format_mimetype())
except Exception as err:
raise ValueError(f"File cannot be fetched: {item}. Failed to load cover image.") from err
def get_metadata(page: str = "", item: str = ""): def get_metadata(page: str = "", item: str = ""):
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
@ -119,6 +147,8 @@ def get_metadata(page: str = "", item: str = ""):
if metadata is None: if metadata is None:
return JSONResponse({}) return JSONResponse({})
metadata = {i:metadata[i] for i in metadata if i != 'ssmd_cover_images'} # those are cover images, and they are too big to display in UI as text
return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)}) return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
@ -142,6 +172,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
def add_pages_to_demo(app): def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
app.add_api_route("/sd_extra_networks/cover-images", fetch_cover_images, methods=["GET"])
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"]) app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"]) app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
@ -626,6 +657,17 @@ class ExtraNetworksPage:
return None return None
def find_embedded_preview(self, path, name, metadata):
"""
Find if embedded preview exists in safetensors metadata and return endpoint for it.
"""
file = f"{path}.safetensors"
if self.lister.exists(file) and 'ssmd_cover_images' in metadata and len(list(filter(None, json.loads(metadata['ssmd_cover_images'])))) > 0:
return f"./sd_extra_networks/cover-images?page={self.extra_networks_tabname}&item={name}"
return None
def find_description(self, path): def find_description(self, path):
""" """
Find and read a description file for a given path (without extension). Find and read a description file for a given path (without extension).