From f99bbfc9838d98d81027dddb18ace0af66acdf6d Mon Sep 17 00:00:00 2001 From: Simon Sawicki <37424085+Grub4K@users.noreply.github.com> Date: Sun, 9 Oct 2022 03:27:32 +0200 Subject: [PATCH] [utils] `traverse_obj`: Always return list when branching (#5170) Fixes #5162 Authored by: Grub4K --- test/test_utils.py | 27 +++++++++++++++++++++++---- yt_dlp/utils.py | 22 ++++++++++++++-------- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 69313564a1..6f3f6cb914 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1890,6 +1890,7 @@ Line 1 {'index': 2}, {'index': 3}, ), + 'dict': {}, } # Test base functionality @@ -1926,11 +1927,15 @@ Line 1 # Test alternative paths self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', - msg='multiple `path_list` should be treated as alternative paths') + msg='multiple `paths` should be treated as alternative paths') self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str', msg='alternatives should exit early') self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None, msg='alternatives should return `default` if exhausted') + self.assertEqual(traverse_obj(_TEST_DATA, (..., 'fail'), 100), 100, + msg='alternatives should track their own branching return') + self.assertEqual(traverse_obj(_TEST_DATA, ('dict', ...), ('data', ...)), list(_TEST_DATA['data']), + msg='alternatives on empty objects should search further') # Test branch and path nesting self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'], @@ -1963,8 +1968,16 @@ Line 1 self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}), {0: ['https://www.example.com/1', 'https://www.example.com/0']}, msg='tripple nesting in dict path should be treated as branches') - self.assertEqual(traverse_obj({}, {0: 1}, default=...), {0: ...}, - msg='do not remove `None` values when dict key') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {}, + msg='remove `None` values when dict key') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...}, + msg='do not remove `None` values if `default`') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}}, + msg='do not remove empty values when dict key') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: {}}, + msg='do not remove empty values when dict key and a default') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {0: []}, + msg='if branch in dict key not successful, return `[]`') # Testing default parameter behavior _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []} @@ -1981,7 +1994,13 @@ Line 1 self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None, msg='`IndexError` should result in `default`') self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=1), 1, - msg='if branched but not successfull return `default`, not `[]`') + msg='if branched but not successful return `default` if defined, not `[]`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=None), None, + msg='if branched but not successful return `default` even if `default` is `None`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail')), [], + msg='if branched but not successful return `[]`, not `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', ...)), [], + msg='if branched but object is empty return `[]`, not `default`') # Testing expected_type behavior _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0} diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index d0be7f19ef..7d8e971626 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -5294,7 +5294,7 @@ def load_plugins(name, suffix, namespace): def traverse_obj( - obj, *paths, default=None, expected_type=None, get_all=True, + obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True, casesense=True, is_user_input=False, traverse_string=False): """ Safely traverse nested `dict`s and `Sequence`s @@ -5304,6 +5304,7 @@ def traverse_obj( "value" Each of the provided `paths` is tested and the first producing a valid result will be returned. + The next path will also be tested if the path branched but no results could be found. A value of None is treated as the absence of a value. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. @@ -5342,6 +5343,7 @@ def traverse_obj( @returns The result of the object traversal. If successful, `get_all=True`, and the path branches at least once, then a list of results is returned instead. + A list is always returned if the last path branches and no `default` is given. """ is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes)) casefold = lambda k: k.casefold() if isinstance(k, str) else k @@ -5385,7 +5387,7 @@ def traverse_obj( elif isinstance(key, dict): iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) yield {k: v if v is not None else default for k, v in iter_obj - if v is not None or default is not None} + if v is not None or default is not NO_DEFAULT} elif isinstance(obj, dict): yield (obj.get(key) if casesense or (key in obj) @@ -5426,18 +5428,22 @@ def traverse_obj( return has_branched, objs - def _traverse_obj(obj, path): + def _traverse_obj(obj, path, use_list=True): has_branched, results = apply_path(obj, path) results = LazyList(x for x in map(type_test, results) if x is not None) - if results: - return results.exhaust() if get_all and has_branched else results[0] - for path in paths: - result = _traverse_obj(obj, path) + if get_all and has_branched: + return results.exhaust() if results or use_list else None + + return results[0] if results else None + + for index, path in enumerate(paths, 1): + use_list = default is NO_DEFAULT and index == len(paths) + result = _traverse_obj(obj, path, use_list) if result is not None: return result - return default + return None if default is NO_DEFAULT else default def traverse_dict(dictn, keys, casesense=True):