diff --git a/modules/api/api.py b/modules/api/api.py index 6bb01603f..195e8b583 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -9,9 +9,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared +from modules import sd_samplers from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo from PIL import PngImagePlugin from modules.sd_models import checkpoints_list @@ -28,8 +28,12 @@ def upscaler_to_index(name: str): raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") -sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) +def validate_sampler_name(name): + config = sd_samplers.all_samplers_map.get(name, None) + if config is None: + raise HTTPException(status_code=404, detail="Sampler not found") + return name def setUpscalers(req: dict): reqDict = vars(req) @@ -77,6 +81,7 @@ class Api: self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) + self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) @@ -103,14 +108,9 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - sampler_index = sampler_to_index(txt2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(txt2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True } @@ -130,12 +130,6 @@ class Api: return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): - sampler_index = sampler_to_index(img2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - - init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -144,10 +138,9 @@ class Api: if mask: mask = decode_base64_to_image(mask) - populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(img2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True, "mask": mask @@ -266,6 +259,9 @@ class Api: return {} + def skip(self): + shared.state.skip() + def get_config(self): options = {} for key in shared.opts.data.keys(): @@ -277,14 +273,10 @@ class Api: return options - def set_config(self, req: OptionsModel): - # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will - # overwrite all options with default values. - raise RuntimeError('Setting options via API is not supported') - - reqDict = vars(req) - for o in reqDict: - setattr(shared.opts, o, reqDict[o]) + def set_config(self, req: Dict[str, Any]): + + for o in req: + setattr(shared.opts, o, req[o]) shared.opts.save(shared.config_filename) return @@ -293,7 +285,7 @@ class Api: return vars(shared.cmd_opts) def get_samplers(self): - return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers] + return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] def get_upscalers(self): upscalers = [] diff --git a/modules/api/models.py b/modules/api/models.py index f9cd929e6..f77951fc2 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -176,9 +176,9 @@ class InterrogateResponse(BaseModel): caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") fields = {} -for key, value in opts.data.items(): - metadata = opts.data_labels.get(key) - optType = opts.typemap.get(type(value), type(value)) +for key, metadata in opts.data_labels.items(): + value = opts.data.get(key) + optType = opts.typemap.get(type(metadata.default), type(value)) if (metadata is not None): fields.update({key: (Optional[optType], Field( diff --git a/modules/extensions.py b/modules/extensions.py index 94ce479a2..db9c4200f 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -65,9 +65,12 @@ class Extension: self.can_update = False self.status = "latest" - def pull(self): + def fetch_and_reset_hard(self): repo = git.Repo(self.path) - repo.remotes.origin.pull() + # Fix: `error: Your local changes to the following files would be overwritten by merge`, + # because WSL2 Docker set 755 file permissions instead of 644, this results to the error. + repo.git.fetch('--all') + repo.git.reset('--hard', 'origin') def list_extensions(): diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 985ec95eb..1408ea053 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -73,6 +73,7 @@ def integrate_settings_paste_fields(component_dict): 'sd_hypernetwork': 'Hypernet', 'sd_hypernetwork_strength': 'Hypernet strength', 'CLIP_stop_at_last_layers': 'Clip skip', + 'inpainting_mask_weight': 'Conditional mask weight', 'sd_model_checkpoint': 'Model hash', } settings_paste_fields = [ diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7f182712b..fbb87dd14 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared +from modules import devices, processing, sd_models, shared, sd_samplers from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt p.steps = preview_steps - p.sampler_index = preview_sampler_index + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name p.cfg_scale = preview_cfg_scale p.seed = preview_seed p.width = preview_width diff --git a/modules/images.py b/modules/images.py index ae705cbd5..26d5b7a95 100644 --- a/modules/images.py +++ b/modules/images.py @@ -303,7 +303,7 @@ class FilenameGenerator: 'width': lambda self: self.image.width, 'height': lambda self: self.image.height, 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), - 'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False), + 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime