Merge pull request #113 from qslug/config-fix
Assign default value for max_caption_length
This commit is contained in:
commit
f665a188f3
|
@ -12,6 +12,8 @@ from typing import TypeVar, Iterable
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MAX_CAPTION_LENGTH = 2048
|
||||||
|
|
||||||
def overlay(overlay, base):
|
def overlay(overlay, base):
|
||||||
return overlay if overlay is not None else base
|
return overlay if overlay is not None else base
|
||||||
|
|
||||||
|
@ -235,7 +237,7 @@ class Dataset:
|
||||||
rating=config.rating or 1.0,
|
rating=config.rating or 1.0,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
tag_weights=tag_weights,
|
tag_weights=tag_weights,
|
||||||
max_target_length=config.max_caption_length,
|
max_target_length=config.max_caption_length or DEFAULT_MAX_CAPTION_LENGTH,
|
||||||
use_weights=use_weights)
|
use_weights=use_weights)
|
||||||
|
|
||||||
item = ImageTrainItem(
|
item = ImageTrainItem(
|
||||||
|
|
|
@ -29,7 +29,6 @@ from torchvision import transforms
|
||||||
|
|
||||||
_RANDOM_TRIM = 0.04
|
_RANDOM_TRIM = 0.04
|
||||||
|
|
||||||
DEFAULT_MAX_CAPTION_LENGTH = 2048
|
|
||||||
|
|
||||||
OptionalImageCaption = typing.Optional['ImageCaption']
|
OptionalImageCaption = typing.Optional['ImageCaption']
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from data.dataset import Dataset, ImageConfig, Tag
|
from data.dataset import Dataset, ImageConfig, Tag, DEFAULT_MAX_CAPTION_LENGTH
|
||||||
|
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from pyfakefs.fake_filesystem_unittest import TestCase
|
from pyfakefs.fake_filesystem_unittest import TestCase
|
||||||
|
@ -276,8 +276,8 @@ class TestDataset(TestCase):
|
||||||
self.assertEqual(actual[0].cond_dropout, 0.01)
|
self.assertEqual(actual[0].cond_dropout, 0.01)
|
||||||
self.assertEqual(actual[0].caption.rating(), 1.1)
|
self.assertEqual(actual[0].caption.rating(), 1.1)
|
||||||
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
|
self.assertEqual(actual[0].caption.get_caption(), "first caption, tag, tag_2")
|
||||||
# Can't test this
|
self.assertTrue(actual[0].caption._ImageCaption__use_weights)
|
||||||
# self.assertTrue(actual[0].caption.__use_weights)
|
self.assertEqual(actual[0].caption._ImageCaption__max_target_length, 1024)
|
||||||
|
|
||||||
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
|
self.assertEqual(actual[1].pathname, os.path.abspath('2.jpg'))
|
||||||
self.assertEqual(actual[1].multiplier, 1.0)
|
self.assertEqual(actual[1].multiplier, 1.0)
|
||||||
|
@ -285,5 +285,5 @@ class TestDataset(TestCase):
|
||||||
self.assertIsNone(actual[1].cond_dropout)
|
self.assertIsNone(actual[1].cond_dropout)
|
||||||
self.assertEqual(actual[1].caption.rating(), 1.0)
|
self.assertEqual(actual[1].caption.rating(), 1.0)
|
||||||
self.assertEqual(actual[1].caption.get_caption(), "single caption")
|
self.assertEqual(actual[1].caption.get_caption(), "single caption")
|
||||||
# Can't test this
|
self.assertFalse(actual[1].caption._ImageCaption__use_weights)
|
||||||
# self.assertFalse(actual[1].caption.__use_weights)
|
self.assertEqual(actual[1].caption._ImageCaption__max_target_length, DEFAULT_MAX_CAPTION_LENGTH)
|
Loading…
Reference in New Issue