riffusion-inference/riffusion/datatypes.py

74 lines
1.5 KiB
Python

"""
Data model for the riffusion API.
"""
from __future__ import annotations
import typing as T
from dataclasses import dataclass
@dataclass(frozen=True)
class PromptInput:
"""
Parameters for one end of interpolation.
"""
# Text prompt fed into a CLIP model
prompt: str
# Random seed for denoising
seed: int
# Negative prompt to avoid (optional)
negative_prompt: T.Optional[str] = None
# Denoising strength
denoising: float = 0.75
# Classifier-free guidance strength
guidance: float = 7.0
@dataclass(frozen=True)
class InferenceInput:
"""
Parameters for a single run of the riffusion model, interpolating between
a start and end set of PromptInputs. This is the API required for a request
to the model server.
"""
# Start point of interpolation
start: PromptInput
# End point of interpolation
end: PromptInput
# Interpolation alpha [0, 1]. A value of 0 uses start fully, a value of 1
# uses end fully.
alpha: float
# Number of inner loops of the diffusion model
num_inference_steps: int = 50
# Which seed image to use
seed_image_id: str = "og_beat"
# ID of mask image to use
mask_image_id: T.Optional[str] = None
@dataclass(frozen=True)
class InferenceOutput:
"""
Response from the model inference server.
"""
# base64 encoded spectrogram image as a JPEG
image: str
# base64 encoded audio clip as an MP3
audio: str
# The duration of the audio clip
duration_s: float