358 lines
13 KiB
Python
358 lines
13 KiB
Python
"""
|
|
Copyright [2022] Victor C Hall
|
|
|
|
Licensed under the GNU Affero General Public License;
|
|
You may not use this code except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
https://www.gnu.org/licenses/agpl-3.0.en.html
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
import bisect
|
|
import logging
|
|
import math
|
|
import os
|
|
import random
|
|
import typing
|
|
import yaml
|
|
|
|
import PIL
|
|
import PIL.Image as Image
|
|
import PIL.ImageOps as ImageOps
|
|
import numpy as np
|
|
from torchvision import transforms
|
|
|
|
OptionalImageCaption = typing.Optional['ImageCaption']
|
|
|
|
class ImageCaption:
|
|
"""
|
|
Represents the various parts of an image caption
|
|
"""
|
|
def __init__(self, main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
|
|
"""
|
|
:param main_prompt: The part of the caption which should always be included
|
|
:param tags: list of tags to pick from to fill the caption
|
|
:param tag_weights: weights to indicate which tags are more desired and should be picked preferably
|
|
:param max_target_length: The desired maximum length of a generated caption
|
|
:param use_weights: if ture, weights are considered when shuffling tags
|
|
"""
|
|
self.__main_prompt = main_prompt
|
|
self.__rating = rating
|
|
self.__tags = tags
|
|
self.__tag_weights = tag_weights
|
|
self.__max_target_length = max_target_length or 2048
|
|
self.__use_weights = use_weights
|
|
if use_weights and len(tags) > len(tag_weights):
|
|
self.__tag_weights.extend([1.0] * (len(tags) - len(tag_weights)))
|
|
|
|
if use_weights and len(tag_weights) > len(tags):
|
|
self.__tag_weights = tag_weights[:len(tags)]
|
|
|
|
def rating(self) -> float:
|
|
return self.__rating
|
|
|
|
def get_shuffled_caption(self, seed: int, keep_tags: int) -> str:
|
|
"""
|
|
returns the caption a string with a random selection of the tags in random order
|
|
:param seed used to initialize the randomizer
|
|
:return: generated caption string
|
|
"""
|
|
if self.__tags:
|
|
try:
|
|
max_target_tag_length = self.__max_target_length - len(self.__main_prompt or 0)
|
|
except Exception as e:
|
|
print()
|
|
logging.error(f"Error determining length for: {e} on {self.__main_prompt}")
|
|
print()
|
|
max_target_tag_length = 2048
|
|
|
|
if self.__use_weights:
|
|
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
|
|
else:
|
|
tags_caption = self.__get_shuffled_tags(seed, self.__tags, keep_tags)
|
|
|
|
return self.__main_prompt + ", " + tags_caption
|
|
return self.__main_prompt
|
|
|
|
def get_caption(self) -> str:
|
|
if self.__tags:
|
|
return self.__main_prompt + ", " + ", ".join(self.__tags)
|
|
return self.__main_prompt
|
|
|
|
@staticmethod
|
|
def __get_weighted_shuffled_tags(seed: int, tags: list[str], weights: list[float], max_target_tag_length: int) -> str:
|
|
picker = random.Random(seed)
|
|
tags_copy = tags.copy()
|
|
weights_copy = weights.copy()
|
|
|
|
caption = ""
|
|
while len(tags_copy) != 0 and len(caption) < max_target_tag_length:
|
|
cum_weights = []
|
|
weight_sum = 0.0
|
|
for weight in weights_copy:
|
|
weight_sum += weight
|
|
cum_weights.append(weight_sum)
|
|
|
|
point = picker.uniform(0, weight_sum)
|
|
pos = bisect.bisect_left(cum_weights, point)
|
|
|
|
weights_copy.pop(pos)
|
|
tag = tags_copy.pop(pos)
|
|
|
|
if caption:
|
|
caption += ", "
|
|
caption += tag
|
|
|
|
return caption
|
|
|
|
@staticmethod
|
|
def __get_shuffled_tags(seed: int, tags: list[str], keep_tags: int) -> str:
|
|
tags = tags.copy()
|
|
keep_tags = min(keep_tags, 0)
|
|
|
|
if len(tags) > keep_tags:
|
|
fixed_tags = tags[:keep_tags]
|
|
rest = tags[keep_tags:]
|
|
random.Random(seed).shuffle(rest)
|
|
tags = fixed_tags + rest
|
|
|
|
return ", ".join(tags)
|
|
|
|
class ImageTrainItem:
|
|
"""
|
|
image: PIL.Image
|
|
identifier: caption,
|
|
target_aspect: (width, height),
|
|
pathname: path to image file
|
|
flip_p: probability of flipping image (0.0 to 1.0)
|
|
rating: the relative rating of the images. The rating is measured in comparison to the other images.
|
|
"""
|
|
def __init__(self,
|
|
image: PIL.Image,
|
|
caption: ImageCaption,
|
|
aspects: list[float],
|
|
pathname: str,
|
|
flip_p=0.0,
|
|
multiplier: float=1.0,
|
|
cond_dropout=None,
|
|
shuffle_tags=False,
|
|
batch_id: str=None,
|
|
loss_scale: float=None
|
|
):
|
|
self.caption = caption
|
|
self.aspects = aspects
|
|
self.pathname = pathname
|
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
|
self.cropped_img = None
|
|
self.runt_size = 0
|
|
self.multiplier = multiplier
|
|
self.cond_dropout = cond_dropout
|
|
self.shuffle_tags = shuffle_tags
|
|
self.batch_id = batch_id or DEFAULT_BATCH_ID
|
|
self.loss_scale = 1 if loss_scale is None else loss_scale
|
|
self.target_wh = None
|
|
|
|
self.image_size = None
|
|
if image is None or len(image) == 0:
|
|
self.image = []
|
|
else:
|
|
self.image = image
|
|
self.image_size = image.size
|
|
#self.target_size = None
|
|
|
|
self.is_undersized = False
|
|
self.error = None
|
|
self.__compute_target_width_height()
|
|
|
|
def load_image(self):
|
|
try:
|
|
image = PIL.Image.open(self.pathname).convert('RGB')
|
|
image = self._try_transpose(image, print_error=False)
|
|
except SyntaxError as e:
|
|
pass
|
|
return image
|
|
|
|
def _try_transpose(self, image, print_error=False):
|
|
try:
|
|
image = ImageOps.exif_transpose(image)
|
|
except Exception as e:
|
|
logging.warning(F"Error rotating image: {e} on {self.pathname}, image will be loaded as is, EXIF may be corrupt") if print_error else None
|
|
pass
|
|
return image
|
|
|
|
def _needs_transpose(self, image, print_error=False):
|
|
try:
|
|
exif = image.getexif()
|
|
orientation = exif.get(0x0112)
|
|
"""
|
|
https://pillow.readthedocs.io/en/stable/_modules/PIL/ImageOps.html#exif_transpose
|
|
method = {
|
|
2: Image.Transpose.FLIP_LEFT_RIGHT,
|
|
3: Image.Transpose.ROTATE_180,
|
|
4: Image.Transpose.FLIP_TOP_BOTTOM,
|
|
5: Image.Transpose.TRANSPOSE,
|
|
6: Image.Transpose.ROTATE_270,
|
|
7: Image.Transpose.TRANSVERSE,
|
|
8: Image.Transpose.ROTATE_90,
|
|
}.get(orientation)
|
|
"""
|
|
return orientation in [5, 6, 7, 8]
|
|
except Exception as e:
|
|
logging.warning(F"Error rotating image: {e} on {self.pathname}, image will be loaded as is, EXIF may be corrupt") if print_error else None
|
|
pass
|
|
return False
|
|
|
|
|
|
def _percent_random_crop(self, image, crop_jitter=0.02):
|
|
"""
|
|
randomly crops the image by a percentage of the image size on each of the four sides
|
|
"""
|
|
width, height = image.size
|
|
max_crop_pixels = int(min(width, height) * crop_jitter)
|
|
|
|
left_crop_pixels = int(round(random.uniform(0, max_crop_pixels)))
|
|
right_crop_pixels = int(round(random.uniform(0, max_crop_pixels)))
|
|
top_crop_pixels = int(round(random.uniform(0, max_crop_pixels)))
|
|
bottom_crop_pixels = int(round(random.uniform(0, max_crop_pixels)))
|
|
|
|
# print(f"{left_crop_pixels}, {right_crop_pixels}, {top_crop_pixels}, {bottom_crop_pixels}, ")
|
|
|
|
left = left_crop_pixels
|
|
right = width - right_crop_pixels
|
|
top = top_crop_pixels
|
|
bottom = height - bottom_crop_pixels
|
|
|
|
crop_size = image.crop((left, top, right, bottom))
|
|
|
|
return crop_size
|
|
|
|
def _debug_save_image(self, image, folder=""):
|
|
base_name = os.path.basename(self.pathname)
|
|
target_dir = os.path.join('test/output', folder)
|
|
target_file = os.path.join(target_dir, base_name)
|
|
|
|
if not os.path.exists(target_dir):
|
|
os.makedirs(target_dir)
|
|
|
|
try:
|
|
#print(f"saving to test/output: {os.path.join('test/output', folder, base_name)}")
|
|
image.save(target_file)
|
|
except Exception as e:
|
|
print(f"error for debug saving image of {self.pathname}: {e}")
|
|
pass
|
|
|
|
def _trim_to_aspect(self, image, target_wh):
|
|
try:
|
|
width, height = image.size
|
|
target_aspect = target_wh[0] / target_wh[1] # 0.60
|
|
image_aspect = width / height # 0.5865
|
|
#self._debug_save_image(image, "precrop")
|
|
if image_aspect > target_aspect:
|
|
target_width = int(height * target_aspect)
|
|
overwidth = width - target_width
|
|
l = random.triangular(0, overwidth)
|
|
#print(f"l: {l}, overwidth: {overwidth}")
|
|
l = max(0, l)
|
|
l = int(min(l, overwidth))
|
|
r = width - overwidth + l
|
|
#print(f"\n_trim_to_aspect actual ar: {image_aspect}, target ar:{target_aspect:.2f}, {image.size}, cropping with box: {l}, 0, {r}, {height}, {self.pathname}")
|
|
image = image.crop((l, 0, r, height))
|
|
elif target_aspect > image_aspect:
|
|
target_height = int(width / target_aspect)
|
|
overheight = height - target_height
|
|
t = random.triangular(0, overheight)
|
|
#print(f"t: {t}, overheight: {overheight}")
|
|
t = max(0, t)
|
|
t = int(min(t, overheight))
|
|
b = height - overheight + t
|
|
#print(f"\n_trim_to_aspect actual ar: {image_aspect}, target ar:{target_aspect:.2f}, {image.size}, cropping with box: 0, {t}, {width}, {b}, {self.pathname}")
|
|
image = image.crop((0, t, width, b))
|
|
|
|
except Exception as e:
|
|
logging.error(f"fatal error trimming image {self.pathname}: {e}")
|
|
raise e
|
|
return image
|
|
|
|
def hydrate(self, save=False, crop_jitter=0.02):
|
|
image = self.load_image()
|
|
|
|
width, height = image.size
|
|
|
|
img_jitter = min((width-self.target_wh[0])/self.target_wh[0], (height-self.target_wh[1])/self.target_wh[1])
|
|
img_jitter = min(img_jitter, crop_jitter)
|
|
img_jitter = max(img_jitter, 0.0)
|
|
|
|
if img_jitter > 0.0:
|
|
image = self._percent_random_crop(image, img_jitter)
|
|
|
|
image = self._trim_to_aspect(image, self.target_wh)
|
|
|
|
self.image = image.resize(self.target_wh)
|
|
|
|
self.image = self.flip(self.image)
|
|
|
|
if save:
|
|
self._debug_save_image(self.image, "final")
|
|
|
|
self.image = np.array(self.image).astype(np.uint8)
|
|
|
|
return self
|
|
|
|
def __compute_target_width_height(self):
|
|
self.target_wh = None
|
|
try:
|
|
with PIL.Image.open(self.pathname) as image:
|
|
if self._needs_transpose(image):
|
|
height, width = image.size
|
|
else:
|
|
width, height = image.size
|
|
|
|
image_aspect = width / height
|
|
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
|
|
|
self.is_undersized = (width != target_wh[0] and height != target_wh[1]) and (width * height) < (target_wh[0]*1.02 * target_wh[1]*1.02)
|
|
|
|
self.target_wh = target_wh
|
|
self.image_size = image.size
|
|
except Exception as e:
|
|
self.error = e
|
|
|
|
@staticmethod
|
|
def __autocrop(image: PIL.Image, q=.404):
|
|
"""
|
|
crops image to a random square inside small axis using a truncated gaussian distribution across the long axis
|
|
"""
|
|
x, y = image.size
|
|
|
|
if x != y:
|
|
if (x > y):
|
|
rand_x = x - y
|
|
sigma = max(rand_x * q, 1)
|
|
else:
|
|
rand_y = y - x
|
|
sigma = max(rand_y * q, 1)
|
|
|
|
if (x > y):
|
|
x_crop_gauss = abs(random.gauss(0, sigma))
|
|
x_crop = min(x_crop_gauss, (x - y) / 2)
|
|
x_crop = math.trunc(x_crop)
|
|
y_crop = 0
|
|
else:
|
|
y_crop_gauss = abs(random.gauss(0, sigma))
|
|
x_crop = 0
|
|
y_crop = min(y_crop_gauss, (y - x) / 2)
|
|
y_crop = math.trunc(y_crop)
|
|
|
|
min_xy = min(x, y)
|
|
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
|
|
|
|
return image
|
|
|
|
|
|
DEFAULT_BATCH_ID = "default_batch"
|