Merge pull request #11593 from akx/better-status-reporting-1
Better status reporting, part 1
This commit is contained in:
commit
d78377ea5d
|
@ -330,7 +330,7 @@ class Api:
|
|||
p.outpath_grids = opts.outdir_txt2img_grids
|
||||
p.outpath_samples = opts.outdir_txt2img_samples
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="scripts_txt2img")
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
|
@ -387,7 +387,7 @@ class Api:
|
|||
p.outpath_grids = opts.outdir_img2img_grids
|
||||
p.outpath_samples = opts.outdir_img2img_samples
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="scripts_img2img")
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
|
@ -396,7 +396,6 @@ class Api:
|
|||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
|
||||
if not img2imgreq.include_init_images:
|
||||
|
@ -603,44 +602,42 @@ class Api:
|
|||
|
||||
def create_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="create_embedding")
|
||||
filename = create_embedding(**args) # create empty embedding
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||
shared.state.end()
|
||||
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
|
||||
def create_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="create_hypernetwork")
|
||||
filename = create_hypernetwork(**args) # create empty embedding
|
||||
shared.state.end()
|
||||
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
def preprocess(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="preprocess")
|
||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||
shared.state.end()
|
||||
return models.PreprocessResponse(info = 'preprocess complete')
|
||||
return models.PreprocessResponse(info='preprocess complete')
|
||||
except KeyError as e:
|
||||
shared.state.end()
|
||||
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
except Exception as e:
|
||||
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||
except FileNotFoundError as e:
|
||||
finally:
|
||||
shared.state.end()
|
||||
return models.PreprocessResponse(info=f'preprocess error: {e}')
|
||||
|
||||
def train_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="train_embedding")
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
filename = ''
|
||||
|
@ -653,15 +650,15 @@ class Api:
|
|||
finally:
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||
except AssertionError as msg:
|
||||
shared.state.end()
|
||||
except Exception as msg:
|
||||
return models.TrainResponse(info=f"train embedding error: {msg}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
def train_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="train_hypernetwork")
|
||||
shared.loaded_hypernetworks = []
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
|
@ -679,9 +676,10 @@ class Api:
|
|||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||
except AssertionError:
|
||||
except Exception as exc:
|
||||
return models.TrainResponse(info=f"train embedding error: {exc}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"train embedding error: {error}")
|
||||
|
||||
def get_memory(self):
|
||||
try:
|
||||
|
|
|
@ -30,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
|||
id_task = None
|
||||
|
||||
with queue_lock:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job=id_task)
|
||||
progress.start_task(id_task)
|
||||
|
||||
try:
|
||||
|
|
|
@ -73,8 +73,7 @@ def to_half(tensor, enable):
|
|||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
shared.state.begin(job="model-merge")
|
||||
|
||||
def fail(message):
|
||||
shared.state.textinfo = message
|
||||
|
|
|
@ -184,8 +184,7 @@ class InterrogateModels:
|
|||
|
||||
def interrogate(self, pil_image):
|
||||
res = ""
|
||||
shared.state.begin()
|
||||
shared.state.job = 'interrogate'
|
||||
shared.state.begin(job="interrogate")
|
||||
try:
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
|
|
|
@ -9,8 +9,7 @@ from modules.shared import opts
|
|||
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'extras'
|
||||
shared.state.begin(job="extras")
|
||||
|
||||
image_data = []
|
||||
image_names = []
|
||||
|
|
|
@ -4,6 +4,7 @@ import os
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
@ -18,6 +19,8 @@ from modules.paths_internal import models_path, script_path, data_path, sd_confi
|
|||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from typing import Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
demo = None
|
||||
|
||||
parser = cmd_args.parser
|
||||
|
@ -144,12 +147,15 @@ class State:
|
|||
def request_restart(self) -> None:
|
||||
self.interrupt()
|
||||
self.server_command = "restart"
|
||||
log.info("Received restart request")
|
||||
|
||||
def skip(self):
|
||||
self.skipped = True
|
||||
log.info("Received skip request")
|
||||
|
||||
def interrupt(self):
|
||||
self.interrupted = True
|
||||
log.info("Received interrupt request")
|
||||
|
||||
def nextjob(self):
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
|
||||
|
@ -173,7 +179,7 @@ class State:
|
|||
|
||||
return obj
|
||||
|
||||
def begin(self):
|
||||
def begin(self, job: str = "(unknown)"):
|
||||
self.sampling_step = 0
|
||||
self.job_count = -1
|
||||
self.processing_has_refined_job_count = False
|
||||
|
@ -187,10 +193,13 @@ class State:
|
|||
self.interrupted = False
|
||||
self.textinfo = None
|
||||
self.time_start = time.time()
|
||||
|
||||
self.job = job
|
||||
devices.torch_gc()
|
||||
log.info("Starting job %s", job)
|
||||
|
||||
def end(self):
|
||||
duration = time.time() - self.time_start
|
||||
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
||||
self.job = ""
|
||||
self.job_count = 0
|
||||
|
||||
|
|
11
webui.py
11
webui.py
|
@ -18,6 +18,17 @@ from packaging import version
|
|||
|
||||
import logging
|
||||
|
||||
# We can't use cmd_opts for this because it will not have been initialized at this point.
|
||||
log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||
if log_level:
|
||||
log_level = getattr(logging, log_level.upper(), None) or logging.INFO
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
)
|
||||
|
||||
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
from modules import paths, timer, import_hook, errors, devices # noqa: F401
|
||||
|
|
Loading…
Reference in New Issue