Merge pull request #113 from qslug/config-fix

Assign default value for max_caption_length
This commit is contained in:
Victor Hall 2023-03-15 16:55:11 -04:00 committed by GitHub
commit f665a188f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 7 deletions

View File

@ -12,6 +12,8 @@ from typing import TypeVar, Iterable
from tqdm import tqdm from tqdm import tqdm
DEFAULT_MAX_CAPTION_LENGTH = 2048
def overlay(overlay, base): def overlay(overlay, base):
return overlay if overlay is not None else base return overlay if overlay is not None else base
@ -235,7 +237,7 @@ class Dataset:
rating=config.rating or 1.0, rating=config.rating or 1.0,
tags=tags, tags=tags,
tag_weights=tag_weights, tag_weights=tag_weights,
max_target_length=config.max_caption_length, max_target_length=config.max_caption_length or DEFAULT_MAX_CAPTION_LENGTH,
use_weights=use_weights) use_weights=use_weights)
item = ImageTrainItem( item = ImageTrainItem(

View File

@ -29,7 +29,6 @@ from torchvision import transforms
_RANDOM_TRIM = 0.04 _RANDOM_TRIM = 0.04
DEFAULT_MAX_CAPTION_LENGTH = 2048
OptionalImageCaption = typing.Optional['ImageCaption'] OptionalImageCaption = typing.Optional['ImageCaption']

View File

@ -1,5 +1,5 @@
import os import os
from data.dataset import Dataset, ImageConfig, Tag from data.dataset import Dataset, ImageConfig, Tag, DEFAULT_MAX_CAPTION_LENGTH
from textwrap import dedent from textwrap import dedent
from pyfakefs.fake_filesystem_unittest import TestCase from pyfakefs.fake_filesystem_unittest import TestCase
@ -276,8 +276,8 @@ class TestDataset(TestCase):
self.assertEqual(actual[0].cond_dropout, 0.01) self.assertEqual(actual[0].cond_dropout, 0.01)
self.assertEqual(actual[0].caption.rating(), 1.1) self.assertEqual(actual[0].caption.rating(), 1.1)
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2") self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
# Can't test this self.assertTrue(actual[0].caption._ImageCaption__use_weights)
# self.assertTrue(actual[0].caption.__use_weights) self.assertEqual(actual[0].caption._ImageCaption__max_target_length, 1024)
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg')) self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
self.assertEqual(actual[1].multiplier, 1.0) self.assertEqual(actual[1].multiplier, 1.0)
@ -285,5 +285,5 @@ class TestDataset(TestCase):
self.assertIsNone(actual[1].cond_dropout) self.assertIsNone(actual[1].cond_dropout)
self.assertEqual(actual[1].caption.rating(), 1.0) self.assertEqual(actual[1].caption.rating(), 1.0)
self.assertEqual(actual[1].caption.get_caption(), "single caption") self.assertEqual(actual[1].caption.get_caption(), "single caption")
# Can't test this self.assertFalse(actual[1].caption._ImageCaption__use_weights)
# self.assertFalse(actual[1].caption.__use_weights) self.assertEqual(actual[1].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)