diff --git a/modules/images.py b/modules/images.py index 09d3523e8..c0ff8a630 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,7 +1,7 @@ from __future__ import annotations import datetime - +import functools import pytz import io import math @@ -349,6 +349,32 @@ def sanitize_filename_part(text, replace_spaces=True): return text +@functools.cache +def get_scheduler_str(sampler_name, scheduler_name): + """Returns {Scheduler} if the scheduler is applicable to the sampler""" + if scheduler_name == 'Automatic': + config = sd_samplers.find_sampler_config(sampler_name) + scheduler_name = config.options.get('scheduler', 'Automatic') + return scheduler_name.capitalize() + + +@functools.cache +def get_sampler_scheduler_str(sampler_name, scheduler_name): + """Returns the '{Sampler} {Scheduler}' if the scheduler is applicable to the sampler""" + return f'{sampler_name} {get_scheduler_str(sampler_name, scheduler_name)}' + + +def get_sampler_scheduler(p, sampler): + """Returns '{Sampler} {Scheduler}' / '{Scheduler}' / 'NOTHING_AND_SKIP_PREVIOUS_TEXT'""" + if hasattr(p, 'scheduler') and hasattr(p, 'sampler_name'): + if sampler: + sampler_scheduler = get_sampler_scheduler_str(p.sampler_name, p.scheduler) + else: + sampler_scheduler = get_scheduler_str(p.sampler_name, p.scheduler) + return sanitize_filename_part(sampler_scheduler, replace_spaces=False) + return NOTHING_AND_SKIP_PREVIOUS_TEXT + + class FilenameGenerator: replacements = { 'seed': lambda self: self.seed if self.seed is not None else '', @@ -360,6 +386,8 @@ class FilenameGenerator: 'height': lambda self: self.image.height, 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False), + 'sampler_scheduler': lambda self: self.p and get_sampler_scheduler(self.p, True), + 'scheduler': lambda self: self.p and get_sampler_scheduler(self.p, False), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),