Fix some deprecated types

This commit is contained in:
a666 2023-08-25 01:58:19 -06:00
parent 84d41e49b3
commit b6c1a1bbbf
7 changed files with 34 additions and 38 deletions

View File

@ -29,7 +29,7 @@ from modules.sd_models import unload_model_weights, reload_model_weights, checkp
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import Dict, List, Any from typing import Any
import piexif import piexif
import piexif.helper import piexif.helper
from contextlib import closing from contextlib import closing
@ -221,15 +221,15 @@ class Api:
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
@ -242,8 +242,8 @@ class Api:
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=List[models.ExtensionItem]) self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
if shared.cmd_opts.api_server_stop: if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
@ -563,7 +563,7 @@ class Api:
return options return options
def set_config(self, req: Dict[str, Any]): def set_config(self, req: dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None) checkpoint_name = req.get("sd_model_checkpoint", None)
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found") raise RuntimeError(f"model {checkpoint_name!r} not found")

View File

@ -1,12 +1,10 @@
import inspect import inspect
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from typing import Any, Optional from typing import Any, Optional, Literal
from typing_extensions import Literal
from inflection import underscore from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers, opts, parser from modules.shared import sd_upscalers, opts, parser
from typing import Dict, List
API_NOT_ALLOWED = [ API_NOT_ALLOWED = [
"self", "self",
@ -130,12 +128,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
).generate_model() ).generate_model()
class TextToImageResponse(BaseModel): class TextToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict parameters: dict
info: str info: str
class ImageToImageResponse(BaseModel): class ImageToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict parameters: dict
info: str info: str
@ -168,10 +166,10 @@ class FileData(BaseModel):
name: str = Field(title="File name") name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest): class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse): class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: List[str] = Field(title="Images", description="The generated images in base64 format.") images: list[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel): class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image") image: str = Field(title="Image", description="The base64 encoded PNG image")
@ -233,8 +231,8 @@ FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel): class SamplerItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
aliases: List[str] = Field(title="Aliases") aliases: list[str] = Field(title="Aliases")
options: Dict[str, str] = Field(title="Options") options: dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel): class UpscalerItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
@ -285,8 +283,8 @@ class EmbeddingItem(BaseModel):
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding") vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class EmbeddingsResponse(BaseModel): class EmbeddingsResponse(BaseModel):
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model") loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
class MemoryResponse(BaseModel): class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats") ram: dict = Field(title="RAM", description="System memory stats")
@ -304,14 +302,14 @@ class ScriptArg(BaseModel):
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI") minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI") maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI") step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument") choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
class ScriptInfo(BaseModel): class ScriptInfo(BaseModel):
name: str = Field(default=None, title="Name", description="Script name") name: str = Field(default=None, title="Name", description="Script name")
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script") is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script") is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments") args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
class ExtensionItem(BaseModel): class ExtensionItem(BaseModel):
name: str = Field(title="Name", description="Extension name") name: str = Field(title="Name", description="Extension name")

View File

@ -23,7 +23,7 @@ class Git(git.Git):
) )
return self._parse_object_header(ret) return self._parse_object_header(ret)
def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]: def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]:
# Not really streaming, per se; this buffers the entire object in memory. # Not really streaming, per se; this buffers the entire object in memory.
# Shouldn't be a problem for our use case, since we're only using this for # Shouldn't be a problem for our use case, since we're only using this for
# object headers (commit objects). # object headers (commit objects).

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import re import re
from collections import namedtuple from collections import namedtuple
from typing import List
import lark import lark
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
@ -240,14 +239,14 @@ def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
class ComposableScheduledPromptConditioning: class ComposableScheduledPromptConditioning:
def __init__(self, schedules, weight=1.0): def __init__(self, schedules, weight=1.0):
self.schedules: List[ScheduledPromptConditioning] = schedules self.schedules: list[ScheduledPromptConditioning] = schedules
self.weight: float = weight self.weight: float = weight
class MulticondLearnedConditioning: class MulticondLearnedConditioning:
def __init__(self, shape, batch): def __init__(self, shape, batch):
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch self.batch: list[list[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning: def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
@ -278,7 +277,7 @@ class DictWithShape(dict):
return self["crossattn"].shape return self["crossattn"].shape
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond param = c[0][0].cond
is_dict = isinstance(param, dict) is_dict = isinstance(param, dict)

View File

@ -1,7 +1,7 @@
import inspect import inspect
import os import os
from collections import namedtuple from collections import namedtuple
from typing import Optional, Dict, Any from typing import Optional, Any
from fastapi import FastAPI from fastapi import FastAPI
from gradio import Blocks from gradio import Blocks
@ -255,7 +255,7 @@ def image_grid_callback(params: ImageGridLoopParams):
report_exception(c, 'image_grid') report_exception(c, 'image_grid')
def infotext_pasted_callback(infotext: str, params: Dict[str, Any]): def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
for c in callback_map['callbacks_infotext_pasted']: for c in callback_map['callbacks_infotext_pasted']:
try: try:
c.callback(infotext, params) c.callback(infotext, params)
@ -446,7 +446,7 @@ def on_infotext_pasted(callback):
"""register a function to be called before applying an infotext. """register a function to be called before applying an infotext.
The callback is called with two arguments: The callback is called with two arguments:
- infotext: str - raw infotext. - infotext: str - raw infotext.
- result: Dict[str, any] - parsed infotext parameters. - result: dict[str, any] - parsed infotext parameters.
""" """
add_callback(callback_map['callbacks_infotext_pasted'], callback) add_callback(callback_map['callbacks_infotext_pasted'], callback)

View File

@ -15,7 +15,7 @@ import torch
from torch import Tensor from torch import Tensor
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import math import math
from typing import Optional, NamedTuple, List from typing import Optional, NamedTuple
def narrow_trunc( def narrow_trunc(
@ -97,7 +97,7 @@ def _query_chunk_attention(
) )
return summarize_chunk(query, key_chunk, value_chunk) return summarize_chunk(query, key_chunk, value_chunk)
chunks: List[AttnChunk] = [ chunks: list[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
] ]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))

View File

@ -1338,7 +1338,6 @@ checkpoint: <a id="sd_checkpoint_hash">N/A</a>
def setup_ui_api(app): def setup_ui_api(app):
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List
class QuicksettingsHint(BaseModel): class QuicksettingsHint(BaseModel):
name: str = Field(title="Name of the quicksettings field") name: str = Field(title="Name of the quicksettings field")
@ -1347,7 +1346,7 @@ def setup_ui_api(app):
def quicksettings_hint(): def quicksettings_hint():
return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()] return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint]) app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint])
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"]) app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])