This commit is contained in:
Hayk Martiros 2022-12-13 06:55:07 +00:00
parent 79239a719d
commit 55651db0fc
2 changed files with 4 additions and 4 deletions

View File

@ -7,15 +7,15 @@ For more on the Truss file format, see https://truss.baseten.co/
"""
import base64
from typing import Dict, List
import dataclasses
import json
import io
from pathlib import Path
from typing import Dict, List
import PIL
import torch
import dacite
import json
from huggingface_hub import hf_hub_download, snapshot_download

View File

@ -84,7 +84,7 @@ def load_model(checkpoint: str):
torch_dtype=torch.float16,
# Disable the NSFW filter, causes incorrect false positives
safety_checker=lambda images, **kwargs: (images, False),
)
).to("cuda")
@dataclasses.dataclass
class UNet2DConditionOutput: