From 32abfb00bdbd119ca675fdc6d1719331f0a2741a Mon Sep 17 00:00:00 2001 From: Simon Sawicki Date: Mon, 1 Apr 2024 02:12:03 +0200 Subject: [PATCH] [utils] `traverse_obj`: Convenience improvements (#9577) Add support for: - `http.cookies.Morsel` - Multi type filters (`{type, type}`) Authored by: Grub4K --- test/test_traversal.py | 33 ++++++++++++++++++++++++++++++++- yt_dlp/utils/traversal.py | 28 +++++++++++++++++++--------- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/test/test_traversal.py b/test/test_traversal.py index 0b2f3fb5da..ed29d03ad5 100644 --- a/test/test_traversal.py +++ b/test/test_traversal.py @@ -1,3 +1,4 @@ +import http.cookies import re import xml.etree.ElementTree @@ -94,6 +95,8 @@ class TestTraversal: 'Function in set should be a transformation' assert traverse_obj(_TEST_DATA, (..., {str})) == ['str'], \ 'Type in set should be a type filter' + assert traverse_obj(_TEST_DATA, (..., {str, int})) == [100, 'str'], \ + 'Multiple types in set should be a type filter' assert traverse_obj(_TEST_DATA, {dict}) == _TEST_DATA, \ 'A single set should be wrapped into a path' assert traverse_obj(_TEST_DATA, (..., {str.upper})) == ['STR'], \ @@ -103,7 +106,7 @@ class TestTraversal: 'Function in set should be a transformation' assert traverse_obj(_TEST_DATA, ('fail', {lambda _: 'const'})) == 'const', \ 'Function in set should always be called' - # Sets with length != 1 should raise in debug + # Sets with length < 1 or > 1 not including only types should raise with pytest.raises(Exception): traverse_obj(_TEST_DATA, set()) with pytest.raises(Exception): @@ -409,3 +412,31 @@ class TestTraversal: '`all` should allow further branching' assert traverse_obj(_TEST_DATA, [('dict', 'None', 'urls', 'data'), any, ..., 'index']) == [0, 1], \ '`any` should allow further branching' + + def test_traversal_morsel(self): + values = { + 'expires': 'a', + 'path': 'b', + 'comment': 'c', + 'domain': 'd', + 'max-age': 'e', + 'secure': 'f', + 'httponly': 'g', + 'version': 'h', + 'samesite': 'i', + } + morsel = http.cookies.Morsel() + morsel.set('item_key', 'item_value', 'coded_value') + morsel.update(values) + values['key'] = 'item_key' + values['value'] = 'item_value' + + for key, value in values.items(): + assert traverse_obj(morsel, key) == value, \ + 'Morsel should provide access to all values' + assert traverse_obj(morsel, ...) == list(values.values()), \ + '`...` should yield all values' + assert traverse_obj(morsel, lambda k, v: True) == list(values.values()), \ + 'function key should yield all values' + assert traverse_obj(morsel, [(None,), any]) == morsel, \ + 'Morsel should not be implicitly changed to dict on usage' diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index 926a3d0a13..96eb2eddf5 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -1,5 +1,6 @@ import collections.abc import contextlib +import http.cookies import inspect import itertools import re @@ -28,7 +29,8 @@ def traverse_obj( Each of the provided `paths` is tested and the first producing a valid result will be returned. The next path will also be tested if the path branched but no results could be found. - Supported values for traversal are `Mapping`, `Iterable` and `re.Match`. + Supported values for traversal are `Mapping`, `Iterable`, `re.Match`, + `xml.etree.ElementTree` (xpath) and `http.cookies.Morsel`. Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. @@ -36,8 +38,8 @@ def traverse_obj( The keys in the path can be one of: - `None`: Return the current object. - `set`: Requires the only item in the set to be a type or function, - like `{type}`/`{func}`. If a `type`, returns only values - of this type. If a function, returns `func(obj)`. + like `{type}`/`{type, type, ...}/`{func}`. If a `type`, return only + values of this type. If a function, returns `func(obj)`. - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`. - `slice`: Branch out and return all values in `obj[key]`. - `Ellipsis`: Branch out and return a list of all values. @@ -48,8 +50,10 @@ def traverse_obj( For `Iterable`s, `key` is the index of the value. For `re.Match`es, `key` is the group number (0 = full match) as well as additionally any group names, if given. - - `dict` Transform the current object and return a matching dict. + - `dict`: Transform the current object and return a matching dict. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. + - `any`-builtin: Take the first matching object and return it, resetting branching. + - `all`-builtin: Take all matching objects and return them as a list, resetting branching. `tuple`, `list`, and `dict` all support nested paths and branches. @@ -102,10 +106,10 @@ def traverse_obj( result = obj elif isinstance(key, set): - assert len(key) == 1, 'Set should only be used to wrap a single item' item = next(iter(key)) - if isinstance(item, type): - if isinstance(obj, item): + if len(key) > 1 or isinstance(item, type): + assert all(isinstance(item, type) for item in key) + if isinstance(obj, tuple(key)): result = obj else: result = try_call(item, args=(obj,)) @@ -117,6 +121,8 @@ def traverse_obj( elif key is ...: branching = True + if isinstance(obj, http.cookies.Morsel): + obj = dict(obj, key=obj.key, value=obj.value) if isinstance(obj, collections.abc.Mapping): result = obj.values() elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element): @@ -131,6 +137,8 @@ def traverse_obj( elif callable(key): branching = True + if isinstance(obj, http.cookies.Morsel): + obj = dict(obj, key=obj.key, value=obj.value) if isinstance(obj, collections.abc.Mapping): iter_obj = obj.items() elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element): @@ -157,6 +165,8 @@ def traverse_obj( } or None elif isinstance(obj, collections.abc.Mapping): + if isinstance(obj, http.cookies.Morsel): + obj = dict(obj, key=obj.key, value=obj.value) result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else next((v for k, v in obj.items() if casefold(k) == key), None)) @@ -179,7 +189,7 @@ def traverse_obj( elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str): xpath, _, special = key.rpartition('/') - if not special.startswith('@') and special != 'text()': + if not special.startswith('@') and not special.endswith('()'): xpath = key special = None @@ -198,7 +208,7 @@ def traverse_obj( return try_call(element.attrib.get, args=(special[1:],)) if special == 'text()': return element.text - assert False, f'apply_specials is missing case for {special!r}' + raise SyntaxError(f'apply_specials is missing case for {special!r}') if xpath: result = list(map(apply_specials, obj.iterfind(xpath)))