The following line contains Chinese characters and special symbols
第二行 ♪♪
Third Line
Lines with invalid timestamps are ignored
Ignore, two
Ignored, three
'''.encode()
srt_data = '''1
00:00:00,000 --> 00:00:01,000
The following line contains Chinese characters and special symbols
2
00:00:01,000 --> 00:00:02,000
第二行
♪♪
3
00:00:02,000 --> 00:00:03,000
Third
Line
'''
self.assertEqual(dfxp2srt(dfxp_data), srt_data)
dfxp_data_no_default_namespace = b'''
The first line
'''
srt_data = '''1
00:00:00,000 --> 00:00:01,000
The first line
'''
self.assertEqual(dfxp2srt(dfxp_data_no_default_namespace), srt_data)
dfxp_data_with_style = b'''
default stylecustom style
part 1 part 2
line 3 part 3
inner style
'''
srt_data = '''1
00:00:02,080 --> 00:00:05,840
default stylecustom style
2
00:00:02,080 --> 00:00:05,840
part 1
part 2
3
00:00:05,840 --> 00:00:09,560
line 3
part 3
4
00:00:09,560 --> 00:00:12,360
inner
style
'''
self.assertEqual(dfxp2srt(dfxp_data_with_style), srt_data)
dfxp_data_non_utf8 = '''
this should be returned
this should also be returned
this should also be returned
closing tag above should not trick, so this should also be returned
but this text should not be returned
'''
GET_ELEMENT_BY_TAG_RES_OUTERDIV_HTML = GET_ELEMENT_BY_TAG_TEST_STRING.strip()[32:276]
GET_ELEMENT_BY_TAG_RES_OUTERDIV_TEXT = GET_ELEMENT_BY_TAG_RES_OUTERDIV_HTML[5:-6]
GET_ELEMENT_BY_TAG_RES_INNERSPAN_HTML = GET_ELEMENT_BY_TAG_TEST_STRING.strip()[78:119]
GET_ELEMENT_BY_TAG_RES_INNERSPAN_TEXT = GET_ELEMENT_BY_TAG_RES_INNERSPAN_HTML[6:-7]
def test_get_element_text_and_html_by_tag(self):
html = self.GET_ELEMENT_BY_TAG_TEST_STRING
self.assertEqual(
get_element_text_and_html_by_tag('div', html),
(self.GET_ELEMENT_BY_TAG_RES_OUTERDIV_TEXT, self.GET_ELEMENT_BY_TAG_RES_OUTERDIV_HTML))
self.assertEqual(
get_element_text_and_html_by_tag('span', html),
(self.GET_ELEMENT_BY_TAG_RES_INNERSPAN_TEXT, self.GET_ELEMENT_BY_TAG_RES_INNERSPAN_HTML))
self.assertRaises(compat_HTMLParseError, get_element_text_and_html_by_tag, 'article', html)
def test_iri_to_uri(self):
self.assertEqual(
iri_to_uri('https://www.google.com/search?q=foo&ie=utf-8&oe=utf-8&client=firefox-b'),
'https://www.google.com/search?q=foo&ie=utf-8&oe=utf-8&client=firefox-b') # Same
self.assertEqual(
iri_to_uri('https://www.google.com/search?q=Käsesoßenrührlöffel'), # German for cheese sauce stirring spoon
'https://www.google.com/search?q=K%C3%A4seso%C3%9Fenr%C3%BChrl%C3%B6ffel')
self.assertEqual(
iri_to_uri('https://www.google.com/search?q=lt<+gt>+eq%3D+amp%26+percent%25+hash%23+colon%3A+tilde~#trash=?&garbage=#'),
'https://www.google.com/search?q=lt%3C+gt%3E+eq%3D+amp%26+percent%25+hash%23+colon%3A+tilde~#trash=?&garbage=#')
self.assertEqual(
iri_to_uri('http://правозащита38.рф/category/news/'),
'http://xn--38-6kcaak9aj5chl4a3g.xn--p1ai/category/news/')
self.assertEqual(
iri_to_uri('http://www.правозащита38.рф/category/news/'),
'http://www.xn--38-6kcaak9aj5chl4a3g.xn--p1ai/category/news/')
self.assertEqual(
iri_to_uri('https://i❤.ws/emojidomain/👍👏🤝💪'),
'https://xn--i-7iq.ws/emojidomain/%F0%9F%91%8D%F0%9F%91%8F%F0%9F%A4%9D%F0%9F%92%AA')
self.assertEqual(
iri_to_uri('http://日本語.jp/'),
'http://xn--wgv71a119e.jp/')
self.assertEqual(
iri_to_uri('http://导航.中国/'),
'http://xn--fet810g.xn--fiqs8s/')
def test_clean_podcast_url(self):
self.assertEqual(clean_podcast_url('https://www.podtrac.com/pts/redirect.mp3/chtbl.com/track/5899E/traffic.megaphone.fm/HSW7835899191.mp3'), 'https://traffic.megaphone.fm/HSW7835899191.mp3')
self.assertEqual(clean_podcast_url('https://play.podtrac.com/npr-344098539/edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3'), 'https://edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3')
def test_LazyList(self):
it = list(range(10))
self.assertEqual(list(LazyList(it)), it)
self.assertEqual(LazyList(it).exhaust(), it)
self.assertEqual(LazyList(it)[5], it[5])
self.assertEqual(LazyList(it)[5:], it[5:])
self.assertEqual(LazyList(it)[:5], it[:5])
self.assertEqual(LazyList(it)[::2], it[::2])
self.assertEqual(LazyList(it)[1::2], it[1::2])
self.assertEqual(LazyList(it)[5::-1], it[5::-1])
self.assertEqual(LazyList(it)[6:2:-2], it[6:2:-2])
self.assertEqual(LazyList(it)[::-1], it[::-1])
self.assertTrue(LazyList(it))
self.assertFalse(LazyList(range(0)))
self.assertEqual(len(LazyList(it)), len(it))
self.assertEqual(repr(LazyList(it)), repr(it))
self.assertEqual(str(LazyList(it)), str(it))
self.assertEqual(list(LazyList(it, reverse=True)), it[::-1])
self.assertEqual(list(reversed(LazyList(it))[::-1]), it)
self.assertEqual(list(reversed(LazyList(it))[1:3:7]), it[::-1][1:3:7])
def test_LazyList_laziness(self):
def test(ll, idx, val, cache):
self.assertEqual(ll[idx], val)
self.assertEqual(ll._cache, list(cache))
ll = LazyList(range(10))
test(ll, 0, 0, range(1))
test(ll, 5, 5, range(6))
test(ll, -3, 7, range(10))
ll = LazyList(range(10), reverse=True)
test(ll, -1, 0, range(1))
test(ll, 3, 6, range(10))
ll = LazyList(itertools.count())
test(ll, 10, 10, range(11))
ll = reversed(ll)
test(ll, -15, 14, range(15))
def test_format_bytes(self):
self.assertEqual(format_bytes(0), '0.00B')
self.assertEqual(format_bytes(1000), '1000.00B')
self.assertEqual(format_bytes(1024), '1.00KiB')
self.assertEqual(format_bytes(1024**2), '1.00MiB')
self.assertEqual(format_bytes(1024**3), '1.00GiB')
self.assertEqual(format_bytes(1024**4), '1.00TiB')
self.assertEqual(format_bytes(1024**5), '1.00PiB')
self.assertEqual(format_bytes(1024**6), '1.00EiB')
self.assertEqual(format_bytes(1024**7), '1.00ZiB')
self.assertEqual(format_bytes(1024**8), '1.00YiB')
self.assertEqual(format_bytes(1024**9), '1024.00YiB')
def test_hide_login_info(self):
self.assertEqual(Config.hide_login_info(['-u', 'foo', '-p', 'bar']),
['-u', 'PRIVATE', '-p', 'PRIVATE'])
self.assertEqual(Config.hide_login_info(['-u']), ['-u'])
self.assertEqual(Config.hide_login_info(['-u', 'foo', '-u', 'bar']),
['-u', 'PRIVATE', '-u', 'PRIVATE'])
self.assertEqual(Config.hide_login_info(['--username=foo']),
['--username=PRIVATE'])
def test_locked_file(self):
TEXT = 'test_locked_file\n'
FILE = 'test_locked_file.ytdl'
MODES = 'war' # Order is important
try:
for lock_mode in MODES:
with locked_file(FILE, lock_mode, False) as f:
if lock_mode == 'r':
self.assertEqual(f.read(), TEXT * 2, 'Wrong file content')
else:
f.write(TEXT)
for test_mode in MODES:
testing_write = test_mode != 'r'
try:
with locked_file(FILE, test_mode, False):
pass
except (BlockingIOError, PermissionError):
if not testing_write: # FIXME
print(f'Known issue: Exclusive lock ({lock_mode}) blocks read access ({test_mode})')
continue
self.assertTrue(testing_write, f'{test_mode} is blocked by {lock_mode}')
else:
self.assertFalse(testing_write, f'{test_mode} is not blocked by {lock_mode}')
finally:
with contextlib.suppress(OSError):
os.remove(FILE)
def test_determine_file_encoding(self):
self.assertEqual(determine_file_encoding(b''), (None, 0))
self.assertEqual(determine_file_encoding(b'--verbose -x --audio-format mkv\n'), (None, 0))
self.assertEqual(determine_file_encoding(b'\xef\xbb\xbf'), ('utf-8', 3))
self.assertEqual(determine_file_encoding(b'\x00\x00\xfe\xff'), ('utf-32-be', 4))
self.assertEqual(determine_file_encoding(b'\xff\xfe'), ('utf-16-le', 2))
self.assertEqual(determine_file_encoding(b'\xff\xfe# coding: utf-8\n--verbose'), ('utf-16-le', 2))
self.assertEqual(determine_file_encoding(b'# coding: utf-8\n--verbose'), ('utf-8', 0))
self.assertEqual(determine_file_encoding(b'# coding: someencodinghere-12345\n--verbose'), ('someencodinghere-12345', 0))
self.assertEqual(determine_file_encoding(b'#coding:utf-8\n--verbose'), ('utf-8', 0))
self.assertEqual(determine_file_encoding(b'# coding: utf-8 \r\n--verbose'), ('utf-8', 0))
self.assertEqual(determine_file_encoding('# coding: utf-32-be'.encode('utf-32-be')), ('utf-32-be', 0))
self.assertEqual(determine_file_encoding('# coding: utf-16-le'.encode('utf-16-le')), ('utf-16-le', 0))
def test_get_compatible_ext(self):
self.assertEqual(get_compatible_ext(
vcodecs=[None], acodecs=[None, None], vexts=['mp4'], aexts=['m4a', 'm4a']), 'mkv')
self.assertEqual(get_compatible_ext(
vcodecs=[None], acodecs=[None], vexts=['flv'], aexts=['flv']), 'flv')
self.assertEqual(get_compatible_ext(
vcodecs=[None], acodecs=[None], vexts=['mp4'], aexts=['m4a']), 'mp4')
self.assertEqual(get_compatible_ext(
vcodecs=[None], acodecs=[None], vexts=['mp4'], aexts=['webm']), 'mkv')
self.assertEqual(get_compatible_ext(
vcodecs=[None], acodecs=[None], vexts=['webm'], aexts=['m4a']), 'mkv')
self.assertEqual(get_compatible_ext(
vcodecs=[None], acodecs=[None], vexts=['webm'], aexts=['webm']), 'webm')
self.assertEqual(get_compatible_ext(
vcodecs=[None], acodecs=[None], vexts=['webm'], aexts=['weba']), 'webm')
self.assertEqual(get_compatible_ext(
vcodecs=['h264'], acodecs=['mp4a'], vexts=['mov'], aexts=['m4a']), 'mp4')
self.assertEqual(get_compatible_ext(
vcodecs=['av01.0.12M.08'], acodecs=['opus'], vexts=['mp4'], aexts=['webm']), 'webm')
self.assertEqual(get_compatible_ext(
vcodecs=['vp9'], acodecs=['opus'], vexts=['webm'], aexts=['webm'], preferences=['flv', 'mp4']), 'mp4')
self.assertEqual(get_compatible_ext(
vcodecs=['av1'], acodecs=['mp4a'], vexts=['webm'], aexts=['m4a'], preferences=('webm', 'mkv')), 'mkv')
def test_try_call(self):
def total(*x, **kwargs):
return sum(x) + sum(kwargs.values())
self.assertEqual(try_call(None), None,
msg='not a fn should give None')
self.assertEqual(try_call(lambda: 1), 1,
msg='int fn with no expected_type should give int')
self.assertEqual(try_call(lambda: 1, expected_type=int), 1,
msg='int fn with expected_type int should give int')
self.assertEqual(try_call(lambda: 1, expected_type=dict), None,
msg='int fn with wrong expected_type should give None')
self.assertEqual(try_call(total, args=(0, 1, 0, ), expected_type=int), 1,
msg='fn should accept arglist')
self.assertEqual(try_call(total, kwargs={'a': 0, 'b': 1, 'c': 0}, expected_type=int), 1,
msg='fn should accept kwargs')
self.assertEqual(try_call(lambda: 1, expected_type=dict), None,
msg='int fn with no expected_type should give None')
self.assertEqual(try_call(lambda x: {}, total, args=(42, ), expected_type=int), 42,
msg='expect first int result with expected_type int')
def test_variadic(self):
self.assertEqual(variadic(None), (None, ))
self.assertEqual(variadic('spam'), ('spam', ))
self.assertEqual(variadic('spam', allowed_types=dict), 'spam')
with warnings.catch_warnings():
warnings.simplefilter('ignore')
self.assertEqual(variadic('spam', allowed_types=[dict]), 'spam')
def test_traverse_obj(self):
_TEST_DATA = {
100: 100,
1.2: 1.2,
'str': 'str',
'None': None,
'...': ...,
'urls': [
{'index': 0, 'url': 'https://www.example.com/0'},
{'index': 1, 'url': 'https://www.example.com/1'},
],
'data': (
{'index': 2},
{'index': 3},
),
'dict': {},
}
# Test base functionality
self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str',
msg='allow tuple path')
self.assertEqual(traverse_obj(_TEST_DATA, ['str']), 'str',
msg='allow list path')
self.assertEqual(traverse_obj(_TEST_DATA, (value for value in ("str",))), 'str',
msg='allow iterable path')
self.assertEqual(traverse_obj(_TEST_DATA, 'str'), 'str',
msg='single items should be treated as a path')
self.assertEqual(traverse_obj(_TEST_DATA, None), _TEST_DATA)
self.assertEqual(traverse_obj(_TEST_DATA, 100), 100)
self.assertEqual(traverse_obj(_TEST_DATA, 1.2), 1.2)
# Test Ellipsis behavior
self.assertCountEqual(traverse_obj(_TEST_DATA, ...),
(item for item in _TEST_DATA.values() if item not in (None, {})),
msg='`...` should give all non discarded values')
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(),
msg='`...` selection for dicts should select all values')
self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')),
['https://www.example.com/0', 'https://www.example.com/1'],
msg='nested `...` queries should work')
self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4),
msg='`...` query result should be flattened')
self.assertEqual(traverse_obj(iter(range(4)), ...), list(range(4)),
msg='`...` should accept iterables')
# Test function as key
self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
[_TEST_DATA['urls']],
msg='function as query key should perform a filter based on (key, value)')
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
msg='exceptions in the query function should be catched')
self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],
msg='function key should accept iterables')
if __debug__:
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
traverse_obj(_TEST_DATA, lambda a: ...)
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
traverse_obj(_TEST_DATA, lambda a, b, c: ...)
# Test set as key (transformation/type, like `expected_type`)
self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper}, )), ['STR'],
msg='Function in set should be a transformation')
self.assertEqual(traverse_obj(_TEST_DATA, (..., {str})), ['str'],
msg='Type in set should be a type filter')
self.assertEqual(traverse_obj(_TEST_DATA, {dict}), _TEST_DATA,
msg='A single set should be wrapped into a path')
self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper})), ['STR'],
msg='Transformation function should not raise')
self.assertEqual(traverse_obj(_TEST_DATA, (..., {str_or_none})),
[item for item in map(str_or_none, _TEST_DATA.values()) if item is not None],
msg='Function in set should be a transformation')
if __debug__:
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
traverse_obj(_TEST_DATA, set())
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
traverse_obj(_TEST_DATA, {str.upper, str})
# Test `slice` as a key
_SLICE_DATA = [0, 1, 2, 3, 4]
self.assertEqual(traverse_obj(_TEST_DATA, ('dict', slice(1))), None,
msg='slice on a dictionary should not throw')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1)), _SLICE_DATA[:1],
msg='slice key should apply slice to sequence')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 2)), _SLICE_DATA[1:2],
msg='slice key should apply slice to sequence')
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 4, 2)), _SLICE_DATA[1:4:2],
msg='slice key should apply slice to sequence')
# Test alternative paths
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
msg='multiple `paths` should be treated as alternative paths')
self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str',
msg='alternatives should exit early')
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None,
msg='alternatives should return `default` if exhausted')
self.assertEqual(traverse_obj(_TEST_DATA, (..., 'fail'), 100), 100,
msg='alternatives should track their own branching return')
self.assertEqual(traverse_obj(_TEST_DATA, ('dict', ...), ('data', ...)), list(_TEST_DATA['data']),
msg='alternatives on empty objects should search further')
# Test branch and path nesting
self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'],
msg='tuple as key should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, ('urls', [3, 0], 'url')), ['https://www.example.com/0'],
msg='list as key should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ((1, 'fail'), (0, 'url')))), ['https://www.example.com/0'],
msg='double nesting in path should be treated as paths')
self.assertEqual(traverse_obj(['0', [1, 2]], [(0, 1), 0]), [1],
msg='do not fail early on branching')
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', ((1, ('fail', 'url')), (0, 'url')))),
['https://www.example.com/0', 'https://www.example.com/1'],
msg='tripple nesting in path should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ('fail', (..., 'url')))),
['https://www.example.com/0', 'https://www.example.com/1'],
msg='ellipsis as branch path start gets flattened')
# Test dictionary as key
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}), {0: 100, 1: 1.2},
msg='dict key should result in a dict with the same keys')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', 0, 'url')}),
{0: 'https://www.example.com/0'},
msg='dict key should allow paths')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', (3, 0), 'url')}),
{0: ['https://www.example.com/0']},
msg='tuple in dict path should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, 'fail'), (0, 'url')))}),
{0: ['https://www.example.com/0']},
msg='double nesting in dict path should be treated as paths')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}),
{0: ['https://www.example.com/1', 'https://www.example.com/0']},
msg='tripple nesting in dict path should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
msg='remove `None` values when top level dict key fails')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...},
msg='use `default` if key fails and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
msg='remove empty values when dict key')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: ...},
msg='use `default` when dict key and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
msg='remove empty values when nested dict key fails')
self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
msg='default to dict if pruned')
self.assertEqual(traverse_obj(None, {0: 'fail'}, default=...), {0: ...},
msg='default to dict if pruned and default is given')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=...), {0: {0: ...}},
msg='use nested `default` when nested dict key fails and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {},
msg='remove key if branch in dict key not successful')
# Testing default parameter behavior
_DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail'), None,
msg='default value should be `None`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', 'fail', default=...), ...,
msg='chained fails should result in default')
self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', 'int'), 0,
msg='should not short cirquit on `None`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', default=1), 1,
msg='invalid dict key should result in `default`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', default=1), 1,
msg='`None` is a deliberate sentinel and should become `default`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None,
msg='`IndexError` should result in `default`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=1), 1,
msg='if branched but not successful return `default` if defined, not `[]`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=None), None,
msg='if branched but not successful return `default` even if `default` is `None`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail')), [],
msg='if branched but not successful return `[]`, not `default`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', ...)), [],
msg='if branched but object is empty return `[]`, not `default`')
self.assertEqual(traverse_obj(None, ...), [],
msg='if branched but object is `None` return `[]`, not `default`')
self.assertEqual(traverse_obj({0: None}, (0, ...)), [],
msg='if branched but state is `None` return `[]`, not `default`')
branching_paths = [
('fail', ...),
(..., 'fail'),
100 * ('fail',) + (...,),
(...,) + 100 * ('fail',),
]
for branching_path in branching_paths:
self.assertEqual(traverse_obj({}, branching_path), [],
msg='if branched but state is `None`, return `[]` (not `default`)')
self.assertEqual(traverse_obj({}, 'fail', branching_path), [],
msg='if branching in last alternative and previous did not match, return `[]` (not `default`)')
self.assertEqual(traverse_obj({0: 'x'}, 0, branching_path), 'x',
msg='if branching in last alternative and previous did match, return single value')
self.assertEqual(traverse_obj({0: 'x'}, branching_path, 0), 'x',
msg='if branching in first alternative and non-branching path does match, return single value')
self.assertEqual(traverse_obj({}, branching_path, 'fail'), None,
msg='if branching in first alternative and non-branching path does not match, return `default`')
# Testing expected_type behavior
_EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str),
'str', msg='accept matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int),
None, msg='reject non matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)),
'0', msg='transform type using type function')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0),
None, msg='wrap expected_type fuction in try_call')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str),
['str'], msg='eliminate items that expected_type fails on')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int),
{0: 100}, msg='type as expected_type should filter dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none),
{0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values')
self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int),
1, msg='expected_type should not filter non final dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int),
{0: {0: 100}}, msg='expected_type should transform deep dict values')
self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(...)),
[{0: ...}, {0: ...}], msg='expected_type should transform branched dict values')
self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int),
[4], msg='expected_type regression for type matching in tuple branching')
self.assertEqual(traverse_obj(_TEST_DATA, ['data', ...], expected_type=int),
[], msg='expected_type regression for type matching in dict result')
# Test get_all behavior
_GET_ALL_DATA = {'key': [0, 1, 2]}
self.assertEqual(traverse_obj(_GET_ALL_DATA, ('key', ...), get_all=False), 0,
msg='if not `get_all`, return only first matching value')
self.assertEqual(traverse_obj(_GET_ALL_DATA, ..., get_all=False), [0, 1, 2],
msg='do not overflatten if not `get_all`')
# Test casesense behavior
_CASESENSE_DATA = {
'KeY': 'value0',
0: {
'KeY': 'value1',
0: {'KeY': 'value2'},
},
}
self.assertEqual(traverse_obj(_CASESENSE_DATA, 'key'), None,
msg='dict keys should be case sensitive unless `casesense`')
self.assertEqual(traverse_obj(_CASESENSE_DATA, 'keY',
casesense=False), 'value0',
msg='allow non matching key case if `casesense`')
self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ('keY',)),
casesense=False), ['value1'],
msg='allow non matching key case in branch if `casesense`')
self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ((0, 'keY'),)),
casesense=False), ['value2'],
msg='allow non matching key case in branch path if `casesense`')
# Test traverse_string behavior
_TRAVERSE_STRING_DATA = {'str': 'str', 1.2: 1.2}
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0)), None,
msg='do not traverse into string if not `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0),
traverse_string=True), 's',
msg='traverse into string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, (1.2, 1),
traverse_string=True), '.',
msg='traverse into converted data if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...),
traverse_string=True), 'str',
msg='`...` should result in string (same value) if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
traverse_string=True), 'sr',
msg='`slice` should result in string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"),
traverse_string=True), 'str',
msg='function should result in string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
traverse_string=True), ['s', 'r'],
msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj({}, (0, ...), traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj({}, (0, lambda x, y: True), traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj({}, (0, slice(1)), traverse_string=True), [],
msg='branching should result in list if `traverse_string`')
# Test is_user_input behavior
_IS_USER_INPUT_DATA = {'range8': list(range(8))}
self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'),
is_user_input=True), 3,
msg='allow for string indexing if `is_user_input`')
self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'),
is_user_input=True), tuple(range(8))[3:],
msg='allow for string slice if `is_user_input`')
self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'),
is_user_input=True), tuple(range(8))[:4:2],
msg='allow step in string slice if `is_user_input`')
self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'),
is_user_input=True), range(8),
msg='`:` should be treated as `...` if `is_user_input`')
with self.assertRaises(TypeError, msg='too many params should result in error'):
traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True)
# Test re.Match as input obj
mobj = re.fullmatch(r'0(12)(?P3)(4)?', '0123')
self.assertEqual(traverse_obj(mobj, ...), [x for x in mobj.groups() if x is not None],
msg='`...` on a `re.Match` should give its `groups()`')
self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 2)), ['0123', '3'],
msg='function on a `re.Match` should give groupno, value starting at 0')
self.assertEqual(traverse_obj(mobj, 'group'), '3',
msg='str key on a `re.Match` should give group with that name')
self.assertEqual(traverse_obj(mobj, 2), '3',
msg='int key on a `re.Match` should give group with that name')
self.assertEqual(traverse_obj(mobj, 'gRoUp', casesense=False), '3',
msg='str key on a `re.Match` should respect casesense')
self.assertEqual(traverse_obj(mobj, 'fail'), None,
msg='failing str key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, 'gRoUpS', casesense=False), None,
msg='failing str key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, 8), None,
msg='failing int key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
msg='function on a `re.Match` should give group name as well')
if __name__ == '__main__':
unittest.main()