[LDMTextToImagePipeline] make text model generic (#162)
make text model generic
This commit is contained in:
parent
75b6c16567
commit
543ee1e092
|
@ -45,11 +45,11 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
|||
# get unconditional embeddings for classifier free guidance
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))[0]
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
|
||||
text_embeddings = self.bert(text_input.input_ids.to(torch_device))
|
||||
text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0]
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
|
@ -618,5 +618,4 @@ class LDMBertModel(LDMBertPreTrainedModel):
|
|||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
return sequence_output
|
||||
return outputs
|
||||
|
|
Loading…
Reference in New Issue