use penultimate layer

from NAI blogpost and https://arxiv.org/pdf/2205.11487.pdf
This commit is contained in:
harubaru 2022-10-12 17:55:56 -07:00
parent e4736c11f5
commit 4f9070af3c
2 changed files with 9 additions and 2 deletions

View File

@ -67,6 +67,8 @@ model:
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
penultimate: true # use 2nd last layer - https://arxiv.org/pdf/2205.11487.pdf D.1
data:
target: main.DataModuleFromConfig

View File

@ -136,12 +136,13 @@ class SpatialRescaler(nn.Module):
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
self.freeze()
def freeze(self):
@ -155,7 +156,11 @@ class FrozenCLIPEmbedder(AbstractEncoder):
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
if self.penultimate:
z = outputs.hidden_states[-2] # simple enough
else:
z = outputs.last_hidden_state
return z
def encode(self, text):