From 3699eeb67cad333272b14a42dd3843d93fda1a2e Mon Sep 17 00:00:00 2001 From: Simon Sawicki Date: Sat, 30 Mar 2024 19:54:43 +0100 Subject: [PATCH] [utils] `traverse_obj`: Allow unbranching using `all` and `any` (#9571) Authored by: Grub4K --- test/test_traversal.py | 32 ++++++++++++++++++++++++++++++++ yt_dlp/utils/traversal.py | 9 +++++++++ 2 files changed, 41 insertions(+) diff --git a/test/test_traversal.py b/test/test_traversal.py index 3b247d0597..0b2f3fb5da 100644 --- a/test/test_traversal.py +++ b/test/test_traversal.py @@ -377,3 +377,35 @@ class TestTraversal: 'special transformations should act on current element' assert traverse_obj(etree, ('country', 0, ..., 'text()', {int_or_none})) == [1, 2008, 141100], \ 'special transformations should act on current element' + + def test_traversal_unbranching(self): + assert traverse_obj(_TEST_DATA, [(100, 1.2), all]) == [100, 1.2], \ + '`all` should give all results as list' + assert traverse_obj(_TEST_DATA, [(100, 1.2), any]) == 100, \ + '`any` should give the first result' + assert traverse_obj(_TEST_DATA, [100, all]) == [100], \ + '`all` should give list if non branching' + assert traverse_obj(_TEST_DATA, [100, any]) == 100, \ + '`any` should give single item if non branching' + assert traverse_obj(_TEST_DATA, [('dict', 'None', 100), all]) == [100], \ + '`all` should filter `None` and empty dict' + assert traverse_obj(_TEST_DATA, [('dict', 'None', 100), any]) == 100, \ + '`any` should filter `None` and empty dict' + assert traverse_obj(_TEST_DATA, [{ + 'all': [('dict', 'None', 100, 1.2), all], + 'any': [('dict', 'None', 100, 1.2), any], + }]) == {'all': [100, 1.2], 'any': 100}, \ + '`all`/`any` should apply to each dict path separately' + assert traverse_obj(_TEST_DATA, [{ + 'all': [('dict', 'None', 100, 1.2), all], + 'any': [('dict', 'None', 100, 1.2), any], + }], get_all=False) == {'all': [100, 1.2], 'any': 100}, \ + '`all`/`any` should apply to dict regardless of `get_all`' + assert traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, {float}]) is None, \ + '`all` should reset branching status' + assert traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), any, {float}]) is None, \ + '`any` should reset branching status' + assert traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, ..., {float}]) == [1.2], \ + '`all` should allow further branching' + assert traverse_obj(_TEST_DATA, [('dict', 'None', 'urls', 'data'), any, ..., 'index']) == [0, 1], \ + '`any` should allow further branching' diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index 8938f4c782..926a3d0a13 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -228,6 +228,15 @@ def traverse_obj( if not casesense and isinstance(key, str): key = key.casefold() + if key in (any, all): + has_branched = False + filtered_objs = (obj for obj in objs if obj not in (None, {})) + if key is any: + objs = (next(filtered_objs, None),) + else: + objs = (list(filtered_objs),) + continue + if __debug__ and callable(key): # Verify function signature inspect.signature(key).bind(None, None)