diff --git a/riffusion/server.py b/riffusion/server.py index a768ad0..d0fcb6c 100644 --- a/riffusion/server.py +++ b/riffusion/server.py @@ -9,6 +9,7 @@ import io import json from pathlib import Path import time +import typing as T import dacite import flask @@ -43,6 +44,8 @@ def run_app( host: str = "127.0.0.1", port: int = 3000, debug: bool = False, + ssl_certificate: T.Optional[str] = None, + ssl_key: T.Optional[str] = None ): """ Run a flask API that serves the given riffusion model checkpoint. @@ -51,13 +54,19 @@ def run_app( global MODEL MODEL = load_model(checkpoint=checkpoint) - app.run( + args = dict( debug=debug, threaded=False, host=host, port=port, ) + if ssl_certificate: + assert ssl_key is not None + args["ssl_context"] = (ssl_certificate, ssl_key) + + app.run(**args) + def load_model(checkpoint: str): """