[Textual Inversion] Do not update other embeddings (#1665)

This commit is contained in:
Patrick von Platen 2022-12-12 17:44:39 +01:00 committed by GitHub
parent 3ce6380d3a
commit 69de9b2eaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 10 deletions

View File

@ -548,6 +548,9 @@ def main():
progress_bar.set_description("Steps")
global_step = 0
# keep original embeddings as reference
orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
for epoch in range(args.num_train_epochs):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
@ -585,20 +588,15 @@ def main():
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
# Zero out the gradients for all token embeddings except the newly added
# embeddings for the concept, as we only want to optimize the concept embeddings
if accelerator.num_processes > 1:
grads = text_encoder.module.get_input_embeddings().weight.grad
else:
grads = text_encoder.get_input_embeddings().weight.grad
# Get the index for tokens that we want to zero the grads for
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad():
text_encoder.get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)