[LDMTextToImagePipeline] make text model generic (#162)

make text model generic
This commit is contained in:
Suraj Patil 2022-08-09 19:16:17 +05:30 committed by GitHub
parent 75b6c16567
commit 543ee1e092
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 4 deletions

View File

@ -45,11 +45,11 @@ class LDMTextToImagePipeline(DiffusionPipeline):
# get unconditional embeddings for classifier free guidance
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))[0]
# get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(torch_device))
text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0]
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
@ -618,5 +618,4 @@ class LDMBertModel(LDMBertPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
return sequence_output
return outputs