use override of sd_vae
This commit is contained in:
parent
1f392517f8
commit
1e696b028a
|
@ -118,21 +118,16 @@ def apply_size(p, x: str, xs) -> None:
|
|||
|
||||
|
||||
def find_vae(name: str):
|
||||
if name.lower() in ['auto', 'automatic']:
|
||||
return modules.sd_vae.unspecified
|
||||
if name.lower() == 'none':
|
||||
return None
|
||||
else:
|
||||
choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
|
||||
if len(choices) == 0:
|
||||
print(f"No VAE found for {name}; using automatic")
|
||||
return modules.sd_vae.unspecified
|
||||
else:
|
||||
return modules.sd_vae.vae_dict[choices[0]]
|
||||
match name := name.lower().strip():
|
||||
case 'auto', 'automatic':
|
||||
return 'Automatic'
|
||||
case 'none':
|
||||
return 'None'
|
||||
return next((k for k in modules.sd_vae.vae_dict if k.lower() == name), print(f'No VAE found for {name}; using Automatic') or 'Automatic')
|
||||
|
||||
|
||||
def apply_vae(p, x, xs):
|
||||
modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
|
||||
p.override_settings['sd_vae'] = find_vae(x)
|
||||
|
||||
|
||||
def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
|
||||
|
@ -270,7 +265,7 @@ axis_options = [
|
|||
AxisOption("Extra noise", float, apply_override("img2img_extra_noise")),
|
||||
AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
|
||||
AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
|
||||
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)),
|
||||
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['Automatic', 'None'] + list(sd_vae.vae_dict)),
|
||||
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
|
||||
AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
|
||||
AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
|
||||
|
@ -399,10 +394,9 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
|
|||
|
||||
class SharedSettingsStackHelper(object):
|
||||
def __enter__(self):
|
||||
self.vae = opts.sd_vae
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_value, tb):
|
||||
opts.data["sd_vae"] = self.vae
|
||||
modules.sd_models.reload_model_weights()
|
||||
modules.sd_vae.reload_vae_weights()
|
||||
|
||||
|
|
Loading…
Reference in New Issue