Assign default value for MAX_CAPTION_LENGTH

This commit is contained in:
Augusto de la Torre 2023-03-15 21:47:33 +01:00
parent b9f4a6d657
commit 48f132554c
3 changed files with 8 additions and 7 deletions

View File

@ -12,6 +12,8 @@ from typing import TypeVar, Iterable
from tqdm import tqdm from tqdm import tqdm
DEFAULT_MAX_CAPTION_LENGTH = 2048
def overlay(overlay, base): def overlay(overlay, base):
return overlay if overlay is not None else base return overlay if overlay is not None else base
@ -235,7 +237,7 @@ class Dataset:
rating=config.rating or 1.0, rating=config.rating or 1.0,
tags=tags, tags=tags,
tag_weights=tag_weights, tag_weights=tag_weights,
max_target_length=config.max_caption_length, max_target_length=config.max_caption_length or DEFAULT_MAX_CAPTION_LENGTH,
use_weights=use_weights) use_weights=use_weights)
item = ImageTrainItem( item = ImageTrainItem(

View File

@ -29,7 +29,6 @@ from torchvision import transforms
_RANDOM_TRIM = 0.04 _RANDOM_TRIM = 0.04
DEFAULT_MAX_CAPTION_LENGTH = 2048
OptionalImageCaption = typing.Optional['ImageCaption'] OptionalImageCaption = typing.Optional['ImageCaption']

View File

@ -1,5 +1,5 @@
import os import os
from data.dataset import Dataset, ImageConfig, Tag from data.dataset import Dataset, ImageConfig, Tag, DEFAULT_MAX_CAPTION_LENGTH
from textwrap import dedent from textwrap import dedent
from pyfakefs.fake_filesystem_unittest import TestCase from pyfakefs.fake_filesystem_unittest import TestCase
@ -276,8 +276,8 @@ class TestDataset(TestCase):
self.assertEqual(actual[0].cond_dropout, 0.01) self.assertEqual(actual[0].cond_dropout, 0.01)
self.assertEqual(actual[0].caption.rating(), 1.1) self.assertEqual(actual[0].caption.rating(), 1.1)
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2") self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
# Can't test this self.assertTrue(actual[0].caption._ImageCaption__use_weights)
# self.assertTrue(actual[0].caption.__use_weights) self.assertEqual(actual[0].caption._ImageCaption__max_target_length, 1024)
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg')) self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
self.assertEqual(actual[1].multiplier, 1.0) self.assertEqual(actual[1].multiplier, 1.0)
@ -285,5 +285,5 @@ class TestDataset(TestCase):
self.assertIsNone(actual[1].cond_dropout) self.assertIsNone(actual[1].cond_dropout)
self.assertEqual(actual[1].caption.rating(), 1.0) self.assertEqual(actual[1].caption.rating(), 1.0)
self.assertEqual(actual[1].caption.get_caption(), "single caption") self.assertEqual(actual[1].caption.get_caption(), "single caption")
# Can't test this self.assertFalse(actual[1].caption._ImageCaption__use_weights)
# self.assertFalse(actual[1].caption.__use_weights) self.assertEqual(actual[1].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)