mirror of
https://github.com/yt-dlp/yt-dlp
synced 2025-01-18 10:26:48 +01:00
[utils] traverse_obj
: More fixes (#6959)
- Fix result when branching with `traverse_string`
- Fix `slice` path on `dict`s
- Fix tests and docstrings from 21b5ec86c2
- Add `is_iterable_like` helper function
Authored by: Grub4K
This commit is contained in:
parent
4d9280c9c8
commit
b079c26f0a
2 changed files with 37 additions and 12 deletions
|
@ -2016,7 +2016,7 @@ Line 1
|
||||||
msg='nested `...` queries should work')
|
msg='nested `...` queries should work')
|
||||||
self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4),
|
self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4),
|
||||||
msg='`...` query result should be flattened')
|
msg='`...` query result should be flattened')
|
||||||
self.assertEqual(traverse_obj(range(4), ...), list(range(4)),
|
self.assertEqual(traverse_obj(iter(range(4)), ...), list(range(4)),
|
||||||
msg='`...` should accept iterables')
|
msg='`...` should accept iterables')
|
||||||
|
|
||||||
# Test function as key
|
# Test function as key
|
||||||
|
@ -2025,7 +2025,7 @@ Line 1
|
||||||
msg='function as query key should perform a filter based on (key, value)')
|
msg='function as query key should perform a filter based on (key, value)')
|
||||||
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
|
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
|
||||||
msg='exceptions in the query function should be catched')
|
msg='exceptions in the query function should be catched')
|
||||||
self.assertEqual(traverse_obj(range(4), lambda _, x: x % 2 == 0), [0, 2],
|
self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],
|
||||||
msg='function key should accept iterables')
|
msg='function key should accept iterables')
|
||||||
if __debug__:
|
if __debug__:
|
||||||
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
|
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
|
||||||
|
@ -2051,6 +2051,17 @@ Line 1
|
||||||
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
|
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
|
||||||
traverse_obj(_TEST_DATA, {str.upper, str})
|
traverse_obj(_TEST_DATA, {str.upper, str})
|
||||||
|
|
||||||
|
# Test `slice` as a key
|
||||||
|
_SLICE_DATA = [0, 1, 2, 3, 4]
|
||||||
|
self.assertEqual(traverse_obj(_TEST_DATA, ('dict', slice(1))), None,
|
||||||
|
msg='slice on a dictionary should not throw')
|
||||||
|
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1)), _SLICE_DATA[:1],
|
||||||
|
msg='slice key should apply slice to sequence')
|
||||||
|
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 2)), _SLICE_DATA[1:2],
|
||||||
|
msg='slice key should apply slice to sequence')
|
||||||
|
self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 4, 2)), _SLICE_DATA[1:4:2],
|
||||||
|
msg='slice key should apply slice to sequence')
|
||||||
|
|
||||||
# Test alternative paths
|
# Test alternative paths
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
|
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
|
||||||
msg='multiple `paths` should be treated as alternative paths')
|
msg='multiple `paths` should be treated as alternative paths')
|
||||||
|
@ -2234,6 +2245,12 @@ Line 1
|
||||||
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
|
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
|
||||||
traverse_string=True), ['s', 'r'],
|
traverse_string=True), ['s', 'r'],
|
||||||
msg='branching should result in list if `traverse_string`')
|
msg='branching should result in list if `traverse_string`')
|
||||||
|
self.assertEqual(traverse_obj({}, (0, ...), traverse_string=True), [],
|
||||||
|
msg='branching should result in list if `traverse_string`')
|
||||||
|
self.assertEqual(traverse_obj({}, (0, lambda x, y: True), traverse_string=True), [],
|
||||||
|
msg='branching should result in list if `traverse_string`')
|
||||||
|
self.assertEqual(traverse_obj({}, (0, slice(1)), traverse_string=True), [],
|
||||||
|
msg='branching should result in list if `traverse_string`')
|
||||||
|
|
||||||
# Test is_user_input behavior
|
# Test is_user_input behavior
|
||||||
_IS_USER_INPUT_DATA = {'range8': list(range(8))}
|
_IS_USER_INPUT_DATA = {'range8': list(range(8))}
|
||||||
|
|
|
@ -3273,8 +3273,14 @@ def multipart_encode(data, boundary=None):
|
||||||
return out, content_type
|
return out, content_type
|
||||||
|
|
||||||
|
|
||||||
def variadic(x, allowed_types=(str, bytes, dict)):
|
def is_iterable_like(x, allowed_types=collections.abc.Iterable, blocked_types=NO_DEFAULT):
|
||||||
return x if isinstance(x, collections.abc.Iterable) and not isinstance(x, allowed_types) else (x,)
|
if blocked_types is NO_DEFAULT:
|
||||||
|
blocked_types = (str, bytes, collections.abc.Mapping)
|
||||||
|
return isinstance(x, allowed_types) and not isinstance(x, blocked_types)
|
||||||
|
|
||||||
|
|
||||||
|
def variadic(x, allowed_types=NO_DEFAULT):
|
||||||
|
return x if is_iterable_like(x, blocked_types=allowed_types) else (x,)
|
||||||
|
|
||||||
|
|
||||||
def dict_get(d, key_or_keys, default=None, skip_false_values=True):
|
def dict_get(d, key_or_keys, default=None, skip_false_values=True):
|
||||||
|
@ -5467,7 +5473,7 @@ def traverse_obj(
|
||||||
obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
|
obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
|
||||||
casesense=True, is_user_input=False, traverse_string=False):
|
casesense=True, is_user_input=False, traverse_string=False):
|
||||||
"""
|
"""
|
||||||
Safely traverse nested `dict`s and `Sequence`s
|
Safely traverse nested `dict`s and `Iterable`s
|
||||||
|
|
||||||
>>> obj = [{}, {"key": "value"}]
|
>>> obj = [{}, {"key": "value"}]
|
||||||
>>> traverse_obj(obj, (1, "key"))
|
>>> traverse_obj(obj, (1, "key"))
|
||||||
|
@ -5475,7 +5481,7 @@ def traverse_obj(
|
||||||
|
|
||||||
Each of the provided `paths` is tested and the first producing a valid result will be returned.
|
Each of the provided `paths` is tested and the first producing a valid result will be returned.
|
||||||
The next path will also be tested if the path branched but no results could be found.
|
The next path will also be tested if the path branched but no results could be found.
|
||||||
Supported values for traversal are `Mapping`, `Sequence` and `re.Match`.
|
Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
|
||||||
Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
|
Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
|
||||||
|
|
||||||
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
|
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
|
||||||
|
@ -5492,7 +5498,7 @@ def traverse_obj(
|
||||||
Read as: `[traverse_obj(obj, branch) for branch in branches]`.
|
Read as: `[traverse_obj(obj, branch) for branch in branches]`.
|
||||||
- `function`: Branch out and return values filtered by the function.
|
- `function`: Branch out and return values filtered by the function.
|
||||||
Read as: `[value for key, value in obj if function(key, value)]`.
|
Read as: `[value for key, value in obj if function(key, value)]`.
|
||||||
For `Sequence`s, `key` is the index of the value.
|
For `Iterable`s, `key` is the index of the value.
|
||||||
For `re.Match`es, `key` is the group number (0 = full match)
|
For `re.Match`es, `key` is the group number (0 = full match)
|
||||||
as well as additionally any group names, if given.
|
as well as additionally any group names, if given.
|
||||||
- `dict` Transform the current object and return a matching dict.
|
- `dict` Transform the current object and return a matching dict.
|
||||||
|
@ -5540,7 +5546,9 @@ def traverse_obj(
|
||||||
result = None
|
result = None
|
||||||
|
|
||||||
if obj is None and traverse_string:
|
if obj is None and traverse_string:
|
||||||
pass
|
if key is ... or callable(key) or isinstance(key, slice):
|
||||||
|
branching = True
|
||||||
|
result = ()
|
||||||
|
|
||||||
elif key is None:
|
elif key is None:
|
||||||
result = obj
|
result = obj
|
||||||
|
@ -5563,7 +5571,7 @@ def traverse_obj(
|
||||||
branching = True
|
branching = True
|
||||||
if isinstance(obj, collections.abc.Mapping):
|
if isinstance(obj, collections.abc.Mapping):
|
||||||
result = obj.values()
|
result = obj.values()
|
||||||
elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)):
|
elif is_iterable_like(obj):
|
||||||
result = obj
|
result = obj
|
||||||
elif isinstance(obj, re.Match):
|
elif isinstance(obj, re.Match):
|
||||||
result = obj.groups()
|
result = obj.groups()
|
||||||
|
@ -5577,7 +5585,7 @@ def traverse_obj(
|
||||||
branching = True
|
branching = True
|
||||||
if isinstance(obj, collections.abc.Mapping):
|
if isinstance(obj, collections.abc.Mapping):
|
||||||
iter_obj = obj.items()
|
iter_obj = obj.items()
|
||||||
elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)):
|
elif is_iterable_like(obj):
|
||||||
iter_obj = enumerate(obj)
|
iter_obj = enumerate(obj)
|
||||||
elif isinstance(obj, re.Match):
|
elif isinstance(obj, re.Match):
|
||||||
iter_obj = itertools.chain(
|
iter_obj = itertools.chain(
|
||||||
|
@ -5601,7 +5609,7 @@ def traverse_obj(
|
||||||
} or None
|
} or None
|
||||||
|
|
||||||
elif isinstance(obj, collections.abc.Mapping):
|
elif isinstance(obj, collections.abc.Mapping):
|
||||||
result = (obj.get(key) if casesense or (key in obj) else
|
result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
|
||||||
next((v for k, v in obj.items() if casefold(k) == key), None))
|
next((v for k, v in obj.items() if casefold(k) == key), None))
|
||||||
|
|
||||||
elif isinstance(obj, re.Match):
|
elif isinstance(obj, re.Match):
|
||||||
|
@ -5613,7 +5621,7 @@ def traverse_obj(
|
||||||
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
|
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
|
||||||
|
|
||||||
elif isinstance(key, (int, slice)):
|
elif isinstance(key, (int, slice)):
|
||||||
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, (str, bytes)):
|
if is_iterable_like(obj, collections.abc.Sequence):
|
||||||
branching = isinstance(key, slice)
|
branching = isinstance(key, slice)
|
||||||
with contextlib.suppress(IndexError):
|
with contextlib.suppress(IndexError):
|
||||||
result = obj[key]
|
result = obj[key]
|
||||||
|
|
Loading…
Reference in a new issue