change StableDiffusionProcessing to internally use sampler name instead of sampler index
This commit is contained in:
parent
d9fd4525a5
commit
cdc8020d13
|
@ -6,9 +6,9 @@ from threading import Lock
|
||||||
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
|
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
|
||||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules import sd_samplers
|
||||||
from modules.api.models import *
|
from modules.api.models import *
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.sd_samplers import all_samplers
|
|
||||||
from modules.extras import run_extras, run_pnginfo
|
from modules.extras import run_extras, run_pnginfo
|
||||||
from PIL import PngImagePlugin
|
from PIL import PngImagePlugin
|
||||||
from modules.sd_models import checkpoints_list
|
from modules.sd_models import checkpoints_list
|
||||||
|
@ -25,8 +25,12 @@ def upscaler_to_index(name: str):
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
|
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
|
||||||
|
|
||||||
|
|
||||||
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
def validate_sampler_name(name):
|
||||||
|
config = sd_samplers.all_samplers_map.get(name, None)
|
||||||
|
if config is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||||
|
|
||||||
|
return name
|
||||||
|
|
||||||
def setUpscalers(req: dict):
|
def setUpscalers(req: dict):
|
||||||
reqDict = vars(req)
|
reqDict = vars(req)
|
||||||
|
@ -82,14 +86,9 @@ class Api:
|
||||||
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||||
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
|
||||||
|
|
||||||
if sampler_index is None:
|
|
||||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
|
||||||
|
|
||||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||||
"sd_model": shared.sd_model,
|
"sd_model": shared.sd_model,
|
||||||
"sampler_index": sampler_index[0],
|
"sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
|
||||||
"do_not_save_samples": True,
|
"do_not_save_samples": True,
|
||||||
"do_not_save_grid": True
|
"do_not_save_grid": True
|
||||||
}
|
}
|
||||||
|
@ -109,12 +108,6 @@ class Api:
|
||||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||||
|
|
||||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
||||||
sampler_index = sampler_to_index(img2imgreq.sampler_index)
|
|
||||||
|
|
||||||
if sampler_index is None:
|
|
||||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
|
||||||
|
|
||||||
|
|
||||||
init_images = img2imgreq.init_images
|
init_images = img2imgreq.init_images
|
||||||
if init_images is None:
|
if init_images is None:
|
||||||
raise HTTPException(status_code=404, detail="Init image not found")
|
raise HTTPException(status_code=404, detail="Init image not found")
|
||||||
|
@ -123,10 +116,9 @@ class Api:
|
||||||
if mask:
|
if mask:
|
||||||
mask = decode_base64_to_image(mask)
|
mask = decode_base64_to_image(mask)
|
||||||
|
|
||||||
|
|
||||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||||
"sd_model": shared.sd_model,
|
"sd_model": shared.sd_model,
|
||||||
"sampler_index": sampler_index[0],
|
"sampler_name": validate_sampler_name(img2imgreq.sampler_index),
|
||||||
"do_not_save_samples": True,
|
"do_not_save_samples": True,
|
||||||
"do_not_save_grid": True,
|
"do_not_save_grid": True,
|
||||||
"mask": mask
|
"mask": mask
|
||||||
|
@ -272,7 +264,7 @@ class Api:
|
||||||
return vars(shared.cmd_opts)
|
return vars(shared.cmd_opts)
|
||||||
|
|
||||||
def get_samplers(self):
|
def get_samplers(self):
|
||||||
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
|
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
|
||||||
|
|
||||||
def get_upscalers(self):
|
def get_upscalers(self):
|
||||||
upscalers = []
|
upscalers = []
|
||||||
|
|
|
@ -12,7 +12,7 @@ import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from modules import devices, processing, sd_models, shared
|
from modules import devices, processing, sd_models, shared, sd_samplers
|
||||||
from modules.textual_inversion import textual_inversion
|
from modules.textual_inversion import textual_inversion
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
p.prompt = preview_prompt
|
p.prompt = preview_prompt
|
||||||
p.negative_prompt = preview_negative_prompt
|
p.negative_prompt = preview_negative_prompt
|
||||||
p.steps = preview_steps
|
p.steps = preview_steps
|
||||||
p.sampler_index = preview_sampler_index
|
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||||
p.cfg_scale = preview_cfg_scale
|
p.cfg_scale = preview_cfg_scale
|
||||||
p.seed = preview_seed
|
p.seed = preview_seed
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
|
|
|
@ -303,7 +303,7 @@ class FilenameGenerator:
|
||||||
'width': lambda self: self.image.width,
|
'width': lambda self: self.image.width,
|
||||||
'height': lambda self: self.image.height,
|
'height': lambda self: self.image.height,
|
||||||
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
||||||
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
|
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
|
||||||
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
||||||
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||||
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
||||||
|
|
|
@ -6,7 +6,7 @@ import traceback
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageChops
|
from PIL import Image, ImageOps, ImageChops
|
||||||
|
|
||||||
from modules import devices
|
from modules import devices, sd_samplers
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
||||||
seed_resize_from_h=seed_resize_from_h,
|
seed_resize_from_h=seed_resize_from_h,
|
||||||
seed_resize_from_w=seed_resize_from_w,
|
seed_resize_from_w=seed_resize_from_w,
|
||||||
seed_enable_extras=seed_enable_extras,
|
seed_enable_extras=seed_enable_extras,
|
||||||
sampler_index=sampler_index,
|
sampler_index=sd_samplers.samplers_for_img2img[sampler_index].name,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
|
|
|
@ -2,6 +2,7 @@ import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -66,19 +67,15 @@ def apply_overlay(image, paste_loc, index, overlays):
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def get_correct_sampler(p):
|
|
||||||
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
|
|
||||||
return sd_samplers.samplers
|
|
||||||
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
|
||||||
return sd_samplers.samplers_for_img2img
|
|
||||||
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
|
|
||||||
return sd_samplers.samplers
|
|
||||||
|
|
||||||
class StableDiffusionProcessing():
|
class StableDiffusionProcessing():
|
||||||
"""
|
"""
|
||||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||||
"""
|
"""
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
|
||||||
|
if sampler_index is not None:
|
||||||
|
warnings.warn("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name")
|
||||||
|
|
||||||
self.sd_model = sd_model
|
self.sd_model = sd_model
|
||||||
self.outpath_samples: str = outpath_samples
|
self.outpath_samples: str = outpath_samples
|
||||||
self.outpath_grids: str = outpath_grids
|
self.outpath_grids: str = outpath_grids
|
||||||
|
@ -91,7 +88,7 @@ class StableDiffusionProcessing():
|
||||||
self.subseed_strength: float = subseed_strength
|
self.subseed_strength: float = subseed_strength
|
||||||
self.seed_resize_from_h: int = seed_resize_from_h
|
self.seed_resize_from_h: int = seed_resize_from_h
|
||||||
self.seed_resize_from_w: int = seed_resize_from_w
|
self.seed_resize_from_w: int = seed_resize_from_w
|
||||||
self.sampler_index: int = sampler_index
|
self.sampler_name: str = sampler_name
|
||||||
self.batch_size: int = batch_size
|
self.batch_size: int = batch_size
|
||||||
self.n_iter: int = n_iter
|
self.n_iter: int = n_iter
|
||||||
self.steps: int = steps
|
self.steps: int = steps
|
||||||
|
@ -210,8 +207,7 @@ class Processed:
|
||||||
self.info = info
|
self.info = info
|
||||||
self.width = p.width
|
self.width = p.width
|
||||||
self.height = p.height
|
self.height = p.height
|
||||||
self.sampler_index = p.sampler_index
|
self.sampler_name = p.sampler_name
|
||||||
self.sampler = sd_samplers.samplers[p.sampler_index].name
|
|
||||||
self.cfg_scale = p.cfg_scale
|
self.cfg_scale = p.cfg_scale
|
||||||
self.steps = p.steps
|
self.steps = p.steps
|
||||||
self.batch_size = p.batch_size
|
self.batch_size = p.batch_size
|
||||||
|
@ -256,8 +252,7 @@ class Processed:
|
||||||
"subseed_strength": self.subseed_strength,
|
"subseed_strength": self.subseed_strength,
|
||||||
"width": self.width,
|
"width": self.width,
|
||||||
"height": self.height,
|
"height": self.height,
|
||||||
"sampler_index": self.sampler_index,
|
"sampler_name": self.sampler_name,
|
||||||
"sampler": self.sampler,
|
|
||||||
"cfg_scale": self.cfg_scale,
|
"cfg_scale": self.cfg_scale,
|
||||||
"steps": self.steps,
|
"steps": self.steps,
|
||||||
"batch_size": self.batch_size,
|
"batch_size": self.batch_size,
|
||||||
|
@ -384,7 +379,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
"Sampler": get_correct_sampler(p)[p.sampler_index].name,
|
"Sampler": p.sampler_name,
|
||||||
"CFG scale": p.cfg_scale,
|
"CFG scale": p.cfg_scale,
|
||||||
"Seed": all_seeds[index],
|
"Seed": all_seeds[index],
|
||||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||||
|
@ -645,7 +640,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
|
@ -706,7 +701,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
|
||||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
|
|
||||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
|
|
||||||
|
@ -743,7 +738,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
self.image_conditioning = None
|
self.image_conditioning = None
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
crop_region = None
|
crop_region = None
|
||||||
|
|
||||||
if self.image_mask is not None:
|
if self.image_mask is not None:
|
||||||
|
|
|
@ -46,16 +46,23 @@ all_samplers = [
|
||||||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||||
]
|
]
|
||||||
|
all_samplers_map = {x.name: x for x in all_samplers}
|
||||||
|
|
||||||
samplers = []
|
samplers = []
|
||||||
samplers_for_img2img = []
|
samplers_for_img2img = []
|
||||||
|
|
||||||
|
|
||||||
def create_sampler_with_index(list_of_configs, index, model):
|
def create_sampler(name, model):
|
||||||
config = list_of_configs[index]
|
if name is not None:
|
||||||
|
config = all_samplers_map.get(name, None)
|
||||||
|
else:
|
||||||
|
config = all_samplers[0]
|
||||||
|
|
||||||
|
assert config is not None, f'bad sampler name: {name}'
|
||||||
|
|
||||||
sampler = config.constructor(model)
|
sampler = config.constructor(model)
|
||||||
sampler.config = config
|
sampler.config = config
|
||||||
|
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ import csv
|
||||||
|
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models, images
|
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
|
||||||
|
@ -345,7 +345,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||||
p.prompt = preview_prompt
|
p.prompt = preview_prompt
|
||||||
p.negative_prompt = preview_negative_prompt
|
p.negative_prompt = preview_negative_prompt
|
||||||
p.steps = preview_steps
|
p.steps = preview_steps
|
||||||
p.sampler_index = preview_sampler_index
|
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||||
p.cfg_scale = preview_cfg_scale
|
p.cfg_scale = preview_cfg_scale
|
||||||
p.seed = preview_seed
|
p.seed = preview_seed
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
|
from modules import sd_samplers
|
||||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
||||||
StableDiffusionProcessingImg2Img, process_images
|
StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
|
@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
||||||
seed_resize_from_h=seed_resize_from_h,
|
seed_resize_from_h=seed_resize_from_h,
|
||||||
seed_resize_from_w=seed_resize_from_w,
|
seed_resize_from_w=seed_resize_from_w,
|
||||||
seed_enable_extras=seed_enable_extras,
|
seed_enable_extras=seed_enable_extras,
|
||||||
sampler_index=sampler_index,
|
sampler_name=sd_samplers.samplers[sampler_index].name,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
|
|
|
@ -142,7 +142,7 @@ def save_files(js_data, images, do_make_zip, index):
|
||||||
filenames.append(os.path.basename(txt_fullfn))
|
filenames.append(os.path.basename(txt_fullfn))
|
||||||
fullfns.append(txt_fullfn)
|
fullfns.append(txt_fullfn)
|
||||||
|
|
||||||
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
||||||
|
|
||||||
# Make Zip
|
# Make Zip
|
||||||
if do_make_zip:
|
if do_make_zip:
|
||||||
|
|
|
@ -157,7 +157,7 @@ class Script(scripts.Script):
|
||||||
def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
|
def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
|
||||||
# Override
|
# Override
|
||||||
if override_sampler:
|
if override_sampler:
|
||||||
p.sampler_index = [sampler.name for sampler in sd_samplers.samplers].index("Euler")
|
p.sampler_name = "Euler"
|
||||||
if override_prompt:
|
if override_prompt:
|
||||||
p.prompt = original_prompt
|
p.prompt = original_prompt
|
||||||
p.negative_prompt = original_negative_prompt
|
p.negative_prompt = original_negative_prompt
|
||||||
|
@ -191,7 +191,7 @@ class Script(scripts.Script):
|
||||||
|
|
||||||
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
|
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
|
||||||
|
|
||||||
sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model)
|
sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
|
||||||
|
|
||||||
sigmas = sampler.model_wrap.get_sigmas(p.steps)
|
sigmas = sampler.model_wrap.get_sigmas(p.steps)
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,9 @@ import numpy as np
|
||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import images
|
from modules import images, sd_samplers
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img
|
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.sd_samplers
|
import modules.sd_samplers
|
||||||
|
@ -60,9 +60,9 @@ def apply_order(p, x, xs):
|
||||||
p.prompt = prompt_tmp + p.prompt
|
p.prompt = prompt_tmp + p.prompt
|
||||||
|
|
||||||
|
|
||||||
def build_samplers_dict(p):
|
def build_samplers_dict():
|
||||||
samplers_dict = {}
|
samplers_dict = {}
|
||||||
for i, sampler in enumerate(get_correct_sampler(p)):
|
for i, sampler in enumerate(sd_samplers.all_samplers):
|
||||||
samplers_dict[sampler.name.lower()] = i
|
samplers_dict[sampler.name.lower()] = i
|
||||||
for alias in sampler.aliases:
|
for alias in sampler.aliases:
|
||||||
samplers_dict[alias.lower()] = i
|
samplers_dict[alias.lower()] = i
|
||||||
|
@ -70,7 +70,7 @@ def build_samplers_dict(p):
|
||||||
|
|
||||||
|
|
||||||
def apply_sampler(p, x, xs):
|
def apply_sampler(p, x, xs):
|
||||||
sampler_index = build_samplers_dict(p).get(x.lower(), None)
|
sampler_index = build_samplers_dict().get(x.lower(), None)
|
||||||
if sampler_index is None:
|
if sampler_index is None:
|
||||||
raise RuntimeError(f"Unknown sampler: {x}")
|
raise RuntimeError(f"Unknown sampler: {x}")
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def apply_sampler(p, x, xs):
|
||||||
|
|
||||||
|
|
||||||
def confirm_samplers(p, xs):
|
def confirm_samplers(p, xs):
|
||||||
samplers_dict = build_samplers_dict(p)
|
samplers_dict = build_samplers_dict()
|
||||||
for x in xs:
|
for x in xs:
|
||||||
if x.lower() not in samplers_dict.keys():
|
if x.lower() not in samplers_dict.keys():
|
||||||
raise RuntimeError(f"Unknown sampler: {x}")
|
raise RuntimeError(f"Unknown sampler: {x}")
|
||||||
|
|
Loading…
Reference in New Issue