fix tokenizers pipeline
This commit is contained in:
parent
dc6324d44b
commit
c6a33e3d24
|
@ -1,19 +1,21 @@
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
from modeling_glide import GLIDE
|
import PIL.Image
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
matplotlib.rcParams['interactive'] = True
|
|
||||||
|
|
||||||
|
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator = generator.manual_seed(0)
|
generator = generator.manual_seed(0)
|
||||||
|
|
||||||
pipeline = GLIDE.from_pretrained("fusing/glide-base")
|
model_id = "fusing/glide-base"
|
||||||
|
|
||||||
img = pipeline("a pencil sketch of a corgi", generator)
|
# load model and scheduler
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(model_id)
|
||||||
|
|
||||||
|
# run inference (text-conditioned denoising + upscaling)
|
||||||
|
img = pipeline("a clip art of a hugging face", generator)
|
||||||
|
|
||||||
|
# process image to PIL
|
||||||
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||||
|
image_pil = PIL.Image.fromarray(img)
|
||||||
|
|
||||||
plt.figure(figsize=(8, 8))
|
# save image
|
||||||
plt.imshow(img)
|
image_pil.save("test.png")
|
||||||
plt.show()
|
|
|
@ -42,7 +42,7 @@ LOADABLE_CLASSES = {
|
||||||
"GlideDDIMScheduler": ["save_config", "from_config"],
|
"GlideDDIMScheduler": ["save_config", "from_config"],
|
||||||
},
|
},
|
||||||
"transformers": {
|
"transformers": {
|
||||||
"GPT2Tokenizer": ["save_pretrained", "from_pretrained"],
|
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue