diff --git a/test/test_utils.py b/test/test_utils.py
index c3e387cd0d..09c648cf89 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -2340,6 +2340,58 @@ Line 1
self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
msg='function on a `re.Match` should give group name as well')
+ # Test xml.etree.ElementTree.Element as input obj
+ etree = xml.etree.ElementTree.fromstring('''
+
+
+ 1
+ 2008
+ 141100
+
+
+
+
+ 4
+ 2011
+ 59900
+
+
+
+ 68
+ 2011
+ 13600
+
+
+
+ ''')
+ self.assertEqual(traverse_obj(etree, ''), etree,
+ msg='empty str key should return the element itself')
+ self.assertEqual(traverse_obj(etree, 'country'), list(etree),
+ msg='str key should lead all children with that tag name')
+ self.assertEqual(traverse_obj(etree, ...), list(etree),
+ msg='`...` as key should return all children')
+ self.assertEqual(traverse_obj(etree, lambda _, x: x[0].text == '4'), [etree[1]],
+ msg='function as key should get element as value')
+ self.assertEqual(traverse_obj(etree, lambda i, _: i == 1), [etree[1]],
+ msg='function as key should get index as key')
+ self.assertEqual(traverse_obj(etree, 0), etree[0],
+ msg='int key should return the nth child')
+ self.assertEqual(traverse_obj(etree, './/neighbor/@name'),
+ ['Austria', 'Switzerland', 'Malaysia', 'Costa Rica', 'Colombia'],
+ msg='`@` at end of path should give that attribute')
+ self.assertEqual(traverse_obj(etree, '//neighbor/@fail'), [None, None, None, None, None],
+ msg='`@` at end of path should give `None`')
+ self.assertEqual(traverse_obj(etree, ('//neighbor/@', 2)), {'name': 'Malaysia', 'direction': 'N'},
+ msg='`@` should give the full attribute dict')
+ self.assertEqual(traverse_obj(etree, '//year/text()'), ['2008', '2011', '2011'],
+ msg='`text()` at end of path should give the inner text')
+ self.assertEqual(traverse_obj(etree, '//*[@direction]/@direction'), ['E', 'W', 'N', 'W', 'E'],
+ msg='full python xpath features should be supported')
+ self.assertEqual(traverse_obj(etree, (0, '@name')), 'Liechtenstein',
+ msg='special transformations should act on current element')
+ self.assertEqual(traverse_obj(etree, ('country', 0, ..., 'text()', {int_or_none})), [1, 2008, 141100],
+ msg='special transformations should act on current element')
+
def test_http_header_dict(self):
headers = HTTPHeaderDict()
headers['ytdl-test'] = b'0'
diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py
index 5a2f69fccd..8938f4c782 100644
--- a/yt_dlp/utils/traversal.py
+++ b/yt_dlp/utils/traversal.py
@@ -3,6 +3,7 @@ import contextlib
import inspect
import itertools
import re
+import xml.etree.ElementTree
from ._utils import (
IDENTITY,
@@ -118,7 +119,7 @@ def traverse_obj(
branching = True
if isinstance(obj, collections.abc.Mapping):
result = obj.values()
- elif is_iterable_like(obj):
+ elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
result = obj
elif isinstance(obj, re.Match):
result = obj.groups()
@@ -132,7 +133,7 @@ def traverse_obj(
branching = True
if isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items()
- elif is_iterable_like(obj):
+ elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
iter_obj = enumerate(obj)
elif isinstance(obj, re.Match):
iter_obj = itertools.chain(
@@ -168,7 +169,7 @@ def traverse_obj(
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
elif isinstance(key, (int, slice)):
- if is_iterable_like(obj, collections.abc.Sequence):
+ if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)):
branching = isinstance(key, slice)
with contextlib.suppress(IndexError):
result = obj[key]
@@ -176,6 +177,34 @@ def traverse_obj(
with contextlib.suppress(IndexError):
result = str(obj)[key]
+ elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str):
+ xpath, _, special = key.rpartition('/')
+ if not special.startswith('@') and special != 'text()':
+ xpath = key
+ special = None
+
+ # Allow abbreviations of relative paths, absolute paths error
+ if xpath.startswith('/'):
+ xpath = f'.{xpath}'
+ elif xpath and not xpath.startswith('./'):
+ xpath = f'./{xpath}'
+
+ def apply_specials(element):
+ if special is None:
+ return element
+ if special == '@':
+ return element.attrib
+ if special.startswith('@'):
+ return try_call(element.attrib.get, args=(special[1:],))
+ if special == 'text()':
+ return element.text
+ assert False, f'apply_specials is missing case for {special!r}'
+
+ if xpath:
+ result = list(map(apply_specials, obj.iterfind(xpath)))
+ else:
+ result = apply_specials(obj)
+
return branching, result if branching else (result,)
def lazy_last(iterable):