Account when lines are mismatched
This commit is contained in:
parent
ee015a1af6
commit
80f3cf2bb2
|
@ -321,7 +321,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
fixes.append(fix[1])
|
fixes.append(fix[1])
|
||||||
self.hijack.fixes.append(fixes)
|
self.hijack.fixes.append(fixes)
|
||||||
|
|
||||||
z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers])
|
tokens = []
|
||||||
|
multipliers = []
|
||||||
|
for i in range(len(remade_batch_tokens)):
|
||||||
|
if len(remade_batch_tokens[i]) > 0:
|
||||||
|
tokens.append(remade_batch_tokens[i][:75])
|
||||||
|
multipliers.append(batch_multipliers[i][:75])
|
||||||
|
else:
|
||||||
|
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
|
||||||
|
multipliers.append([1.0] * 75)
|
||||||
|
|
||||||
|
z1 = self.process_tokens(tokens, multipliers)
|
||||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||||
|
|
||||||
remade_batch_tokens = rem_tokens
|
remade_batch_tokens = rem_tokens
|
||||||
|
|
Loading…
Reference in New Issue