diff --git a/modules/processing.py b/modules/processing.py index c983d001e..25a19a778 100755 --- a/modules/processing.py +++ b/modules/processing.py @@ -746,7 +746,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.tiling is None: p.tiling = opts.tiling - if p.refiner_checkpoint not in (None, "", "None"): + if p.refiner_checkpoint not in (None, "", "None", "none"): p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint) if p.refiner_checkpoint_info is None: raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}') diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index b389c4ef9..29ccb78f9 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -42,7 +42,7 @@ class ScriptRefiner(scripts.ScriptBuiltinUI): # the actual implementation is in sd_samplers_common.py, apply_refiner if not enable_refiner or refiner_checkpoint in (None, "", "None"): - p.refiner_checkpoint_info = None + p.refiner_checkpoint = None p.refiner_switch_at = None else: p.refiner_checkpoint = refiner_checkpoint diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 017a470f5..2217cc69f 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -85,20 +85,12 @@ def confirm_checkpoints(p, xs): if modules.sd_models.get_closet_checkpoint_match(x) is None: raise RuntimeError(f"Unknown checkpoint: {x}") -def apply_refiner_checkpoint(p, x, xs): - if x == 'None': - p.override_settings['sd_refiner_checkpoint'] = 'None' - return - info = modules.sd_models.get_closet_checkpoint_match(x) - if info is None: - raise RuntimeError(f"Unknown checkpoint: {x}") - p.override_settings['sd_refiner_checkpoint'] = info.name - -def confirm_refiner_checkpoints(p, xs): +def confirm_checkpoints_or_none(p, xs): for x in xs: - if x == 'None': + if x in (None, "", "None", "none"): continue + if modules.sd_models.get_closet_checkpoint_match(x) is None: raise RuntimeError(f"Unknown checkpoint: {x}") @@ -267,8 +259,8 @@ axis_options = [ AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')), AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')), AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)), - AxisOption("Refiner checkpoint", str, apply_refiner_checkpoint, format_value=format_remove_path, confirm=confirm_refiner_checkpoints, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)), - AxisOption("Refiner switch at", float, apply_override('sd_refiner_switch_at')) + AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)), + AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')), ]