stable-diffusion-webui/modules/api/models.py

108 lines
3.8 KiB
Python
Raw Normal View History

from array import array
2022-10-17 01:02:08 -06:00
from inflection import underscore
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, create_model
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
2022-10-17 01:02:08 -06:00
import inspect
2022-10-17 13:10:36 -06:00
API_NOT_ALLOWED = [
"self",
"kwargs",
"sd_model",
"outpath_samples",
"outpath_grids",
"sampler_index",
"do_not_save_samples",
"do_not_save_grid",
"extra_generation_params",
"overlay_images",
"do_not_reload_embeddings",
"seed_enable_extras",
"prompt_for_display",
"sampler_noise_scheduler_override",
"ddim_discretize"
]
2022-10-17 01:02:08 -06:00
class ModelDef(BaseModel):
"""Assistance Class for Pydantic Dynamic Model Generation"""
field: str
field_alias: str
field_type: Any
field_value: Any
field_exclude: bool = False
2022-10-17 01:02:08 -06:00
2022-10-17 13:10:36 -06:00
class PydanticModelGenerator:
2022-10-17 01:02:08 -06:00
"""
2022-10-17 01:18:41 -06:00
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
source_data is a snapshot of the default values produced by the class
params are the names of the actual keys required by __init__
2022-10-17 01:02:08 -06:00
"""
def __init__(
self,
model_name: str = None,
2022-10-18 13:04:56 -06:00
class_instance = None,
additional_fields = None,
2022-10-17 01:02:08 -06:00
):
2022-10-17 13:10:36 -06:00
def field_type_generator(k, v):
# field_type = str if not overrides.get(k) else overrides[k]["type"]
# print(k, v.annotation, v.default)
field_type = v.annotation
2022-10-17 01:02:08 -06:00
return Optional[field_type]
2022-10-17 13:10:36 -06:00
def merge_class_params(class_):
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
parameters = {}
for classes in all_classes:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters
2022-10-17 01:02:08 -06:00
self._model_name = model_name
2022-10-17 13:10:36 -06:00
self._class_data = merge_class_params(class_instance)
2022-10-17 01:02:08 -06:00
self._model_def = [
ModelDef(
field=underscore(k),
field_alias=k,
2022-10-17 13:10:36 -06:00
field_type=field_type_generator(k, v),
2022-10-24 10:18:54 -06:00
field_value=v.default
2022-10-17 01:02:08 -06:00
)
2022-10-17 13:10:36 -06:00
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
2022-10-17 01:02:08 -06:00
]
2022-10-18 13:04:56 -06:00
for fields in additional_fields:
self._model_def.append(ModelDef(
field=underscore(fields["key"]),
field_alias=fields["key"],
field_type=fields["type"],
field_value=fields["default"],
field_exclude=fields["exclude"] if "exclude" in fields else False))
2022-10-17 01:02:08 -06:00
def generate_model(self):
"""
Creates a pydantic BaseModel
from the json and overrides provided at initialization
"""
fields = {
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
2022-10-17 01:02:08 -06:00
}
DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True
2022-10-17 13:10:36 -06:00
DynamicModel.__config__.allow_mutation = True
2022-10-17 01:02:08 -06:00
return DynamicModel
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
2022-10-18 13:04:56 -06:00
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
2022-10-18 23:19:01 -06:00
[{"key": "sampler_index", "type": str, "default": "Euler"}]
).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
2022-10-18 23:19:01 -06:00
).generate_model()