riffusion-inference/riffusion/datatypes.py

74 lines
1.5 KiB
Python
Raw Permalink Normal View History

2022-11-25 17:13:29 -07:00
"""
Data model for the riffusion API.
"""
from __future__ import annotations
2022-11-25 17:13:29 -07:00
2022-11-25 23:48:52 -07:00
import typing as T
from dataclasses import dataclass
2022-11-25 17:13:29 -07:00
2022-11-27 17:06:12 -07:00
@dataclass(frozen=True)
2022-11-25 17:13:29 -07:00
class PromptInput:
"""
Parameters for one end of interpolation.
"""
# Text prompt fed into a CLIP model
prompt: str
# Random seed for denoising
seed: int
# Negative prompt to avoid (optional)
negative_prompt: T.Optional[str] = None
2022-11-25 17:13:29 -07:00
# Denoising strength
denoising: float = 0.75
# Classifier-free guidance strength
guidance: float = 7.0
2022-11-27 17:06:12 -07:00
@dataclass(frozen=True)
2022-11-25 17:13:29 -07:00
class InferenceInput:
"""
Parameters for a single run of the riffusion model, interpolating between
a start and end set of PromptInputs. This is the API required for a request
to the model server.
"""
# Start point of interpolation
start: PromptInput
# End point of interpolation
end: PromptInput
# Interpolation alpha [0, 1]. A value of 0 uses start fully, a value of 1
# uses end fully.
alpha: float
# Number of inner loops of the diffusion model
num_inference_steps: int = 50
# Which seed image to use
2022-11-25 23:48:52 -07:00
seed_image_id: str = "og_beat"
# ID of mask image to use
mask_image_id: T.Optional[str] = None
2022-11-25 17:13:29 -07:00
2022-11-27 17:06:12 -07:00
@dataclass(frozen=True)
2022-11-25 17:13:29 -07:00
class InferenceOutput:
"""
Response from the model inference server.
2022-11-25 17:13:29 -07:00
"""
# base64 encoded spectrogram image as a JPEG
2022-11-25 17:13:29 -07:00
image: str
# base64 encoded audio clip as an MP3
2022-11-25 17:13:29 -07:00
audio: str
# The duration of the audio clip
duration_s: float