diff --git a/riffusion/baseten.py b/riffusion/baseten.py index 25b96c6..1966487 100644 --- a/riffusion/baseten.py +++ b/riffusion/baseten.py @@ -7,15 +7,15 @@ For more on the Truss file format, see https://truss.baseten.co/ """ import base64 -from typing import Dict, List - import dataclasses +import json import io from pathlib import Path +from typing import Dict, List + import PIL import torch import dacite -import json from huggingface_hub import hf_hub_download, snapshot_download diff --git a/riffusion/server.py b/riffusion/server.py index 8f263f8..c931f6e 100644 --- a/riffusion/server.py +++ b/riffusion/server.py @@ -84,7 +84,7 @@ def load_model(checkpoint: str): torch_dtype=torch.float16, # Disable the NSFW filter, causes incorrect false positives safety_checker=lambda images, **kwargs: (images, False), - ) + ).to("cuda") @dataclasses.dataclass class UNet2DConditionOutput: