[textual_inversion] use tokenizer.add_tokens to add placeholder_token (#357)
use add_tokens
This commit is contained in:
parent
9ea9c6d1c2
commit
55d6453fce
|
@ -357,15 +357,18 @@ def main():
|
|||
|
||||
# Load the tokenizer and add the placeholder token as a additional special token
|
||||
if args.tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.tokenizer_name, additional_special_tokens=[args.placeholder_token]
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
additional_special_tokens=[args.placeholder_token],
|
||||
subfolder="tokenizer",
|
||||
use_auth_token=args.use_auth_token,
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=args.use_auth_token
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
|
|
Loading…
Reference in New Issue