2024-03-14 17:22:36 -06:00
from __future__ import annotations
2024-01-20 10:21:36 -07:00
from pathlib import Path
2024-01-31 23:40:15 -07:00
from modules import errors
2022-09-09 14:16:02 -06:00
import csv
2022-09-11 08:35:12 -06:00
import os
import typing
import shutil
2022-09-09 14:16:02 -06:00
2022-09-11 08:35:12 -06:00
class PromptStyle ( typing . NamedTuple ) :
name : str
2024-01-20 10:21:36 -07:00
prompt : str | None
negative_prompt : str | None
path : str | None = None
2023-11-27 04:39:50 -07:00
2022-09-14 08:56:21 -06:00
def merge_prompts ( style_prompt : str , prompt : str ) - > str :
if " {prompt} " in style_prompt :
res = style_prompt . replace ( " {prompt} " , prompt )
else :
parts = filter ( None , ( prompt . strip ( ) , style_prompt . strip ( ) ) )
res = " , " . join ( parts )
2022-09-09 14:16:02 -06:00
2022-09-14 08:56:21 -06:00
return res
2022-09-09 14:16:02 -06:00
2022-09-14 08:56:21 -06:00
def apply_styles_to_prompt ( prompt , styles ) :
for style in styles :
prompt = merge_prompts ( style , prompt )
2022-09-09 15:51:07 -06:00
2023-12-04 12:40:12 -07:00
return prompt
2022-09-09 14:16:02 -06:00
2023-12-30 06:51:02 -07:00
def extract_style_text_from_prompt ( style_text , prompt ) :
""" This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
2023-06-04 01:56:48 -06:00
2023-12-30 06:51:02 -07:00
extract_style_text_from_prompt ( " masterpiece " , " 1girl, art by greg, masterpiece " ) outputs ( True , " 1girl, art by greg " )
extract_style_text_from_prompt ( " masterpiece, {prompt} " , " masterpiece, 1girl, art by greg " ) outputs ( True , " 1girl, art by greg " )
extract_style_text_from_prompt ( " masterpiece, {prompt} " , " exquisite, 1girl, art by greg " ) outputs ( False , " exquisite, 1girl, art by greg " )
2023-11-27 04:39:50 -07:00
"""
2023-12-30 06:51:02 -07:00
stripped_prompt = prompt . strip ( )
stripped_style_text = style_text . strip ( )
2023-06-04 01:56:48 -06:00
if " {prompt} " in stripped_style_text :
2024-03-04 20:23:44 -07:00
left , _ , right = stripped_style_text . partition ( " {prompt} " )
2023-06-04 01:56:48 -06:00
if stripped_prompt . startswith ( left ) and stripped_prompt . endswith ( right ) :
2023-12-30 06:51:02 -07:00
prompt = stripped_prompt [ len ( left ) : len ( stripped_prompt ) - len ( right ) ]
2023-06-04 01:56:48 -06:00
return True , prompt
else :
if stripped_prompt . endswith ( stripped_style_text ) :
2023-12-30 06:51:02 -07:00
prompt = stripped_prompt [ : len ( stripped_prompt ) - len ( stripped_style_text ) ]
if prompt . endswith ( ' , ' ) :
2023-06-04 01:56:48 -06:00
prompt = prompt [ : - 2 ]
2023-12-30 06:51:02 -07:00
2023-06-04 01:56:48 -06:00
return True , prompt
return False , prompt
2023-11-27 04:39:50 -07:00
def extract_original_prompts ( style : PromptStyle , prompt , negative_prompt ) :
"""
Takes a style and compares it to the prompt and negative prompt . If the style
matches , returns True plus the prompt and negative prompt with the style text
removed . Otherwise , returns False with the original prompt and negative prompt .
"""
2023-06-04 01:56:48 -06:00
if not style . prompt and not style . negative_prompt :
return False , prompt , negative_prompt
2023-12-30 06:51:02 -07:00
match_positive , extracted_positive = extract_style_text_from_prompt ( style . prompt , prompt )
2023-06-04 01:56:48 -06:00
if not match_positive :
return False , prompt , negative_prompt
2023-12-30 06:51:02 -07:00
match_negative , extracted_negative = extract_style_text_from_prompt ( style . negative_prompt , negative_prompt )
2023-06-04 01:56:48 -06:00
if not match_negative :
return False , prompt , negative_prompt
return True , extracted_positive , extracted_negative
2022-09-14 08:56:21 -06:00
class StyleDatabase :
2024-01-20 10:21:36 -07:00
def __init__ ( self , paths : list [ str | Path ] ) :
2023-11-27 04:39:50 -07:00
self . no_style = PromptStyle ( " None " , " " , " " , None )
2023-01-14 04:56:39 -07:00
self . styles = { }
2024-01-20 10:21:36 -07:00
self . paths = paths
self . all_styles_files : list [ Path ] = [ ]
folder , file = os . path . split ( self . paths [ 0 ] )
if ' * ' in file or ' ? ' in file :
# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
self . default_path = next ( Path ( folder ) . glob ( file ) , Path ( os . path . join ( folder , ' styles.csv ' ) ) )
self . paths . insert ( 0 , self . default_path )
else :
self . default_path = Path ( self . paths [ 0 ] )
2023-11-27 04:39:50 -07:00
self . prompt_fields = [ field for field in PromptStyle . _fields if field != " path " ]
2023-01-14 04:56:39 -07:00
self . reload ( )
def reload ( self ) :
2023-11-27 04:39:50 -07:00
"""
Clears the style database and reloads the styles from the CSV file ( s )
matching the path used to initialize the database .
"""
2023-01-14 04:56:39 -07:00
self . styles . clear ( )
2024-01-20 10:21:36 -07:00
# scans for all styles files
all_styles_files = [ ]
for pattern in self . paths :
folder , file = os . path . split ( pattern )
if ' * ' in file or ' ? ' in file :
found_files = Path ( folder ) . glob ( file )
[ all_styles_files . append ( file ) for file in found_files ]
else :
# if os.path.exists(pattern):
all_styles_files . append ( Path ( pattern ) )
# Remove any duplicate entries
seen = set ( )
self . all_styles_files = [ s for s in all_styles_files if not ( s in seen or seen . add ( s ) ) ]
for styles_file in self . all_styles_files :
if len ( all_styles_files ) > 1 :
# add divider when more than styles file
# '---------------- STYLES ----------------'
divider = f ' { styles_file . stem . upper ( ) } ' . center ( 40 , ' - ' )
self . styles [ divider ] = PromptStyle ( f " { divider } " , None , None , " do_not_save " )
if styles_file . is_file ( ) :
self . load_from_csv ( styles_file )
def load_from_csv ( self , path : str | Path ) :
2024-01-31 23:40:15 -07:00
try :
with open ( path , " r " , encoding = " utf-8-sig " , newline = " " ) as file :
reader = csv . DictReader ( file , skipinitialspace = True )
for row in reader :
# Ignore empty rows or rows starting with a comment
if not row or row [ " name " ] . startswith ( " # " ) :
continue
# Support loading old CSV format with "name, text"-columns
prompt = row [ " prompt " ] if " prompt " in row else row [ " text " ]
negative_prompt = row . get ( " negative_prompt " , " " )
# Add style to database
self . styles [ row [ " name " ] ] = PromptStyle (
row [ " name " ] , prompt , negative_prompt , str ( path )
)
except Exception :
errors . report ( f ' Error loading styles from { path } : ' , exc_info = True )
2023-11-27 04:39:50 -07:00
2023-12-09 22:03:41 -07:00
def get_style_paths ( self ) - > set :
""" Returns a set of all distinct paths of files that styles are loaded from. """
2023-11-27 04:39:50 -07:00
# Update any styles without a path to the default path
for style in list ( self . styles . values ( ) ) :
if not style . path :
2024-01-20 10:21:36 -07:00
self . styles [ style . name ] = style . _replace ( path = str ( self . default_path ) )
2023-11-27 04:39:50 -07:00
# Create a list of all distinct paths, including the default path
style_paths = set ( )
2024-01-20 10:21:36 -07:00
style_paths . add ( str ( self . default_path ) )
2023-11-27 04:39:50 -07:00
for _ , style in self . styles . items ( ) :
if style . path :
style_paths . add ( style . path )
# Remove any paths for styles that are just list dividers
2023-12-09 22:03:41 -07:00
style_paths . discard ( " do_not_save " )
2023-11-27 04:39:50 -07:00
2023-12-09 22:03:41 -07:00
return style_paths
2022-09-14 08:56:21 -06:00
2022-09-29 21:01:32 -06:00
def get_style_prompts ( self , styles ) :
return [ self . styles . get ( x , self . no_style ) . prompt for x in styles ]
def get_negative_style_prompts ( self , styles ) :
return [ self . styles . get ( x , self . no_style ) . negative_prompt for x in styles ]
2022-09-14 08:56:21 -06:00
def apply_styles_to_prompt ( self , prompt , styles ) :
2023-11-27 04:39:50 -07:00
return apply_styles_to_prompt (
prompt , [ self . styles . get ( x , self . no_style ) . prompt for x in styles ]
)
2022-09-14 08:56:21 -06:00
def apply_negative_styles_to_prompt ( self , prompt , styles ) :
2023-11-27 04:39:50 -07:00
return apply_styles_to_prompt (
prompt , [ self . styles . get ( x , self . no_style ) . negative_prompt for x in styles ]
)
def save_styles ( self , path : str = None ) - > None :
# The path argument is deprecated, but kept for backwards compatibility
2023-12-09 22:03:41 -07:00
style_paths = self . get_style_paths ( )
2023-11-27 04:39:50 -07:00
csv_names = [ os . path . split ( path ) [ 1 ] . lower ( ) for path in style_paths ]
for style_path in style_paths :
# Always keep a backup file around
if os . path . exists ( style_path ) :
shutil . copy ( style_path , f " { style_path } .bak " )
# Write the styles to the CSV file
with open ( style_path , " w " , encoding = " utf-8-sig " , newline = " " ) as file :
writer = csv . DictWriter ( file , fieldnames = self . prompt_fields )
writer . writeheader ( )
for style in ( s for s in self . styles . values ( ) if s . path == style_path ) :
# Skip style list dividers, e.g. "STYLES.CSV"
if style . name . lower ( ) . strip ( " # " ) in csv_names :
continue
# Write style fields, ignoring the path field
writer . writerow (
{ k : v for k , v in style . _asdict ( ) . items ( ) if k != " path " }
)
2023-06-04 01:56:48 -06:00
def extract_styles_from_prompt ( self , prompt , negative_prompt ) :
extracted = [ ]
applicable_styles = list ( self . styles . values ( ) )
while True :
found_style = None
for style in applicable_styles :
2023-11-27 04:39:50 -07:00
is_match , new_prompt , new_neg_prompt = extract_original_prompts (
style , prompt , negative_prompt
)
2023-06-04 01:56:48 -06:00
if is_match :
found_style = style
prompt = new_prompt
negative_prompt = new_neg_prompt
break
if not found_style :
break
applicable_styles . remove ( found_style )
extracted . append ( found_style . name )
return list ( reversed ( extracted ) ) , prompt , negative_prompt