make server work

This commit is contained in:
Hayk Martiros 2022-12-13 06:43:46 +00:00
parent 6bee23d9b6
commit 79239a719d
3 changed files with 24 additions and 23 deletions

View File

@ -28,9 +28,11 @@ python -m pip install -r requirements.txt
## Run
Start the Flask server:
```
python -m riffusion.server --port 3013 --host 127.0.0.1 --checkpoint /path/to/diffusers_checkpoint
python -m riffusion.server --port 3013 --host 127.0.0.1
```
You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
The model endpoint is now available at `http://127.0.0.1:3013/run_inference` via POST request.
Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L28) for the API):

View File

@ -23,6 +23,7 @@ from .audio import wav_bytes_from_spectrogram_image, mp3_bytes_from_wav_bytes
from .datatypes import InferenceInput, InferenceOutput
from .riffusion_pipeline import RiffusionPipeline
class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
@ -31,7 +32,9 @@ class Model:
self._vae = None
# Download entire seed image folder from huggingface hub
self._seed_images_dir = snapshot_download("riffusion/riffusion-model-v1", allow_patterns="*.png")
self._seed_images_dir = snapshot_download(
"riffusion/riffusion-model-v1", allow_patterns="*.png"
)
print(self._seed_images_dir)
@ -63,7 +66,9 @@ class Model:
sample: torch.FloatTensor
# Use traced unet from hf hub
unet_file = hf_hub_download("riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced")
unet_file = hf_hub_download(
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
)
unet_traced = torch.jit.load(unet_file)
class TracedUNet(torch.nn.Module):
@ -80,7 +85,6 @@ class Model:
self._model = pipe
def preprocess(self, request: Dict) -> Dict:
"""
Incorporate pre-processing required by the model if desired here.
@ -113,7 +117,7 @@ class Model:
return str(exception), 400
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
with torch.inference_mode() and torch.cuda.amp.autocast(enabled = False):
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
response = self.compute(inputs)
return response

View File

@ -80,15 +80,7 @@ def load_model(checkpoint: str):
model = RiffusionPipeline.from_pretrained(
checkpoint,
revision="fp16",
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
)
model = RiffusionPipeline.from_pretrained(
"riffusion/riffusion-model-v1",
revision="fp16",
revision="main",
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
@ -99,7 +91,9 @@ def load_model(checkpoint: str):
sample: torch.FloatTensor
# Using traced unet from hf hub
unet_file = hf_hub_download("riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced")
unet_file = hf_hub_download(
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
)
unet_traced = torch.jit.load(unet_file)
class TracedUNet(torch.nn.Module):
@ -107,6 +101,7 @@ def load_model(checkpoint: str):
super().__init__()
self.in_channels = model.unet.in_channels
self.device = model.unet.device
self.dtype = torch.float16
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
@ -155,6 +150,7 @@ def run_inference():
return response
# TODO(hayk): Enable cache here.
# @functools.lru_cache()
def compute(inputs: InferenceInput) -> str:
@ -196,7 +192,6 @@ def compute(inputs: InferenceInput) -> str:
return flask.jsonify(dataclasses.asdict(output))
def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
"""
Convert a PIL image into bytes of the given image format.