use penultimate layer
from NAI blogpost and https://arxiv.org/pdf/2205.11487.pdf
This commit is contained in:
parent
e4736c11f5
commit
4f9070af3c
|
@ -67,6 +67,8 @@ model:
|
||||||
|
|
||||||
cond_stage_config:
|
cond_stage_config:
|
||||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
penultimate: true # use 2nd last layer - https://arxiv.org/pdf/2205.11487.pdf D.1
|
||||||
|
|
||||||
data:
|
data:
|
||||||
target: main.DataModuleFromConfig
|
target: main.DataModuleFromConfig
|
||||||
|
|
|
@ -136,12 +136,13 @@ class SpatialRescaler(nn.Module):
|
||||||
|
|
||||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
|
||||||
self.freeze()
|
self.freeze()
|
||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
|
@ -155,7 +156,11 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
tokens = batch_encoding["input_ids"].to(self.device)
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
outputs = self.transformer(input_ids=tokens)
|
outputs = self.transformer(input_ids=tokens)
|
||||||
|
|
||||||
|
if self.penultimate:
|
||||||
|
z = outputs.hidden_states[-2] # simple enough
|
||||||
|
else:
|
||||||
z = outputs.last_hidden_state
|
z = outputs.last_hidden_state
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
|
|
Loading…
Reference in New Issue