Merge pull request #56 from harubaru/extended-mode-fix
Extended Mode Typo Fix and Seed Update
This commit is contained in:
commit
c709257bec
|
@ -570,7 +570,7 @@ class AspectDataset(torch.utils.data.Dataset):
|
||||||
input_ids = z
|
input_ids = z
|
||||||
else:
|
else:
|
||||||
for i, x in enumerate(input_ids):
|
for i, x in enumerate(input_ids):
|
||||||
input_ids[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.max_length - len(x) - 1), self.tokenizer.eos_token_id)]
|
input_ids[i] = [self.tokenizer.bos_token_id, *x, *np.full((self.tokenizer.model_max_length - len(x) - 1), self.tokenizer.eos_token_id)]
|
||||||
if args.clip_penultimate:
|
if args.clip_penultimate:
|
||||||
input_ids = self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True)['hidden_states'][-2])
|
input_ids = self.text_encoder.text_model.final_layer_norm(self.text_encoder(torch.asarray(input_ids).to(self.device), output_hidden_states=True)['hidden_states'][-2])
|
||||||
else:
|
else:
|
||||||
|
@ -717,6 +717,8 @@ def main():
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
|
random.seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
print('RANDOM SEED:', args.seed)
|
print('RANDOM SEED:', args.seed)
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
|
|
Loading…
Reference in New Issue