From 6e69c942771ae581649015d49e75ed556adc0765 Mon Sep 17 00:00:00 2001 From: Hayk Martiros Date: Sat, 26 Nov 2022 06:48:52 +0000 Subject: [PATCH] Support masks --- README.md | 3 +- dev_requirements.txt | 1 + riffusion/datatypes.py | 7 ++-- riffusion/riffusion_pipeline.py | 56 ++++++++++++++++++++++++++--- riffusion/server.py | 11 +++++- seed_images/mask_beat_lines_80.png | Bin 0 -> 7245 bytes seed_images/{0.png => og_beat.png} | Bin 7 files changed, 68 insertions(+), 10 deletions(-) create mode 100644 seed_images/mask_beat_lines_80.png rename seed_images/{0.png => og_beat.png} (100%) diff --git a/README.md b/README.md index 5303297..5b46291 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-infere { alpha: 0.75, num_inference_steps: 50, - seed_image_id: 0, + seed_image_id: "og_beat", start: { prompt: "church bells on sunday", @@ -53,4 +53,3 @@ Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-infe audio: "< base64 encoded MP3 clip >",, } ``` - diff --git a/dev_requirements.txt b/dev_requirements.txt index c6fa812..b10ce3e 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,4 +1,5 @@ black +ipdb isort mypy pylint diff --git a/riffusion/datatypes.py b/riffusion/datatypes.py index bdf6672..6280c1c 100644 --- a/riffusion/datatypes.py +++ b/riffusion/datatypes.py @@ -3,6 +3,7 @@ Data model for the riffusion API. """ from dataclasses import dataclass +import typing as T @dataclass @@ -46,8 +47,10 @@ class InferenceInput: num_inference_steps: int = 50 # Which seed image to use - # TODO(hayk): Convert this to a string ID and add a seed image + mask API. - seed_image_id: int = 0 + seed_image_id: str = "og_beat" + + # ID of mask image to use + mask_image_id: T.Optional[str] = None @dataclass diff --git a/riffusion/riffusion_pipeline.py b/riffusion/riffusion_pipeline.py index 9d439af..2d4b558 100644 --- a/riffusion/riffusion_pipeline.py +++ b/riffusion/riffusion_pipeline.py @@ -78,9 +78,17 @@ class RiffusionPipeline(DiffusionPipeline): self, inputs: InferenceInput, init_image: PIL.Image.Image, + mask_image: PIL.Image.Image = None, ) -> PIL.Image.Image: """ Runs inference using interpolation with both img2img and text conditioning. + + Args: + inputs: Parameter dataclass + init_image: Image used for conditioning + mask_image: White pixels in the mask will be replaced by noise and therefore repainted, + while black pixels will be preserved. It will be converted to a single + channel (luminance) before use. """ alpha = inputs.alpha start = inputs.start @@ -96,7 +104,7 @@ class RiffusionPipeline(DiffusionPipeline): text_embedding = torch.lerp(embed_start, embed_end, alpha) # Image latents - init_image = preprocess(init_image) + init_image = preprocess_image(init_image) init_image_torch = init_image.to(device=self.device, dtype=embed_start.dtype) init_latent_dist = self.vae.encode(init_image_torch).latent_dist # TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The @@ -105,9 +113,18 @@ class RiffusionPipeline(DiffusionPipeline): init_latents = init_latent_dist.sample(generator=generator) init_latents = 0.18215 * init_latents + # Prepare mask latent + if mask_image: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + mask_image = preprocess_mask(mask_image, scale_factor=vae_scale_factor) + mask = mask_image.to(device=self.device, dtype=embed_start.dtype) + else: + mask = None + outputs = self.interpolate_img2img( text_embeddings=text_embedding, init_latents=init_latents, + mask=mask, generator_a=generator_start, generator_b=generator_end, interpolate_alpha=alpha, @@ -124,9 +141,10 @@ class RiffusionPipeline(DiffusionPipeline): self, text_embeddings: torch.FloatTensor, init_latents: torch.FloatTensor, - generator_a: T.Optional[torch.Generator], - generator_b: T.Optional[torch.Generator], + generator_a: torch.Generator, + generator_b: torch.Generator, interpolate_alpha: float, + mask: T.Optional[torch.FloatTensor] = None, strength_a: float = 0.8, strength_b: float = 0.8, num_inference_steps: T.Optional[int] = 50, @@ -137,6 +155,9 @@ class RiffusionPipeline(DiffusionPipeline): output_type: T.Optional[str] = "pil", **kwargs, ): + """ + TODO + """ batch_size = text_embeddings.shape[0] # set timesteps @@ -209,6 +230,7 @@ class RiffusionPipeline(DiffusionPipeline): init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype ) noise = slerp(interpolate_alpha, noise_a, noise_b) + init_latents_orig = init_latents init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -220,7 +242,7 @@ class RiffusionPipeline(DiffusionPipeline): if accepts_eta: extra_step_kwargs["eta"] = eta - latents = init_latents + latents = init_latents.clone() t_start = max(num_inference_steps - init_timestep + offset, 0) @@ -250,6 +272,11 @@ class RiffusionPipeline(DiffusionPipeline): # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + if mask is not None: + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + # import ipdb; ipdb.set_trace() + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + latents = 1.0 / 0.18215 * latents image = self.vae.decode(latents).sample @@ -262,7 +289,7 @@ class RiffusionPipeline(DiffusionPipeline): return dict(images=image, latents=latents, nsfw_content_detected=False) -def preprocess(image): +def preprocess_image(image: PIL.Image.Image) -> torch.Tensor: """ Preprocess an image for the model. """ @@ -275,6 +302,25 @@ def preprocess(image): return 2.0 * image - 1.0 +def preprocess_mask(mask: PIL.Image.Image, scale_factor: int = 8) -> torch.Tensor: + """ + Preprocess a mask for the model. + """ + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize( + (w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST + ) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + + return mask + + def slerp(t, v0, v1, dot_threshold=0.9995): """ Helper function to spherically interpolate two arrays v1 v2. diff --git a/riffusion/server.py b/riffusion/server.py index ff7a65c..a768ad0 100644 --- a/riffusion/server.py +++ b/riffusion/server.py @@ -113,8 +113,17 @@ def run_inference(): return f"Invalid seed image: {inputs.seed_image_id}", 400 init_image = PIL.Image.open(str(init_image_path)).convert("RGB") + # Load the mask image by ID + if inputs.mask_image_id: + mask_image_path = Path(SEED_IMAGES_DIR, f"{inputs.mask_image_id}.png") + if not mask_image_path.is_file: + return f"Invalid mask image: {inputs.mask_image_id}", 400 + mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB") + else: + mask_image = None + # Execute the model to get the spectrogram image - image = MODEL.riffuse(inputs, init_image=init_image) + image = MODEL.riffuse(inputs, init_image=init_image, mask_image=mask_image) # Reconstruct audio from the image wav_bytes = wav_bytes_from_spectrogram_image(image) diff --git a/seed_images/mask_beat_lines_80.png b/seed_images/mask_beat_lines_80.png new file mode 100644 index 0000000000000000000000000000000000000000..c986bd74b282db93921ee8217f8d3b5f04e5cadc GIT binary patch literal 7245 zcmeI1`8!nq|NqaK88c%r#-wR5ld+YoQ`rVF_9DXDE4xMsz4Xc$L{W|zwAdq*l1O>A zXi7;XW=KS~kS!TWw#k<4%V*x#_rLgF?;paeF?_6=x?qDU3V@ z005~&2d!NJ046$x0cf~rsJpM)C>p?US366v@1s_nXd=!$=otZ_=g;)?^v|C^2L}hcySv-l z+gn;%K7IPs+}ympy!`w3@9*Eg_xJaI{`|SMwY9gmcX)VsVq#)#ZEa;`Wn^Tev$M0K zqob#%M<5Ugg~A^{ehdr@3=Iv9kB`sJ&Mq!4&dkhw`}Xb2moJl(lcS@f^Yin4eSK|h zZCzbmV`F1qzkV$%D0K$_1-nDm7H-kEe);wt9kG`@+Xx78zaDX^r#s`{&muwppCnFW)~4LdC&rGv>Q_4)&1q9UhaK zU$x|9?h1P^2l~}r1d-siw%J8@516C_?D2d<5Mp^{EjRYOIGC0*F_4cF2uDtbUHJ3- zK$JKbiMT8y!^pFPZUS+BL%Yiv>zdW^o!P)dQ1nhC*uC_j>bIL~r|^P?^dp-~)4v_I zSH~E>GHfnh`}XB`nFCrK&22j9hGXb{-9hawSD{-%g))?L7cHPkH1je)Bs$g-(u1XF z!i1I9P&SZ!(gI4Jt6!P9^fY(lwI=w@6-JbysDeB@`uzHY?>U4Iq76j;k)eD97n?{F z`^_lb6}d-kGiz@r-VfV#Ar7kWG}=HI{=d&_iYR75rZ81qKm(OQ2G1duHCRw-?gh4D z+MPm2NEKEk)&|PUP}XIf@jw>h%+hafC={E4qrA2|85eK0^OP^VnZ#(#f~PzuQA}VD z({-Z}l@;QGYtTc1-KYp8)Oql3G_vf7_Oot!C( zD`BDYs=Wy3_63{|We(?j$d%zK(*+?Of5ZZ4^j=qH0mrRP+86yW4LTTskd(KhyEI?e zq=EAxg5jvM5>$#AaNv_^40SIbJy;+qV zVx2cqs_;Z3l)6TIp(`^|${JVr2I<#W7Jr>2yaBJejwoG*H-gL(!P^jlOgo;ljbX6= z2Vuo;{+(&4vAmA3y#{jtt5CJpi>R!eB^27@5^`RV2Ga9!;+rO4ncnF#; z5lo_aANex7PzOyuo8Gj{nti9f1|GLi;}&ZYT`*>KXT?VB`5a>yIvj$R<76d91+PMT z83!N9rj)`dqk=cGBni!UU$R9}4mF!3c}Bue!h%Zka6@rd!F}44bRdnz3^(q*VqP8m zk)L6S6QYTcf5FtUES@#UP}{s~l^>LltZug2_xBV zf|>0PU|TfS`7l!L=~HiVAx%!!0XydeLmY)4Yf+&v&FPm&k~gq9XT{7FN*EoQr&r;H zXhshRu$VM?WQh1%nH=_#ytaR)Gq#Lmd6b{UyE~XG-3UIOW74rAIowjAm&*!wf*CZP zWGK#G>OQ3`{ys@Kj8o3#v|NM?UsxR!sR5e{fAJEK)^88STtAGm@~XA1Yb_Hn(?buV z>R7-XWBOQdo;U4v3h*yUG7P;J4V)H-S0u(X@xw;YI2+*bp2}`K!ASg;qhJGAx=#=k zAaqQ_fpBA@8cLz=#zC4op$>+i5(hdHKt?v5Wh!etnf> zbz;aHc(4AvZ~LbS{H{vm=1@RJr3(5aVFS|zOZq8;GZQEM$g;Tey<)asIZzn1&7BP1 zeF@C!w||D=ShE4QSxg_p#5Ya+m;X`k((3P?j3&1YL9X0fJLl_&^Viv7o&4}v_t`xf zFb@T*<4Y5QCJh5~*!pe`#fuj+ga0W_mG7ki*4V{oSO00|ODrX2-QHX?J%u(RQZ@97)hqgMsNplte1(mut`Ll)n#}<5AS*|scOI(4&NQG z{m7%#b79v4+-ME4+5>U}J8}-mmZ?c9Bc;5G{nNbYPmo@|{In9mV(n$<4*kSF_2p_n zLjIi zfc;gR@ixKJ3o0fCl%r2}IMNDP79x8>g3f^*rGxwd$Krz`lkEbJ z7EK5O$-oH$`5%as;!PaRghyzlUK3)=@~^^YB%eDLao=(*Y%>e6N}5(uvigYz-Sm5e zjBjKBkzWkM1bj*dS03C)@iJ(al*e4eO@0ZhisD1pQF~Cz3b|wSbmvA2=xaSaz zv#8u+e`gx;FZ8?;eL%}A0KfeT?6R)$eNLZkUBz1U8qAnvxo*tnz2%5M;B4gBjSYMI zpz;PrhKc3llk&v=$LD}2h%YSe=_uv}sg#7kj?@W(y(DH3AJ`3AF2X$mJ2JeXr#7E* zuOrH=13OygW@kJ39{RYd{iV;Smd9mJjRF_3O2nrOK{Z5U?JLzrCASncR0nnhYXjx` zgQ@!ocD0|1%pExR8CUL<(wWKxr3aigXqQXi5lXFQ<#sS&piVOmitE!lpJEE5Bo;1C zGwd#9d)f`bi^eLnJW<>SUE z`VE`pT-&}!CrVh%lHG-8Q3{bjnM&aq@ol3(%N%oJb=gVmxtqe`<~n|r@2^$ZX+fk_28&CWkxfx}-Z;juvliWR0ql-o7Gr@%a9)dK zOinqC_T)I;iZ$+?L}Se4k!>I^%c5o2*_{^c-^};xw`O7LiO;oUUdX6X@Ekz14P{=8_HFQQc z`$<}1+!h^e&^FAWso+;EEd$^%15;HM4ma z=gkx_yqhBlqfTQDSil&1wmS`YUgE-?V0cSVvNzBk-D}qVyyHF*k;Qq2 zlo46Yn23LX-4-?Fc=RpoCfo@2cX@Eif5u;ORZ2C)yd8>F?daiplCMpSgg($aLuxF7 z_*exsUJ5rNu(7GPErvxw|02p(ML;BiG$72uqf^!1fV?i1!)aq2i0|#-{t-%@ZsI%D zueWwkPdgA9`H{lt=;7e-lYn)q3l8sB41X>FfS{p&VXc~JV20qyHp+zCFuuz8A zOD2tsG8lU*GnB0(fg`j^e_3%kY@!X6n!$OR=*!>Ag-6d|ZKtIYaKJZw;N!IIjrc5;5Q`%Fw$O@a2zpKRf4;=YgHRDY-iKL!c0Vg$$A_x_NhE z#0E+jZ6SjDLMcftRFrlKAcCDu;7lacq`Yf>4}yGA;T4c=g$^|8>L}V;Drx3e=;NZruB+WDw~$TsHdFR zhdwS#J3R(Wi`a8t<6f97i-}MYapI>S!tEYV*+r?;k0{eu>~Fp$_xm9(!9%WdFPT)W zN%X`NG%ks~sABghp-jq1^NB9v?|tA3o&4pCz=RAX3O&)g19x#BT7PH3djUBBAN+6p zG}(f)xEg)MbHTwx9MQyAj|1+Je8ecPT>`@W-?zmj4nEu?YeYH^zY64Y^&aZpmb5>1 zr^4!TfGH`2AVroA&1Q#{$6wq4I#%y<;D!kia@Glh4}J1iRmcsC%(pqAecZr5>|UW}Wi}@_HKedX2Fk-0SUDES2X;_s zz0UCZx19PcixSdl5x24cXNO+N8;G;n>;sd4%W18&i0hmq4|Ax@J;V?~syQq`O3AD| zzW4gx{1b%KO0mZ+TK(nN_B<&`Q*z2z96+w3MY!RY@Do6?9ZJL3}ge9w=JmU!Oe>F@7p|1OL}%jFr&-IQzyi~ zCX~25|43O0_Gp+7KIqO&8OIpxu!TMqd3!TdhNa+N&bLP`h@u=5QKrNvH@_ zcsAHG598;~+w-A8ek38S5j5bbHVtSQKQv5gJd5HWt^S5ti$LO;2{A-g?$H0sv%c*@ zyiJoa5FF%uP6BwHAh!miyn=p4a6_yba}Ow}KTjUY9^_|~F(QsI`UyIZ5dEF}QT_I? z3W*^+(w9V=43?0-0e2(pJ>id_?y$>|;a!NLp84l%W3YJga3`NT$iJxV-|#D>Fq!0& zeu?8%%PSE$j$;-TBt0uqm3{Lq!=^^Ovn(cd;2vYU^ROQ(66*2g<6c<52M!+NAs+XU zwxx+OGuN%q&#@GDc>1~n{}+k!f)oqaPcBP?1R*2+k-*R#^R-=u64L)28}#C#o~x9` z&5Yr`zun)cmz*R@T%FX1EC}7Ref>sy;8l!8h`dx?MMAfdQI+RHRUe+vtIoO|*zv4g zyQ&Fv*7g@nFvj)U_lpasQOp`-f!iU*_yi-?f!MG0rWMR_`kS#=?uCbg1M!pk^KUz` z;h1;jsmgp`eyuTGZ?XKUZ@fl4=N);xmEw)HbwnP@wf%Ln*uObRUfm!W+00)aN2<9b>wiwqr4I(&sI_=7}#cs7kj$mqPRAYdWiTMJ}pTu_cAY(u@jy(_Vi4vQ;dbUsDzCW?}|k{NsBDM zI(R;I`GWUZu^kAeVX~ha!Xsff@9dU#_@3mV?y(_J(@0eq?BR#@>VB00Ar6aYJdD2+ zN#=EEMh8tz&buGAc|AT_|2{FR&?ADNoAYWVG-I}5qGs*5#}T8k&WTXez;7O+bb!Sf zvc%Aq-&P}yeJztBO&sQ%msFERJY~;m6d!z;^m-%eSitQTZ}+%;__3M`2eq4)k}dNkr6a}S4#fc*zN~{rTNA0-+wI6 zL~gF2)aVCx-XU4^#HSaYKW67P#E!cN&8T?_x<~uQK2HAFl?v==^>oWb2K{rT57nNJ zbgXs7&zSq(-t@irO>4`~Y}amjd&8I=g8+^Q(UX@<*-4|0W)i$ZSmEN_*wZ6hvu%Cz z(dWK2o176xl&{yp$}bm`C3 z(MD-Tb|N%*dxAgJv#~MU&5xO1$UroPbZoq;pZVQAxwvd{fNl9^^ufBr=Srb`Mwz7j z^sLj&vA-+cZGDT|FkXK%S>i7k=9)*ESa%*_tO@28)<$|7I!@u)oC07ah&t5w&3JrC zy5N4>?+iJf`q;5m7-_@UY;$92VeX`P04C!>nC`JSvupn`y|8sYO@_L*Q1HN87p3d{ zuD;{@!t9ymIkLSKeIDue?zUe+wJ-T_R4oz`hV~BJHVxe!T_8kmZYJ*#j81O0b#&DI zIkCEg+K~DTtY2JOWG=?6Osy?U*6WMe%Vor#7d4~X-swp$3@!O)l?usuW`lqHVm2Nu ztVzoQwlsL)+ut7ps!awYHa5m|d5fKcBg4ZzQxP|BN2P`V&GAW4`#e)IyZ!C((Bvr)$kd*b(HnG->Lyy8xsP-=!KZw5$LAl z%bQ2H{>;ryPTS0HN_FZ|ZFHsJSwc0|CVBDN4rWkL&^}ZZyDv0jb2~?<$gAr#j|+-3 z*;=kuw3ygay|KCQ;!m~i_aQD__X_Me@uo|#fy0*JkF~Zq)TbVE4R$aK*`j3Et}Zd0 zbz9UiC(o_rZlmwCDqqrtkDKt0dlXpY)zlc(zZ(Pnr_5;kvpq z_jGJ!s{G;_t#IL(^v3zA)ks<3aeSOTz)}Yr?;(X8A{QFsGO;aX V_pe4eiT?fphisgz%Pbk}{{v9be7^ty literal 0 HcmV?d00001 diff --git a/seed_images/0.png b/seed_images/og_beat.png similarity index 100% rename from seed_images/0.png rename to seed_images/og_beat.png