cog using common load fxn
This commit is contained in:
parent
19fea1294b
commit
2695af9706
2
cog.yaml
2
cog.yaml
|
@ -22,7 +22,7 @@ build:
|
|||
- "flask_cors==3.0.10"
|
||||
- "flask==1.1.2"
|
||||
- "numpy==1.19.4"
|
||||
- "pillow==8.2.0"
|
||||
- "pillow==9.1.0"
|
||||
- "pydub==0.25.1"
|
||||
- "scipy==1.6.3"
|
||||
- "torch==1.13.0"
|
||||
|
|
|
@ -11,7 +11,7 @@ This package contains integrations of Riffusion into third party apps and deploy
|
|||
To run riffusion as a Cog model, first, [install Cog](https://github.com/replicate/cog) and
|
||||
download the model weights:
|
||||
|
||||
cog run python -m riffusion.cog_riffusion --download_weights
|
||||
cog run python -m integrations.cog_riffusion --download_weights
|
||||
|
||||
Then you can run predictions:
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@ import numpy as np
|
|||
import PIL
|
||||
import torch
|
||||
from cog import BaseModel, BasePredictor, Input, Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from riffusion.datatypes import InferenceInput, PromptInput
|
||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||
|
@ -21,7 +20,6 @@ from riffusion.spectrogram_params import SpectrogramParams
|
|||
|
||||
MODEL_ID = "riffusion/riffusion-model-v1"
|
||||
MODEL_CACHE = "riffusion-cache"
|
||||
UNET_CACHE = "unet-cache"
|
||||
|
||||
# Where built-in seed images are stored
|
||||
SEED_IMAGES_DIR = Path("./seed_images")
|
||||
|
@ -46,7 +44,7 @@ class RiffusionPredictor(BasePredictor):
|
|||
See README & https://github.com/replicate/cog for details
|
||||
"""
|
||||
|
||||
def setup(self):
|
||||
def setup(self, local_files_only=True):
|
||||
"""
|
||||
Loads the model onto GPU from local cache.
|
||||
"""
|
||||
|
@ -56,7 +54,8 @@ class RiffusionPredictor(BasePredictor):
|
|||
checkpoint=MODEL_ID,
|
||||
use_traced_unet=True,
|
||||
device=self.device,
|
||||
local_files_only=True,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=MODEL_CACHE,
|
||||
)
|
||||
|
||||
def predict(
|
||||
|
@ -137,38 +136,16 @@ class RiffusionPredictor(BasePredictor):
|
|||
# RiffusionPipeline.load_checkpoint?
|
||||
|
||||
|
||||
def download_weights(checkpoint: str):
|
||||
def download_weights():
|
||||
"""
|
||||
Clears local cache & downloads riffusion weights
|
||||
"""
|
||||
for folder in [MODEL_CACHE, UNET_CACHE]:
|
||||
if os.path.exists(folder):
|
||||
shutil.rmtree(folder)
|
||||
os.makedirs(folder)
|
||||
if os.path.exists(MODEL_CACHE):
|
||||
shutil.rmtree(MODEL_CACHE)
|
||||
os.makedirs(MODEL_CACHE)
|
||||
|
||||
model, unet_file = _load_model(checkpoint, local_only=False)
|
||||
return model, unet_file
|
||||
|
||||
|
||||
def _load_model(checkpoint: str, local_only=False):
|
||||
model = RiffusionPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=torch.float16,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
cache_dir=MODEL_CACHE,
|
||||
local_files_only=local_only,
|
||||
)
|
||||
|
||||
unet_file = hf_hub_download(
|
||||
"riffusion/riffusion-model-v1",
|
||||
filename="unet_traced.pt",
|
||||
subfolder="unet_traced",
|
||||
cache_dir=UNET_CACHE,
|
||||
local_files_only=local_only,
|
||||
)
|
||||
return model, unet_file
|
||||
pred = RiffusionPredictor()
|
||||
pred.setup(local_files_only=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -178,4 +155,4 @@ if __name__ == "__main__":
|
|||
)
|
||||
args = parser.parse_args()
|
||||
if args.download_weights:
|
||||
download_weights(MODEL_ID)
|
||||
download_weights()
|
||||
|
|
|
@ -70,6 +70,7 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
device: str = "cuda",
|
||||
local_files_only: bool = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
cache_dir: T.Optional[str] = None,
|
||||
) -> RiffusionPipeline:
|
||||
"""
|
||||
Load the riffusion model pipeline.
|
||||
|
@ -97,6 +98,7 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
).to(device)
|
||||
|
||||
if channels_last:
|
||||
|
@ -111,6 +113,8 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
in_channels=pipeline.unet.in_channels,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
if traced_unet is not None:
|
||||
|
@ -128,6 +132,8 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
in_channels: int,
|
||||
dtype: torch.dtype,
|
||||
device: str = "cuda",
|
||||
local_files_only=False,
|
||||
cache_dir: T.Optional[str] = None,
|
||||
) -> T.Optional[torch.nn.Module]:
|
||||
"""
|
||||
Load a traced unet from the huggingface hub. This can improve performance.
|
||||
|
@ -141,6 +147,8 @@ class RiffusionPipeline(DiffusionPipeline):
|
|||
checkpoint,
|
||||
subfolder=subfolder,
|
||||
filename=filename,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
unet_traced = torch.jit.load(unet_file)
|
||||
|
||||
|
|
Loading…
Reference in New Issue