to cuda
This commit is contained in:
parent
79239a719d
commit
55651db0fc
|
@ -7,15 +7,15 @@ For more on the Truss file format, see https://truss.baseten.co/
|
|||
"""
|
||||
|
||||
import base64
|
||||
from typing import Dict, List
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
import dacite
|
||||
import json
|
||||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ def load_model(checkpoint: str):
|
|||
torch_dtype=torch.float16,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
)
|
||||
).to("cuda")
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UNet2DConditionOutput:
|
||||
|
|
Loading…
Reference in New Issue