cog using common load fxn
This commit is contained in:
parent
19fea1294b
commit
2695af9706
2
cog.yaml
2
cog.yaml
|
@ -22,7 +22,7 @@ build:
|
||||||
- "flask_cors==3.0.10"
|
- "flask_cors==3.0.10"
|
||||||
- "flask==1.1.2"
|
- "flask==1.1.2"
|
||||||
- "numpy==1.19.4"
|
- "numpy==1.19.4"
|
||||||
- "pillow==8.2.0"
|
- "pillow==9.1.0"
|
||||||
- "pydub==0.25.1"
|
- "pydub==0.25.1"
|
||||||
- "scipy==1.6.3"
|
- "scipy==1.6.3"
|
||||||
- "torch==1.13.0"
|
- "torch==1.13.0"
|
||||||
|
|
|
@ -11,7 +11,7 @@ This package contains integrations of Riffusion into third party apps and deploy
|
||||||
To run riffusion as a Cog model, first, [install Cog](https://github.com/replicate/cog) and
|
To run riffusion as a Cog model, first, [install Cog](https://github.com/replicate/cog) and
|
||||||
download the model weights:
|
download the model weights:
|
||||||
|
|
||||||
cog run python -m riffusion.cog_riffusion --download_weights
|
cog run python -m integrations.cog_riffusion --download_weights
|
||||||
|
|
||||||
Then you can run predictions:
|
Then you can run predictions:
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,6 @@ import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from cog import BaseModel, BasePredictor, Input, Path
|
from cog import BaseModel, BasePredictor, Input, Path
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
|
|
||||||
from riffusion.datatypes import InferenceInput, PromptInput
|
from riffusion.datatypes import InferenceInput, PromptInput
|
||||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||||
|
@ -21,7 +20,6 @@ from riffusion.spectrogram_params import SpectrogramParams
|
||||||
|
|
||||||
MODEL_ID = "riffusion/riffusion-model-v1"
|
MODEL_ID = "riffusion/riffusion-model-v1"
|
||||||
MODEL_CACHE = "riffusion-cache"
|
MODEL_CACHE = "riffusion-cache"
|
||||||
UNET_CACHE = "unet-cache"
|
|
||||||
|
|
||||||
# Where built-in seed images are stored
|
# Where built-in seed images are stored
|
||||||
SEED_IMAGES_DIR = Path("./seed_images")
|
SEED_IMAGES_DIR = Path("./seed_images")
|
||||||
|
@ -46,7 +44,7 @@ class RiffusionPredictor(BasePredictor):
|
||||||
See README & https://github.com/replicate/cog for details
|
See README & https://github.com/replicate/cog for details
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def setup(self):
|
def setup(self, local_files_only=True):
|
||||||
"""
|
"""
|
||||||
Loads the model onto GPU from local cache.
|
Loads the model onto GPU from local cache.
|
||||||
"""
|
"""
|
||||||
|
@ -56,7 +54,8 @@ class RiffusionPredictor(BasePredictor):
|
||||||
checkpoint=MODEL_ID,
|
checkpoint=MODEL_ID,
|
||||||
use_traced_unet=True,
|
use_traced_unet=True,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
local_files_only=True,
|
local_files_only=local_files_only,
|
||||||
|
cache_dir=MODEL_CACHE,
|
||||||
)
|
)
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
|
@ -137,38 +136,16 @@ class RiffusionPredictor(BasePredictor):
|
||||||
# RiffusionPipeline.load_checkpoint?
|
# RiffusionPipeline.load_checkpoint?
|
||||||
|
|
||||||
|
|
||||||
def download_weights(checkpoint: str):
|
def download_weights():
|
||||||
"""
|
"""
|
||||||
Clears local cache & downloads riffusion weights
|
Clears local cache & downloads riffusion weights
|
||||||
"""
|
"""
|
||||||
for folder in [MODEL_CACHE, UNET_CACHE]:
|
if os.path.exists(MODEL_CACHE):
|
||||||
if os.path.exists(folder):
|
shutil.rmtree(MODEL_CACHE)
|
||||||
shutil.rmtree(folder)
|
os.makedirs(MODEL_CACHE)
|
||||||
os.makedirs(folder)
|
|
||||||
|
|
||||||
model, unet_file = _load_model(checkpoint, local_only=False)
|
pred = RiffusionPredictor()
|
||||||
return model, unet_file
|
pred.setup(local_files_only=False)
|
||||||
|
|
||||||
|
|
||||||
def _load_model(checkpoint: str, local_only=False):
|
|
||||||
model = RiffusionPipeline.from_pretrained(
|
|
||||||
checkpoint,
|
|
||||||
revision="main",
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
# Disable the NSFW filter, causes incorrect false positives
|
|
||||||
safety_checker=lambda images, **kwargs: (images, False),
|
|
||||||
cache_dir=MODEL_CACHE,
|
|
||||||
local_files_only=local_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
unet_file = hf_hub_download(
|
|
||||||
"riffusion/riffusion-model-v1",
|
|
||||||
filename="unet_traced.pt",
|
|
||||||
subfolder="unet_traced",
|
|
||||||
cache_dir=UNET_CACHE,
|
|
||||||
local_files_only=local_only,
|
|
||||||
)
|
|
||||||
return model, unet_file
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -178,4 +155,4 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.download_weights:
|
if args.download_weights:
|
||||||
download_weights(MODEL_ID)
|
download_weights()
|
||||||
|
|
|
@ -70,6 +70,7 @@ class RiffusionPipeline(DiffusionPipeline):
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
low_cpu_mem_usage: bool = False,
|
low_cpu_mem_usage: bool = False,
|
||||||
|
cache_dir: T.Optional[str] = None,
|
||||||
) -> RiffusionPipeline:
|
) -> RiffusionPipeline:
|
||||||
"""
|
"""
|
||||||
Load the riffusion model pipeline.
|
Load the riffusion model pipeline.
|
||||||
|
@ -97,6 +98,7 @@ class RiffusionPipeline(DiffusionPipeline):
|
||||||
safety_checker=lambda images, **kwargs: (images, False),
|
safety_checker=lambda images, **kwargs: (images, False),
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
cache_dir=cache_dir,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
if channels_last:
|
if channels_last:
|
||||||
|
@ -111,6 +113,8 @@ class RiffusionPipeline(DiffusionPipeline):
|
||||||
in_channels=pipeline.unet.in_channels,
|
in_channels=pipeline.unet.in_channels,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
if traced_unet is not None:
|
if traced_unet is not None:
|
||||||
|
@ -128,6 +132,8 @@ class RiffusionPipeline(DiffusionPipeline):
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
|
local_files_only=False,
|
||||||
|
cache_dir: T.Optional[str] = None,
|
||||||
) -> T.Optional[torch.nn.Module]:
|
) -> T.Optional[torch.nn.Module]:
|
||||||
"""
|
"""
|
||||||
Load a traced unet from the huggingface hub. This can improve performance.
|
Load a traced unet from the huggingface hub. This can improve performance.
|
||||||
|
@ -141,6 +147,8 @@ class RiffusionPipeline(DiffusionPipeline):
|
||||||
checkpoint,
|
checkpoint,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
filename=filename,
|
filename=filename,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
unet_traced = torch.jit.load(unet_file)
|
unet_traced = torch.jit.load(unet_file)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue