make server work
This commit is contained in:
parent
6bee23d9b6
commit
79239a719d
|
@ -28,9 +28,11 @@ python -m pip install -r requirements.txt
|
|||
## Run
|
||||
Start the Flask server:
|
||||
```
|
||||
python -m riffusion.server --port 3013 --host 127.0.0.1 --checkpoint /path/to/diffusers_checkpoint
|
||||
python -m riffusion.server --port 3013 --host 127.0.0.1
|
||||
```
|
||||
|
||||
You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
|
||||
|
||||
The model endpoint is now available at `http://127.0.0.1:3013/run_inference` via POST request.
|
||||
|
||||
Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L28) for the API):
|
||||
|
|
|
@ -23,6 +23,7 @@ from .audio import wav_bytes_from_spectrogram_image, mp3_bytes_from_wav_bytes
|
|||
from .datatypes import InferenceInput, InferenceOutput
|
||||
from .riffusion_pipeline import RiffusionPipeline
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self._data_dir = kwargs["data_dir"]
|
||||
|
@ -31,7 +32,9 @@ class Model:
|
|||
self._vae = None
|
||||
|
||||
# Download entire seed image folder from huggingface hub
|
||||
self._seed_images_dir = snapshot_download("riffusion/riffusion-model-v1", allow_patterns="*.png")
|
||||
self._seed_images_dir = snapshot_download(
|
||||
"riffusion/riffusion-model-v1", allow_patterns="*.png"
|
||||
)
|
||||
|
||||
print(self._seed_images_dir)
|
||||
|
||||
|
@ -63,7 +66,9 @@ class Model:
|
|||
sample: torch.FloatTensor
|
||||
|
||||
# Use traced unet from hf hub
|
||||
unet_file = hf_hub_download("riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced")
|
||||
unet_file = hf_hub_download(
|
||||
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
|
||||
)
|
||||
unet_traced = torch.jit.load(unet_file)
|
||||
|
||||
class TracedUNet(torch.nn.Module):
|
||||
|
@ -80,7 +85,6 @@ class Model:
|
|||
|
||||
self._model = pipe
|
||||
|
||||
|
||||
def preprocess(self, request: Dict) -> Dict:
|
||||
"""
|
||||
Incorporate pre-processing required by the model if desired here.
|
||||
|
@ -113,7 +117,7 @@ class Model:
|
|||
return str(exception), 400
|
||||
|
||||
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
|
||||
with torch.inference_mode() and torch.cuda.amp.autocast(enabled = False):
|
||||
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
|
||||
response = self.compute(inputs)
|
||||
|
||||
return response
|
||||
|
|
|
@ -80,15 +80,7 @@ def load_model(checkpoint: str):
|
|||
|
||||
model = RiffusionPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
)
|
||||
|
||||
model = RiffusionPipeline.from_pretrained(
|
||||
"riffusion/riffusion-model-v1",
|
||||
revision="fp16",
|
||||
revision="main",
|
||||
torch_dtype=torch.float16,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
|
@ -99,7 +91,9 @@ def load_model(checkpoint: str):
|
|||
sample: torch.FloatTensor
|
||||
|
||||
# Using traced unet from hf hub
|
||||
unet_file = hf_hub_download("riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced")
|
||||
unet_file = hf_hub_download(
|
||||
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
|
||||
)
|
||||
unet_traced = torch.jit.load(unet_file)
|
||||
|
||||
class TracedUNet(torch.nn.Module):
|
||||
|
@ -107,6 +101,7 @@ def load_model(checkpoint: str):
|
|||
super().__init__()
|
||||
self.in_channels = model.unet.in_channels
|
||||
self.device = model.unet.device
|
||||
self.dtype = torch.float16
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
||||
|
@ -155,6 +150,7 @@ def run_inference():
|
|||
|
||||
return response
|
||||
|
||||
|
||||
# TODO(hayk): Enable cache here.
|
||||
# @functools.lru_cache()
|
||||
def compute(inputs: InferenceInput) -> str:
|
||||
|
@ -196,7 +192,6 @@ def compute(inputs: InferenceInput) -> str:
|
|||
return flask.jsonify(dataclasses.asdict(output))
|
||||
|
||||
|
||||
|
||||
def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
|
||||
"""
|
||||
Convert a PIL image into bytes of the given image format.
|
||||
|
|
Loading…
Reference in New Issue