cog using common load fxn

This commit is contained in:
daanelson 2022-12-31 04:42:04 +00:00 committed by Hayk Martiros
parent 19fea1294b
commit 2695af9706
4 changed files with 20 additions and 35 deletions

View File

@ -22,7 +22,7 @@ build:
- "flask_cors==3.0.10"
- "flask==1.1.2"
- "numpy==1.19.4"
- "pillow==8.2.0"
- "pillow==9.1.0"
- "pydub==0.25.1"
- "scipy==1.6.3"
- "torch==1.13.0"

View File

@ -11,7 +11,7 @@ This package contains integrations of Riffusion into third party apps and deploy
To run riffusion as a Cog model, first, [install Cog](https://github.com/replicate/cog) and
download the model weights:
cog run python -m riffusion.cog_riffusion --download_weights
cog run python -m integrations.cog_riffusion --download_weights
Then you can run predictions:

View File

@ -12,7 +12,6 @@ import numpy as np
import PIL
import torch
from cog import BaseModel, BasePredictor, Input, Path
from huggingface_hub import hf_hub_download
from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.riffusion_pipeline import RiffusionPipeline
@ -21,7 +20,6 @@ from riffusion.spectrogram_params import SpectrogramParams
MODEL_ID = "riffusion/riffusion-model-v1"
MODEL_CACHE = "riffusion-cache"
UNET_CACHE = "unet-cache"
# Where built-in seed images are stored
SEED_IMAGES_DIR = Path("./seed_images")
@ -46,7 +44,7 @@ class RiffusionPredictor(BasePredictor):
See README & https://github.com/replicate/cog for details
"""
def setup(self):
def setup(self, local_files_only=True):
"""
Loads the model onto GPU from local cache.
"""
@ -56,7 +54,8 @@ class RiffusionPredictor(BasePredictor):
checkpoint=MODEL_ID,
use_traced_unet=True,
device=self.device,
local_files_only=True,
local_files_only=local_files_only,
cache_dir=MODEL_CACHE,
)
def predict(
@ -137,38 +136,16 @@ class RiffusionPredictor(BasePredictor):
# RiffusionPipeline.load_checkpoint?
def download_weights(checkpoint: str):
def download_weights():
"""
Clears local cache & downloads riffusion weights
"""
for folder in [MODEL_CACHE, UNET_CACHE]:
if os.path.exists(folder):
shutil.rmtree(folder)
os.makedirs(folder)
if os.path.exists(MODEL_CACHE):
shutil.rmtree(MODEL_CACHE)
os.makedirs(MODEL_CACHE)
model, unet_file = _load_model(checkpoint, local_only=False)
return model, unet_file
def _load_model(checkpoint: str, local_only=False):
model = RiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
cache_dir=MODEL_CACHE,
local_files_only=local_only,
)
unet_file = hf_hub_download(
"riffusion/riffusion-model-v1",
filename="unet_traced.pt",
subfolder="unet_traced",
cache_dir=UNET_CACHE,
local_files_only=local_only,
)
return model, unet_file
pred = RiffusionPredictor()
pred.setup(local_files_only=False)
if __name__ == "__main__":
@ -178,4 +155,4 @@ if __name__ == "__main__":
)
args = parser.parse_args()
if args.download_weights:
download_weights(MODEL_ID)
download_weights()

View File

@ -70,6 +70,7 @@ class RiffusionPipeline(DiffusionPipeline):
device: str = "cuda",
local_files_only: bool = False,
low_cpu_mem_usage: bool = False,
cache_dir: T.Optional[str] = None,
) -> RiffusionPipeline:
"""
Load the riffusion model pipeline.
@ -97,6 +98,7 @@ class RiffusionPipeline(DiffusionPipeline):
safety_checker=lambda images, **kwargs: (images, False),
low_cpu_mem_usage=low_cpu_mem_usage,
local_files_only=local_files_only,
cache_dir=cache_dir,
).to(device)
if channels_last:
@ -111,6 +113,8 @@ class RiffusionPipeline(DiffusionPipeline):
in_channels=pipeline.unet.in_channels,
dtype=dtype,
device=device,
local_files_only=local_files_only,
cache_dir=cache_dir,
)
if traced_unet is not None:
@ -128,6 +132,8 @@ class RiffusionPipeline(DiffusionPipeline):
in_channels: int,
dtype: torch.dtype,
device: str = "cuda",
local_files_only=False,
cache_dir: T.Optional[str] = None,
) -> T.Optional[torch.nn.Module]:
"""
Load a traced unet from the huggingface hub. This can improve performance.
@ -141,6 +147,8 @@ class RiffusionPipeline(DiffusionPipeline):
checkpoint,
subfolder=subfolder,
filename=filename,
local_files_only=local_files_only,
cache_dir=cache_dir,
)
unet_traced = torch.jit.load(unet_file)