diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index b36b09c..f1ad9a0 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -570,7 +570,7 @@ class AspectDataset(torch.utils.data.Dataset): input_ids = z else: for i, x in enumerate(input_ids): - input_ids[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.max_length - len(x) - 1), self.tokenizer.eos_token_id)] + input_ids[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.tokenizer.model_max_length - len(x) - 1), self.tokenizer.eos_token_id)] if args.clip_penultimate: input_ids = self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True)['hidden_states'][-2]) else: