62 lines
1.3 KiB
Python
62 lines
1.3 KiB
Python
|
"""
|
||
|
Data model for the riffusion API.
|
||
|
"""
|
||
|
|
||
|
from dataclasses import dataclass
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class PromptInput:
|
||
|
"""
|
||
|
Parameters for one end of interpolation.
|
||
|
"""
|
||
|
|
||
|
# Text prompt fed into a CLIP model
|
||
|
prompt: str
|
||
|
|
||
|
# Random seed for denoising
|
||
|
seed: int
|
||
|
|
||
|
# Denoising strength
|
||
|
denoising: float = 0.75
|
||
|
|
||
|
# Classifier-free guidance strength
|
||
|
guidance: float = 7.0
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class InferenceInput:
|
||
|
"""
|
||
|
Parameters for a single run of the riffusion model, interpolating between
|
||
|
a start and end set of PromptInputs. This is the API required for a request
|
||
|
to the model server.
|
||
|
"""
|
||
|
|
||
|
# Start point of interpolation
|
||
|
start: PromptInput
|
||
|
|
||
|
# End point of interpolation
|
||
|
end: PromptInput
|
||
|
|
||
|
# Interpolation alpha [0, 1]. A value of 0 uses start fully, a value of 1
|
||
|
# uses end fully.
|
||
|
alpha: float
|
||
|
|
||
|
# Number of inner loops of the diffusion model
|
||
|
num_inference_steps: int = 50
|
||
|
|
||
|
# Which seed image to use
|
||
|
# TODO(hayk): Convert this to a string ID and add a seed image + mask API.
|
||
|
seed_image_id: int = 0
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class InferenceOutput:
|
||
|
"""
|
||
|
Response from the model server. Contains a base64 encoded spectrogram image and a base64
|
||
|
encoded MP3 audio clip.
|
||
|
"""
|
||
|
|
||
|
image: str
|
||
|
audio: str
|