From bdc90837987ed8919dd611fd01553b0c170ded5c Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Thu, 27 Oct 2022 15:20:15 -0400 Subject: [PATCH 1/3] Add a barebones interrogate API --- launch.py | 2 +- modules/api/api.py | 25 ++++++++++++++++++++++++- modules/api/models.py | 13 ++++++++++++- webui.py | 12 ++++++++---- 4 files changed, 45 insertions(+), 7 deletions(-) diff --git a/launch.py b/launch.py index 8affd4109..ae79b4a42 100644 --- a/launch.py +++ b/launch.py @@ -198,7 +198,7 @@ def prepare_enviroment(): def start_webui(): print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}") import webui - webui.webui() + webui.webui_or_api() if __name__ == "__main__": diff --git a/modules/api/api.py b/modules/api/api.py index 6e9d6097b..eabdb7b8a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,4 +1,4 @@ -from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI +from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI, InterrogateAPI from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers from modules.extras import run_pnginfo @@ -25,6 +25,11 @@ class ImageToImageResponse(BaseModel): parameters: Json info: Json +class InterrogateResponse(BaseModel): + caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") + parameters: Json + info: Json + class Api: def __init__(self, app, queue_lock): @@ -33,6 +38,7 @@ class Api: self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) def __base64_to_image(self, base64_string): # if has a comma, deal with prefix @@ -118,6 +124,23 @@ class Api: return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js()) + def interrogateapi(self, interrogatereq: InterrogateAPI): + image_b64 = interrogatereq.image + if image_b64 is None: + raise HTTPException(status_code=404, detail="Image not found") + + populate = interrogatereq.copy(update={ # Override __init__ params + } + ) + + img = self.__base64_to_image(image_b64) + + # Override object param + with self.queue_lock: + processed = shared.interrogator.interrogate(img) + + return InterrogateResponse(caption=processed, parameters=json.dumps(vars(interrogatereq)), info=None) + def extrasapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index 079e33d9b..8be64749a 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -63,7 +63,12 @@ class PydanticModelGenerator: self._model_name = model_name - self._class_data = merge_class_params(class_instance) + + if class_instance is not None: + self._class_data = merge_class_params(class_instance) + else: + self._class_data = {} + self._model_def = [ ModelDef( field=underscore(k), @@ -105,4 +110,10 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}] +).generate_model() + +InterrogateAPI = PydanticModelGenerator( + "Interrogate", + None, + [{"key": "image", "type": str, "default": None}] ).generate_model() \ No newline at end of file diff --git a/webui.py b/webui.py index ade7334bf..7a4bb2a23 100644 --- a/webui.py +++ b/webui.py @@ -146,7 +146,9 @@ def webui(): app.add_middleware(GZipMiddleware, minimum_size=1000) if (launch_api): - create_api(app) + print('launching API') + api = create_api(app) + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) wait_on_server(demo) @@ -161,10 +163,12 @@ def webui(): print('Restarting Gradio') - -task = [] -if __name__ == "__main__": +def webui_or_api(): if cmd_opts.nowebui: api_only() else: webui() + +task = [] +if __name__ == "__main__": + webui_or_api() \ No newline at end of file From df6a7ebfe8cc4da23861e3e2583693bb7808d573 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 31 Oct 2022 11:50:33 -0400 Subject: [PATCH 2/3] revert things to master --- launch.py | 2 +- modules/api/api.py | 2 -- modules/api/models.py | 6 +----- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/launch.py b/launch.py index fe9cef3cc..958336f21 100644 --- a/launch.py +++ b/launch.py @@ -220,7 +220,7 @@ def tests(argv): def start_webui(): print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}") import webui - webui.webui_or_api() + webui.webui() if __name__ == "__main__": diff --git a/modules/api/api.py b/modules/api/api.py index c510a8334..6a903e4c0 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -117,8 +117,6 @@ class Api: return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js()) - def extrasapi(self): - raise NotImplementedError def extras_single_image_api(self, req: ExtrasSingleImageRequest): reqDict = setUpscalers(req) diff --git a/modules/api/models.py b/modules/api/models.py index 035a7179c..82ab29b84 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -64,11 +64,7 @@ class PydanticModelGenerator: self._model_name = model_name - - if class_instance is not None: - self._class_data = merge_class_params(class_instance) - else: - self._class_data = {} + self._class_data = merge_class_params(class_instance) self._model_def = [ ModelDef( From 3f3d14afd5abd07d3843370dc1c28be299dbdbab Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 31 Oct 2022 11:51:21 -0400 Subject: [PATCH 3/3] nix unused thing --- modules/api/api.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 6a903e4c0..536e3f161 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -182,10 +182,6 @@ class Api: if image_b64 is None: raise HTTPException(status_code=404, detail="Image not found") - populate = interrogatereq.copy(update={ # Override __init__ params - } - ) - img = self.__base64_to_image(image_b64) # Override object param