diff --git a/modules/api/api.py b/modules/api/api.py index cabccb4c0..946cfe4a9 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -312,8 +312,13 @@ class Api: script_args[script.args_from:script.args_to] = ui_default_values return script_args - def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner): + def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None): script_args = default_script_args.copy() + + if input_script_args is not None: + for index, value in input_script_args.items(): + script_args[index] = value + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() if selectable_scripts: script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args @@ -335,41 +340,58 @@ class Api: script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] return script_args - def apply_infotext(self, request, tabname): + def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None): if not request.infotext: return {} + possible_fields = generation_parameters_copypaste.paste_fields[tabname]["fields"] set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) - handled_fields = {} - for field in generation_parameters_copypaste.paste_fields[tabname]["fields"]: + def get_field_value(field, params): + value = field.function(params) if field.function else params.get(field.label) + if value is None: + return None + + if field.api in request.__fields__: + target_type = request.__fields__[field.api].type_ + else: + target_type = type(field.component.value) + + if target_type == type(None): + return None + + if not isinstance(value, target_type): + value = target_type(value) + + return value + + for field in possible_fields: if not field.api: continue if field.api in set_fields: continue - value = field.function(params) if field.function else params.get(field.label) - target_type = request.__fields__[field.api].type_ - - if value is None: - continue - - if not isinstance(value, target_type): - value = target_type(value) - - setattr(request, field.api, value) - handled_fields[field.label] = 1 + value = get_field_value(field, params) + if value is not None: + setattr(request, field.api, value) if request.override_settings is None: request.override_settings = {} - overriden_settings = generation_parameters_copypaste.get_override_settings(params, skip_fields=handled_fields) - for infotext_text, setting_name, value in overriden_settings: + overriden_settings = generation_parameters_copypaste.get_override_settings(params) + for _, setting_name, value in overriden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value + if script_runner is not None and mentioned_script_args is not None: + indexes = {v: i for i, v in enumerate(script_runner.inputs)} + script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes) + + for field, index in script_fields: + mentioned_script_args[index] = get_field_value(field, params) + return params def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): @@ -380,7 +402,8 @@ class Api: script_runner.initialize_scripts(False) ui.create_ui() - self.apply_infotext(txt2imgreq, "txt2img") + infotext_script_args = {} + self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) if not self.default_script_arg_txt2img: self.default_script_arg_txt2img = self.init_default_script_args(script_runner) @@ -400,7 +423,7 @@ class Api: args.pop('alwayson_scripts', None) args.pop('infotext', None) - script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) + script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args) send_images = args.pop('send_images', True) args.pop('save_images', None)