use penultimate layer
from NAI blogpost and https://arxiv.org/pdf/2205.11487.pdf
This commit is contained in:
parent
e4736c11f5
commit
4f9070af3c
|
@ -67,6 +67,8 @@ model:
|
|||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
penultimate: true # use 2nd last layer - https://arxiv.org/pdf/2205.11487.pdf D.1
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
|
|
|
@ -136,12 +136,13 @@ class SpatialRescaler(nn.Module):
|
|||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
|
@ -155,7 +156,11 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
if self.penultimate:
|
||||
z = outputs.hidden_states[-2] # simple enough
|
||||
else:
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
|
|
Loading…
Reference in New Issue