71 lines
1.4 KiB
Python
71 lines
1.4 KiB
Python
"""
|
|
Data model for the riffusion API.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import typing as T
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PromptInput:
|
|
"""
|
|
Parameters for one end of interpolation.
|
|
"""
|
|
|
|
# Text prompt fed into a CLIP model
|
|
prompt: str
|
|
|
|
# Random seed for denoising
|
|
seed: int
|
|
|
|
# Denoising strength
|
|
denoising: float = 0.75
|
|
|
|
# Classifier-free guidance strength
|
|
guidance: float = 7.0
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class InferenceInput:
|
|
"""
|
|
Parameters for a single run of the riffusion model, interpolating between
|
|
a start and end set of PromptInputs. This is the API required for a request
|
|
to the model server.
|
|
"""
|
|
|
|
# Start point of interpolation
|
|
start: PromptInput
|
|
|
|
# End point of interpolation
|
|
end: PromptInput
|
|
|
|
# Interpolation alpha [0, 1]. A value of 0 uses start fully, a value of 1
|
|
# uses end fully.
|
|
alpha: float
|
|
|
|
# Number of inner loops of the diffusion model
|
|
num_inference_steps: int = 50
|
|
|
|
# Which seed image to use
|
|
seed_image_id: str = "og_beat"
|
|
|
|
# ID of mask image to use
|
|
mask_image_id: T.Optional[str] = None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class InferenceOutput:
|
|
"""
|
|
Response from the model inference server.
|
|
"""
|
|
|
|
# base64 encoded spectrogram image as a JPEG
|
|
image: str
|
|
|
|
# base64 encoded audio clip as an MP3
|
|
audio: str
|
|
|
|
# The duration of the audio clip
|
|
duration_s: float
|