From 4f9070af3c94e4d5d6d9819646d3ce3f776e0393 Mon Sep 17 00:00:00 2001 From: harubaru Date: Wed, 12 Oct 2022 17:55:56 -0700 Subject: [PATCH] use penultimate layer from NAI blogpost and https://arxiv.org/pdf/2205.11487.pdf --- configs/stable-diffusion/v1-4-finetune-test.yaml | 2 ++ ldm/modules/encoders/modules.py | 9 +++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/configs/stable-diffusion/v1-4-finetune-test.yaml b/configs/stable-diffusion/v1-4-finetune-test.yaml index 93dd2a9..bd446c9 100644 --- a/configs/stable-diffusion/v1-4-finetune-test.yaml +++ b/configs/stable-diffusion/v1-4-finetune-test.yaml @@ -67,6 +67,8 @@ model: cond_stage_config: target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + params: + penultimate: true # use 2nd last layer - https://arxiv.org/pdf/2205.11487.pdf D.1 data: target: main.DataModuleFromConfig diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index ededbe4..2041738 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -136,12 +136,13 @@ class SpatialRescaler(nn.Module): class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) self.device = device self.max_length = max_length + self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf self.freeze() def freeze(self): @@ -155,7 +156,11 @@ class FrozenCLIPEmbedder(AbstractEncoder): tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer(input_ids=tokens) - z = outputs.last_hidden_state + if self.penultimate: + z = outputs.hidden_states[-2] # simple enough + else: + z = outputs.last_hidden_state + return z def encode(self, text):