2022-06-07 04:19:53 -06:00
|
|
|
import torch
|
2022-06-09 06:06:58 -06:00
|
|
|
|
2022-06-09 03:43:51 -06:00
|
|
|
import PIL.Image
|
2022-06-09 06:06:58 -06:00
|
|
|
from diffusers import DiffusionPipeline
|
|
|
|
|
2022-06-08 03:53:12 -06:00
|
|
|
|
2022-06-07 04:19:53 -06:00
|
|
|
generator = torch.Generator()
|
|
|
|
generator = generator.manual_seed(0)
|
|
|
|
|
2022-06-09 03:43:51 -06:00
|
|
|
model_id = "fusing/glide-base"
|
|
|
|
|
|
|
|
# load model and scheduler
|
|
|
|
pipeline = DiffusionPipeline.from_pretrained(model_id)
|
|
|
|
|
|
|
|
# run inference (text-conditioned denoising + upscaling)
|
2022-06-09 04:42:59 -06:00
|
|
|
img = pipeline("a crayon drawing of a corgi", generator)
|
2022-06-07 04:19:53 -06:00
|
|
|
|
2022-06-09 03:43:51 -06:00
|
|
|
# process image to PIL
|
2022-06-09 04:42:59 -06:00
|
|
|
img = img.squeeze(0)
|
2022-06-09 06:06:58 -06:00
|
|
|
img = ((img + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
2022-06-09 03:43:51 -06:00
|
|
|
image_pil = PIL.Image.fromarray(img)
|
2022-06-07 04:19:53 -06:00
|
|
|
|
2022-06-09 03:43:51 -06:00
|
|
|
# save image
|
2022-06-09 06:06:58 -06:00
|
|
|
image_pil.save("test.png")
|