make server work

This commit is contained in:
Hayk Martiros 2022-12-13 06:43:46 +00:00
parent 6bee23d9b6
commit 79239a719d
3 changed files with 24 additions and 23 deletions

View File

@ -28,9 +28,11 @@ python -m pip install -r requirements.txt
## Run
Start the Flask server:
```
python -m riffusion.server --port 3013 --host 127.0.0.1 --checkpoint /path/to/diffusers_checkpoint
python -m riffusion.server --port 3013 --host 127.0.0.1
```
You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
The model endpoint is now available at `http://127.0.0.1:3013/run_inference` via POST request.
Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L28) for the API):

View File

@ -1,6 +1,6 @@
"""
This file can be used to build a Truss for deployment with Baseten.
If used, it should be renamed to model.py and placed alongside the other
If used, it should be renamed to model.py and placed alongside the other
files from /riffusion in the standard /model directory of the Truss.
For more on the Truss file format, see https://truss.baseten.co/
@ -23,6 +23,7 @@ from .audio import wav_bytes_from_spectrogram_image, mp3_bytes_from_wav_bytes
from .datatypes import InferenceInput, InferenceOutput
from .riffusion_pipeline import RiffusionPipeline
class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
@ -31,10 +32,12 @@ class Model:
self._vae = None
# Download entire seed image folder from huggingface hub
self._seed_images_dir = snapshot_download("riffusion/riffusion-model-v1", allow_patterns="*.png")
self._seed_images_dir = snapshot_download(
"riffusion/riffusion-model-v1", allow_patterns="*.png"
)
print(self._seed_images_dir)
def load(self):
# Load Riffusion model here and assign to self._model.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -61,11 +64,13 @@ class Model:
@dataclasses.dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
# Use traced unet from hf hub
unet_file = hf_hub_download("riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced")
unet_file = hf_hub_download(
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
)
unet_traced = torch.jit.load(unet_file)
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
@ -77,9 +82,8 @@ class Model:
return UNet2DConditionOutput(sample=sample)
pipe.unet = TracedUNet()
self._model = pipe
def preprocess(self, request: Dict) -> Dict:
"""
@ -113,7 +117,7 @@ class Model:
return str(exception), 400
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
with torch.inference_mode() and torch.cuda.amp.autocast(enabled = False):
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
response = self.compute(inputs)
return response

View File

@ -80,15 +80,7 @@ def load_model(checkpoint: str):
model = RiffusionPipeline.from_pretrained(
checkpoint,
revision="fp16",
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
)
model = RiffusionPipeline.from_pretrained(
"riffusion/riffusion-model-v1",
revision="fp16",
revision="main",
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
@ -97,16 +89,19 @@ def load_model(checkpoint: str):
@dataclasses.dataclass
class UNet2DConditionOutput:
sample: torch.FloatTensor
# Using traced unet from hf hub
unet_file = hf_hub_download("riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced")
unet_file = hf_hub_download(
"riffusion/riffusion-model-v1", filename="unet_traced.pt", subfolder="unet_traced"
)
unet_traced = torch.jit.load(unet_file)
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = model.unet.in_channels
self.device = model.unet.device
self.dtype = torch.float16
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
@ -155,6 +150,7 @@ def run_inference():
return response
# TODO(hayk): Enable cache here.
# @functools.lru_cache()
def compute(inputs: InferenceInput) -> str:
@ -196,7 +192,6 @@ def compute(inputs: InferenceInput) -> str:
return flask.jsonify(dataclasses.asdict(output))
def image_bytes_from_image(image: PIL.Image, mode: str = "PNG") -> io.BytesIO:
"""
Convert a PIL image into bytes of the given image format.