From 6ccb80396842f244a1d125718cd7caa32b57eabb Mon Sep 17 00:00:00 2001 From: Hayk Martiros Date: Mon, 12 Dec 2022 21:43:34 -0800 Subject: [PATCH] Save settings --- components/DebugView.tsx | 5 +- components/ModelInference.tsx | 14 +- components/PromptPanel.tsx | 8 +- components/Settings.tsx | 238 ++++++++++------- components/Share.tsx | 479 ++++++++++++++++------------------ pages/about.tsx | 2 +- pages/index.tsx | 18 ++ 7 files changed, 405 insertions(+), 359 deletions(-) diff --git a/components/DebugView.tsx b/components/DebugView.tsx index 4461c0b..5b4045e 100644 --- a/components/DebugView.tsx +++ b/components/DebugView.tsx @@ -9,7 +9,7 @@ interface DebugViewProps { promptInputs: PromptInput[]; inferenceResults: InferenceResult[]; nowPlayingResult: InferenceResult; - open: boolean ; + open: boolean; setOpen: (open: boolean) => void; } @@ -39,8 +39,9 @@ export default function DebugView({ onClose={() => setOpen(false)} as="div" className="fixed inset-0 z-30" + key="debug-dialog" > - +
diff --git a/components/ModelInference.tsx b/components/ModelInference.tsx index 007fdeb..297d79b 100644 --- a/components/ModelInference.tsx +++ b/components/ModelInference.tsx @@ -16,6 +16,8 @@ interface ModelInferenceProps { nowPlayingResult: InferenceResult; newResultCallback: (input: InferenceInput, result: InferenceResult) => void; useBaseten: boolean; + denoising: number; + seedImageId: string; } /** @@ -31,12 +33,12 @@ export default function ModelInference({ nowPlayingResult, newResultCallback, useBaseten, + denoising, + seedImageId, }: ModelInferenceProps) { // Create parameters for the inference request - const [denoising, setDenoising] = useState(0.75); const [guidance, setGuidance] = useState(7.0); const [numInferenceSteps, setNumInferenceSteps] = useState(50); - const [seedImageId, setSeedImageId] = useState("og_beat"); const [maskImageId, setMaskImageId] = useState(null); const [initializedUrlParams, setInitializedUrlParams] = useState(false); @@ -50,10 +52,6 @@ export default function ModelInference({ // Set initial params from URL query strings const router = useRouter(); useEffect(() => { - if (router.query.denoising) { - setDenoising(parseFloat(router.query.denoising as string)); - } - if (router.query.guidance) { setGuidance(parseFloat(router.query.guidance as string)); } @@ -62,10 +60,6 @@ export default function ModelInference({ setNumInferenceSteps(parseInt(router.query.numInferenceSteps as string)); } - if (router.query.seedImageId) { - setSeedImageId(router.query.seedImageId as string); - } - if (router.query.maskImageId) { if (router.query.maskImageId === "none") { setMaskImageId(""); diff --git a/components/PromptPanel.tsx b/components/PromptPanel.tsx index 849dc9e..9f24a41 100644 --- a/components/PromptPanel.tsx +++ b/components/PromptPanel.tsx @@ -59,7 +59,7 @@ export default function PromptPanel({ displayPrompts = [...promptsToAdd, ...displayPrompts]; } - // Add in the upNext and staged prompts + // Add in the upNext and staged prompts // select the last 2 prompts from prompts const lastPrompts = prompts.slice(-2); @@ -167,7 +167,7 @@ export default function PromptPanel({ }} > void; + seedImage: string; + setSeedImage: (seedImage: string) => void; } export default function Settings({ promptInputs, inferenceResults, nowPlayingResult, + denoising, + setDenoising, + seedImage, + setSeedImage, }: DebugViewProps) { const [open, setOpen] = useState(false); - var classNameCondition = "" + var classNameCondition = ""; if (open) { - classNameCondition = "fixed z-20 top-44 right-4 md:top-48 md:right-8 bg-sky-400 w-14 h-14 rounded-full drop-shadow-lg flex justify-center items-center text-white text-2xl hover:bg-sky-500 hover:drop-shadow-2xl" + classNameCondition = + "fixed z-20 top-44 right-4 md:top-48 md:right-8 bg-sky-400 w-14 h-14 rounded-full drop-shadow-lg flex justify-center items-center text-white text-2xl hover:bg-sky-500 hover:drop-shadow-2xl"; } else { - classNameCondition = "fixed z-20 top-44 right-4 md:top-48 md:right-8 bg-slate-100 w-14 h-14 rounded-full drop-shadow-lg flex justify-center items-center text-sky-900 text-2xl hover:text-white hover:bg-sky-600 hover:drop-shadow-2xl" + classNameCondition = + "fixed z-20 top-44 right-4 md:top-48 md:right-8 bg-slate-100 w-14 h-14 rounded-full drop-shadow-lg flex justify-center items-center text-sky-900 text-2xl hover:text-white hover:bg-sky-600 hover:drop-shadow-2xl"; } return ( @@ -82,47 +92,36 @@ export default function Settings({ leaveFrom="opacity-100 scale-100" leaveTo="opacity-0 scale-95" > - +
Settings
-

- - - {/* */} - - {SeedImageSelector()} - - {DenoisingSelector()} - - {DebugButton( - promptInputs, - inferenceResults, - nowPlayingResult - )} - +

+ Riffusion generates music from text prompts using a + diffusion model. Try typing in your favorite artist or + genre, and playing with the settings below to explore the + latent space of sound.

+ + {/* */} + + {SeedImageSelector(seedImage, setSeedImage)} + + {DenoisingSelector(denoising, setDenoising)} + + {DebugButton( + promptInputs, + inferenceResults, + nowPlayingResult + )}
- - -
@@ -152,57 +152,103 @@ export default function Settings({ ); -}; - -export function SeedImageSelector() { - return ( -
- - - -
- ) } -export function DenoisingSelector() { - return ( -
- - - -
- ) -} - -export function DebugButton( - promptInputs, - inferenceResults, - nowPlayingResult +export function SeedImageSelector( + seedImage: string, + setSeedImage: (seedImage: string) => void ) { + let selectOptions = [ + ["OG Beat", "og_beat"], + ["Soul", "chill_soul_1"], + // ["High Energy", 0.85], + // ["Spacy", 0.95], + ]; + + let matchedOption = selectOptions.find((x) => x[1] === seedImage); + if (matchedOption === undefined) { + matchedOption = [`Custom (${seedImage})`, seedImage]; + selectOptions.push(matchedOption); + } + + return ( +
+ + +

+ Used as the base for img2img diffusion. This keeps your riff on beat and + impacts melodic patterns. +

+
+ ); +} + +export function DenoisingSelector( + denoising: number, + setDenoising: (d: number) => void +) { + let selectOptions = [ + ["Keep it on beat (0.75)", 0.75], + ["Get a little crazy (0.8)", 0.8], + ["I'm feeling lucky (0.85)", 0.85], + ["What is tempo? (0.95)", 0.95], + ]; + + let matchedOption = selectOptions.find((x) => x[1] === denoising); + if (matchedOption === undefined) { + matchedOption = [`Custom (${denoising})`, denoising]; + selectOptions.push(matchedOption); + } + + return ( +
+ + +

+ The higher the denoising, the more creative the output, and the more + likely you are to get off beat. +

+
+ ); +} + +export function DebugButton(promptInputs, inferenceResults, nowPlayingResult) { const [debugOpen, debugSetOpen] = useState(false); let buttonClassName = ""; @@ -218,6 +264,7 @@ export function DebugButton( <> - } - - return ( - share image - ) - } - - // function to generate a link to a the moment in the song based on the played clips, input variable is how many seconds ago - function generateLink(secondsAgo: number) { - - var prompt - var seed - var denoising - var maskImageId - var seedImageId - var guidance - var numInferenceSteps - var alphaVelocity - - if (!nowPlayingResult) { - return window.location.href; - } - else { - var selectedInput: InferenceResult["input"] - if (secondsAgo == 0) { - selectedInput = nowPlayingResult.input - } - else { - var selectedCounter = nowPlayingResult.counter - (secondsAgo / 5) - selectedInput = inferenceResults.find((result) => result.counter == selectedCounter)?.input - - if (!selectedInput) { - // TODO: ideally don't show the button in this case... - return window.location.href; - } - } - - // TODO: Consider only including in the link the things that are different from the default values - prompt = selectedInput.start.prompt - seed = selectedInput.start.seed - denoising = selectedInput.start.denoising - maskImageId = selectedInput.mask_image_id - seedImageId = nowPlayingResult.input.seed_image_id - - // TODO, selectively add these based on whether we give user option to change them - // guidance = nowPlayingResult.input.guidance - // numInferenceSteps = nowPlayingResult.input.num_inference_steps - // alphaVelocity = nowPlayingResult.input.alpha_velocity - } - - var baseUrl = window.location.origin + "/?"; - - if (prompt != null) { var promptString = "&prompt=" + prompt } else { promptString = "" } - if (seed != null) { var seedString = "&seed=" + seed } else { seedString = "" } - if (denoising != null) { var denoisingString = "&denoising=" + denoising } else { denoisingString = "" } - if (maskImageId != null) { var maskImageIdString = "&maskImageId=" + maskImageId } else { maskImageIdString = "" } - if (seedImageId != null) { var seedImageIdString = "&seedImageId=" + seedImageId } else { seedImageIdString = "" } - if (guidance != null) { var guidanceString = "&guidance=" + guidance } else { guidanceString = "" } - if (numInferenceSteps != null) { var numInferenceStepsString = "&numInferenceSteps=" + numInferenceSteps } else { numInferenceStepsString = "" } - if (alphaVelocity != null) { var alphaVelocityString = "&alphaVelocity=" + alphaVelocity } else { alphaVelocityString = "" } - - // Format strings to have + in place of spaces for ease of sharing, note this is only necessary for prompts currently - promptString = promptString.replace(/ /g, "+"); - - // create url string with the variables above combined - var shareUrl = baseUrl + promptString + seedString + denoisingString + maskImageIdString + seedImageIdString + guidanceString + numInferenceStepsString + alphaVelocityString - - return shareUrl; - } - - return ( - <> - + + - - setOpen(false)} - > -
- - - + + + +
+ + Share your riff + +
+ share image +
- - - -
- - Share your riff - -
- {displayShareImage()} -
+
+ +
+
+ -
- - - - - -
-
- - -
-
-
- - ); -}; - -function dataURItoBlob(image: string) { - // convert base64/URLEncoded data component to raw binary data held in a string - var byteString; - if (image.split(',')[0].indexOf('base64') >= 0) - byteString = atob(image.split(',')[1]); - else - byteString = unescape(image.split(',')[1]); - - // separate out the mime component - var mimeString = image.split(',')[0].split(':')[1].split(';')[0]; - - // write the bytes of the string to a typed array - var ia = new Uint8Array(byteString.length); - for (var i = 0; i < byteString.length; i++) { - ia[i] = byteString.charCodeAt(i); - } - - return new Blob([ia], { type: mimeString }); + +
+
+
+ + + + + + ); } diff --git a/pages/about.tsx b/pages/about.tsx index 6be69c3..f01ec59 100644 --- a/pages/about.tsx +++ b/pages/about.tsx @@ -404,7 +404,7 @@ export default function Home() { results in under five seconds, you can run the experience locally.


- Code +

Code