167 lines
5.3 KiB
Python
167 lines
5.3 KiB
Python
import datetime
|
|
import logging
|
|
import threading
|
|
import time
|
|
|
|
from modules import errors, shared, devices
|
|
from typing import Optional
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class State:
|
|
skipped = False
|
|
interrupted = False
|
|
stopping_generation = False
|
|
job = ""
|
|
job_no = 0
|
|
job_count = 0
|
|
processing_has_refined_job_count = False
|
|
job_timestamp = '0'
|
|
sampling_step = 0
|
|
sampling_steps = 0
|
|
current_latent = None
|
|
current_image = None
|
|
current_image_sampling_step = 0
|
|
id_live_preview = 0
|
|
textinfo = None
|
|
time_start = None
|
|
server_start = None
|
|
_server_command_signal = threading.Event()
|
|
_server_command: Optional[str] = None
|
|
|
|
def __init__(self):
|
|
self.server_start = time.time()
|
|
|
|
@property
|
|
def need_restart(self) -> bool:
|
|
# Compatibility getter for need_restart.
|
|
return self.server_command == "restart"
|
|
|
|
@need_restart.setter
|
|
def need_restart(self, value: bool) -> None:
|
|
# Compatibility setter for need_restart.
|
|
if value:
|
|
self.server_command = "restart"
|
|
|
|
@property
|
|
def server_command(self):
|
|
return self._server_command
|
|
|
|
@server_command.setter
|
|
def server_command(self, value: Optional[str]) -> None:
|
|
"""
|
|
Set the server command to `value` and signal that it's been set.
|
|
"""
|
|
self._server_command = value
|
|
self._server_command_signal.set()
|
|
|
|
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
|
|
"""
|
|
Wait for server command to get set; return and clear the value and signal.
|
|
"""
|
|
if self._server_command_signal.wait(timeout):
|
|
self._server_command_signal.clear()
|
|
req = self._server_command
|
|
self._server_command = None
|
|
return req
|
|
return None
|
|
|
|
def request_restart(self) -> None:
|
|
self.interrupt()
|
|
self.server_command = "restart"
|
|
log.info("Received restart request")
|
|
|
|
def skip(self):
|
|
self.skipped = True
|
|
log.info("Received skip request")
|
|
|
|
def interrupt(self):
|
|
self.interrupted = True
|
|
log.info("Received interrupt request")
|
|
|
|
def stop_generating(self):
|
|
self.stopping_generation = True
|
|
log.info("Received stop generating request")
|
|
|
|
def nextjob(self):
|
|
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
|
|
self.do_set_current_image()
|
|
|
|
self.job_no += 1
|
|
self.sampling_step = 0
|
|
self.current_image_sampling_step = 0
|
|
|
|
def dict(self):
|
|
obj = {
|
|
"skipped": self.skipped,
|
|
"interrupted": self.interrupted,
|
|
"stopping_generation": self.stopping_generation,
|
|
"job": self.job,
|
|
"job_count": self.job_count,
|
|
"job_timestamp": self.job_timestamp,
|
|
"job_no": self.job_no,
|
|
"sampling_step": self.sampling_step,
|
|
"sampling_steps": self.sampling_steps,
|
|
}
|
|
|
|
return obj
|
|
|
|
def begin(self, job: str = "(unknown)"):
|
|
self.sampling_step = 0
|
|
self.time_start = time.time()
|
|
self.job_count = -1
|
|
self.processing_has_refined_job_count = False
|
|
self.job_no = 0
|
|
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
self.current_latent = None
|
|
self.current_image = None
|
|
self.current_image_sampling_step = 0
|
|
self.id_live_preview = 0
|
|
self.skipped = False
|
|
self.interrupted = False
|
|
self.stopping_generation = False
|
|
self.textinfo = None
|
|
self.job = job
|
|
devices.torch_gc()
|
|
log.info("Starting job %s", job)
|
|
|
|
def end(self):
|
|
duration = time.time() - self.time_start
|
|
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
|
self.job = ""
|
|
self.job_count = 0
|
|
|
|
devices.torch_gc()
|
|
|
|
def set_current_image(self):
|
|
"""if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
|
|
if not shared.parallel_processing_allowed:
|
|
return
|
|
|
|
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
|
|
self.do_set_current_image()
|
|
|
|
def do_set_current_image(self):
|
|
if self.current_latent is None:
|
|
return
|
|
|
|
import modules.sd_samplers
|
|
|
|
try:
|
|
if shared.opts.show_progress_grid:
|
|
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
|
else:
|
|
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
|
|
|
self.current_image_sampling_step = self.sampling_step
|
|
|
|
except Exception:
|
|
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
|
# we silently ignore this error
|
|
errors.record_exception()
|
|
|
|
def assign_current_image(self, image):
|
|
self.current_image = image
|
|
self.id_live_preview += 1
|