make it possible again to extract styles that have whitespace at the end.

This commit is contained in:
AUTOMATIC1111 2023-12-30 16:51:02 +03:00
parent adcd65ba34
commit 31992eff9b
1 changed files with 17 additions and 30 deletions

View File

@ -30,38 +30,29 @@ def apply_styles_to_prompt(prompt, styles):
return prompt return prompt
def unwrap_style_text_from_prompt(style_text, prompt): def extract_style_text_from_prompt(style_text, prompt):
""" """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
Checks the prompt to see if the style text is wrapped around it. If so,
returns True plus the prompt text without the style text. Otherwise, returns
False with the original prompt.
Note that the "cleaned" version of the style text is only used for matching extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
purposes here. It isn't returned; the original style text is not modified. extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
""" """
stripped_prompt = prompt
stripped_style_text = style_text stripped_prompt = prompt.strip()
stripped_style_text = style_text.strip()
if "{prompt}" in stripped_style_text: if "{prompt}" in stripped_style_text:
# Work out whether the prompt is wrapped in the style text. If so, we left, right = stripped_style_text.split("{prompt}", 2)
# return True and the "inner" prompt text that isn't part of the style.
try:
left, right = stripped_style_text.split("{prompt}", 2)
except ValueError as e:
# If the style text has multple "{prompt}"s, we can't split it into
# two parts. This is an error, but we can't do anything about it.
print(f"Unable to compare style text to prompt:\n{style_text}")
print(f"Error: {e}")
return False, prompt
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)] prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
return True, prompt return True, prompt
else: else:
# Work out whether the given prompt ends with the style text. If so, we
# return True and the prompt text up to where the style text starts.
if stripped_prompt.endswith(stripped_style_text): if stripped_prompt.endswith(stripped_style_text):
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)] prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
if prompt.endswith(", "):
if prompt.endswith(', '):
prompt = prompt[:-2] prompt = prompt[:-2]
return True, prompt return True, prompt
return False, prompt return False, prompt
@ -76,15 +67,11 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
if not style.prompt and not style.negative_prompt: if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt return False, prompt, negative_prompt
match_positive, extracted_positive = unwrap_style_text_from_prompt( match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
style.prompt, prompt
)
if not match_positive: if not match_positive:
return False, prompt, negative_prompt return False, prompt, negative_prompt
match_negative, extracted_negative = unwrap_style_text_from_prompt( match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
style.negative_prompt, negative_prompt
)
if not match_negative: if not match_negative:
return False, prompt, negative_prompt return False, prompt, negative_prompt