import datetime import logging import threading import time from modules import errors, shared, devices from typing import Optional log = logging.getLogger(__name__) class State: skipped = False interrupted = False stopping_generation = False job = "" job_no = 0 job_count = 0 processing_has_refined_job_count = False job_timestamp = '0' sampling_step = 0 sampling_steps = 0 current_latent = None current_image = None current_image_sampling_step = 0 id_live_preview = 0 textinfo = None time_start = None server_start = None _server_command_signal = threading.Event() _server_command: Optional[str] = None def __init__(self): self.server_start = time.time() @property def need_restart(self) -> bool: # Compatibility getter for need_restart. return self.server_command == "restart" @need_restart.setter def need_restart(self, value: bool) -> None: # Compatibility setter for need_restart. if value: self.server_command = "restart" @property def server_command(self): return self._server_command @server_command.setter def server_command(self, value: Optional[str]) -> None: """ Set the server command to `value` and signal that it's been set. """ self._server_command = value self._server_command_signal.set() def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]: """ Wait for server command to get set; return and clear the value and signal. """ if self._server_command_signal.wait(timeout): self._server_command_signal.clear() req = self._server_command self._server_command = None return req return None def request_restart(self) -> None: self.interrupt() self.server_command = "restart" log.info("Received restart request") def skip(self): self.skipped = True log.info("Received skip request") def interrupt(self): self.interrupted = True log.info("Received interrupt request") def stop_generating(self): self.stopping_generation = True log.info("Received stop generating request") def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: self.do_set_current_image() self.job_no += 1 self.sampling_step = 0 self.current_image_sampling_step = 0 def dict(self): obj = { "skipped": self.skipped, "interrupted": self.interrupted, "stopping_generation": self.stopping_generation, "job": self.job, "job_count": self.job_count, "job_timestamp": self.job_timestamp, "job_no": self.job_no, "sampling_step": self.sampling_step, "sampling_steps": self.sampling_steps, } return obj def begin(self, job: str = "(unknown)"): self.sampling_step = 0 self.time_start = time.time() self.job_count = -1 self.processing_has_refined_job_count = False self.job_no = 0 self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") self.current_latent = None self.current_image = None self.current_image_sampling_step = 0 self.id_live_preview = 0 self.skipped = False self.interrupted = False self.stopping_generation = False self.textinfo = None self.job = job devices.torch_gc() log.info("Starting job %s", job) def end(self): duration = time.time() - self.time_start log.info("Ending job %s (%.2f seconds)", self.job, duration) self.job = "" self.job_count = 0 devices.torch_gc() def set_current_image(self): """if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly""" if not shared.parallel_processing_allowed: return if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1: self.do_set_current_image() def do_set_current_image(self): if self.current_latent is None: return import modules.sd_samplers try: if shared.opts.show_progress_grid: self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) else: self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) self.current_image_sampling_step = self.sampling_step except Exception: # when switching models during generation, VAE would be on CPU, so creating an image will fail. # we silently ignore this error errors.record_exception() def assign_current_image(self, image): if shared.opts.live_previews_image_format == 'jpeg' and image.mode == 'RGBA': image = image.convert('RGB') self.current_image = image self.id_live_preview += 1