fix sorting of tree list

This commit is contained in:
Sj-Si 2024-04-15 16:10:12 -04:00
parent 49d146393c
commit 2a8403e0a0
1 changed files with 35 additions and 19 deletions

View File

@ -14,8 +14,7 @@ from fastapi.exceptions import HTTPException
from PIL import Image
from starlette.responses import FileResponse, JSONResponse, Response
from modules import (errors, extra_networks, shared,
ui_extra_networks_user_metadata, util)
from modules import errors, extra_networks, shared, ui_extra_networks_user_metadata, util
from modules.images import read_info_from_image, save_image_with_geninfo
from modules.infotext_utils import image_from_url_text
@ -134,9 +133,9 @@ class DirectoryTreeNode:
"""Flattens the keys/values of the tree nodes into a dictionary.
Args:
res: The dictionary result updated in place. On initial call, should be passed
as an empty dictionary.
dirs_only: Whether to only add directories to the result.
res: The dictionary result updated in place. On initial call,
should be passed as an empty dictionary.
dirs_only: Whether to only add directories to the result.
Raises:
KeyError: If any nodes in the tree have the same ID.
@ -150,6 +149,25 @@ class DirectoryTreeNode:
for child in self.children:
child.flatten(res, dirs_only)
def to_sorted_list(self, res: list) -> None:
"""Sorts the tree by absolute path and groups by directories/files.
Since we are sorting a directory tree, we always want the directories to come
before the files. So we have to sort these two lists separately.
Args:
res: The list result updated in place. On initial call, should be passed
as an empty list.
"""
res.append(self)
dir_children = [x for x in self.children if x.is_dir]
file_children = [x for x in self.children if not x.is_dir]
for child in sorted(dir_children, key=lambda x: shared.natural_sort_key(x.abspath)):
child.to_sorted_list(res)
for child in sorted(file_children, key=lambda x: shared.natural_sort_key(x.abspath)):
child.to_sorted_list(res)
def apply(self, fn: Callable) -> None:
"""Recursively calls passed function with instance for entire tree."""
fn(self)
@ -693,20 +711,21 @@ class ExtraNetworksPage:
if not self.tree_roots:
return {}
# Flatten each root into a single dict
tree = {}
# Flatten roots into a single sorted list of nodes.
# Directories always come before files. After that, natural sort is used.
sorted_nodes = []
for node in self.tree_roots.values():
subtree = {}
node.flatten(subtree)
tree.update(subtree)
_sorted_nodes = []
node.to_sorted_list(_sorted_nodes)
sorted_nodes.extend(_sorted_nodes)
path_to_div_id = {}
div_id_to_node = {} # reverse mapping
# First assign div IDs to each node. Used for parent ID lookup later.
for i, path in enumerate(sorted(tree.keys(), key=shared.natural_sort_key)):
for i, node in enumerate(sorted_nodes):
div_id = str(i)
path_to_div_id[path] = div_id
div_id_to_node[div_id] = tree[path]
path_to_div_id[node.abspath] = div_id
div_id_to_node[div_id] = node
show_files = shared.opts.extra_networks_tree_view_show_files is True
for div_id, node in div_id_to_node.items():
@ -966,12 +985,9 @@ def initialize():
def register_default_pages():
from modules.ui_extra_networks_checkpoints import \
ExtraNetworksPageCheckpoints
from modules.ui_extra_networks_hypernets import \
ExtraNetworksPageHypernetworks
from modules.ui_extra_networks_textual_inversion import \
ExtraNetworksPageTextualInversion
from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
register_page(ExtraNetworksPageTextualInversion())
register_page(ExtraNetworksPageHypernetworks())