diff --git a/data/image_train_item.py b/data/image_train_item.py index 623c664..394394a 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -48,7 +48,7 @@ class ImageCaption: self.__rating = rating self.__tags = tags self.__tag_weights = tag_weights - self.__max_target_length = max_target_length + self.__max_target_length = max_target_length or 2048 self.__use_weights = use_weights if use_weights and len(tags) > len(tag_weights): self.__tag_weights.extend([1.0] * (len(tags) - len(tag_weights))) @@ -66,7 +66,13 @@ class ImageCaption: :return: generated caption string """ if self.__tags: - max_target_tag_length = self.__max_target_length - len(self.__main_prompt) + try: + max_target_tag_length = self.__max_target_length - len(self.__main_prompt or 0) + except Exception as e: + print() + logging.error(f"Error determining length for: {e} on {self.__main_prompt}") + print() + max_target_tag_length = 2048 if self.__use_weights: tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)