cog using common load fxn

This commit is contained in:
daanelson 2022-12-31 04:42:04 +00:00 committed by Hayk Martiros
parent 19fea1294b
commit 2695af9706
4 changed files with 20 additions and 35 deletions

View File

@ -22,7 +22,7 @@ build:
- "flask_cors==3.0.10" - "flask_cors==3.0.10"
- "flask==1.1.2" - "flask==1.1.2"
- "numpy==1.19.4" - "numpy==1.19.4"
- "pillow==8.2.0" - "pillow==9.1.0"
- "pydub==0.25.1" - "pydub==0.25.1"
- "scipy==1.6.3" - "scipy==1.6.3"
- "torch==1.13.0" - "torch==1.13.0"

View File

@ -11,7 +11,7 @@ This package contains integrations of Riffusion into third party apps and deploy
To run riffusion as a Cog model, first, [install Cog](https://github.com/replicate/cog) and To run riffusion as a Cog model, first, [install Cog](https://github.com/replicate/cog) and
download the model weights: download the model weights:
cog run python -m riffusion.cog_riffusion --download_weights cog run python -m integrations.cog_riffusion --download_weights
Then you can run predictions: Then you can run predictions:

View File

@ -12,7 +12,6 @@ import numpy as np
import PIL import PIL
import torch import torch
from cog import BaseModel, BasePredictor, Input, Path from cog import BaseModel, BasePredictor, Input, Path
from huggingface_hub import hf_hub_download
from riffusion.datatypes import InferenceInput, PromptInput from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.riffusion_pipeline import RiffusionPipeline from riffusion.riffusion_pipeline import RiffusionPipeline
@ -21,7 +20,6 @@ from riffusion.spectrogram_params import SpectrogramParams
MODEL_ID = "riffusion/riffusion-model-v1" MODEL_ID = "riffusion/riffusion-model-v1"
MODEL_CACHE = "riffusion-cache" MODEL_CACHE = "riffusion-cache"
UNET_CACHE = "unet-cache"
# Where built-in seed images are stored # Where built-in seed images are stored
SEED_IMAGES_DIR = Path("./seed_images") SEED_IMAGES_DIR = Path("./seed_images")
@ -46,7 +44,7 @@ class RiffusionPredictor(BasePredictor):
See README & https://github.com/replicate/cog for details See README & https://github.com/replicate/cog for details
""" """
def setup(self): def setup(self, local_files_only=True):
""" """
Loads the model onto GPU from local cache. Loads the model onto GPU from local cache.
""" """
@ -56,7 +54,8 @@ class RiffusionPredictor(BasePredictor):
checkpoint=MODEL_ID, checkpoint=MODEL_ID,
use_traced_unet=True, use_traced_unet=True,
device=self.device, device=self.device,
local_files_only=True, local_files_only=local_files_only,
cache_dir=MODEL_CACHE,
) )
def predict( def predict(
@ -137,38 +136,16 @@ class RiffusionPredictor(BasePredictor):
# RiffusionPipeline.load_checkpoint? # RiffusionPipeline.load_checkpoint?
def download_weights(checkpoint: str): def download_weights():
""" """
Clears local cache & downloads riffusion weights Clears local cache & downloads riffusion weights
""" """
for folder in [MODEL_CACHE, UNET_CACHE]: if os.path.exists(MODEL_CACHE):
if os.path.exists(folder): shutil.rmtree(MODEL_CACHE)
shutil.rmtree(folder) os.makedirs(MODEL_CACHE)
os.makedirs(folder)
model, unet_file = _load_model(checkpoint, local_only=False) pred = RiffusionPredictor()
return model, unet_file pred.setup(local_files_only=False)
def _load_model(checkpoint: str, local_only=False):
model = RiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
cache_dir=MODEL_CACHE,
local_files_only=local_only,
)
unet_file = hf_hub_download(
"riffusion/riffusion-model-v1",
filename="unet_traced.pt",
subfolder="unet_traced",
cache_dir=UNET_CACHE,
local_files_only=local_only,
)
return model, unet_file
if __name__ == "__main__": if __name__ == "__main__":
@ -178,4 +155,4 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
if args.download_weights: if args.download_weights:
download_weights(MODEL_ID) download_weights()

View File

@ -70,6 +70,7 @@ class RiffusionPipeline(DiffusionPipeline):
device: str = "cuda", device: str = "cuda",
local_files_only: bool = False, local_files_only: bool = False,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
cache_dir: T.Optional[str] = None,
) -> RiffusionPipeline: ) -> RiffusionPipeline:
""" """
Load the riffusion model pipeline. Load the riffusion model pipeline.
@ -97,6 +98,7 @@ class RiffusionPipeline(DiffusionPipeline):
safety_checker=lambda images, **kwargs: (images, False), safety_checker=lambda images, **kwargs: (images, False),
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
local_files_only=local_files_only, local_files_only=local_files_only,
cache_dir=cache_dir,
).to(device) ).to(device)
if channels_last: if channels_last:
@ -111,6 +113,8 @@ class RiffusionPipeline(DiffusionPipeline):
in_channels=pipeline.unet.in_channels, in_channels=pipeline.unet.in_channels,
dtype=dtype, dtype=dtype,
device=device, device=device,
local_files_only=local_files_only,
cache_dir=cache_dir,
) )
if traced_unet is not None: if traced_unet is not None:
@ -128,6 +132,8 @@ class RiffusionPipeline(DiffusionPipeline):
in_channels: int, in_channels: int,
dtype: torch.dtype, dtype: torch.dtype,
device: str = "cuda", device: str = "cuda",
local_files_only=False,
cache_dir: T.Optional[str] = None,
) -> T.Optional[torch.nn.Module]: ) -> T.Optional[torch.nn.Module]:
""" """
Load a traced unet from the huggingface hub. This can improve performance. Load a traced unet from the huggingface hub. This can improve performance.
@ -141,6 +147,8 @@ class RiffusionPipeline(DiffusionPipeline):
checkpoint, checkpoint,
subfolder=subfolder, subfolder=subfolder,
filename=filename, filename=filename,
local_files_only=local_files_only,
cache_dir=cache_dir,
) )
unet_traced = torch.jit.load(unet_file) unet_traced = torch.jit.load(unet_file)