[utils] traverse_obj: Allow re.Match objects (#5174)

Authored by: Grub4K
This commit is contained in:
Simon Sawicki 2022-10-09 03:31:37 +02:00 committed by GitHub
parent f99bbfc983
commit 7b0127e1e1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 3 deletions

View file

@ -2,6 +2,7 @@
# Allow direct execution # Allow direct execution
import os import os
import re
import sys import sys
import unittest import unittest
@ -2080,6 +2081,25 @@ Line 1
with self.assertRaises(TypeError, msg='too many params should result in error'): with self.assertRaises(TypeError, msg='too many params should result in error'):
traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True) traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True)
# Test re.Match as input obj
mobj = re.fullmatch(r'0(12)(?P<group>3)(4)?', '0123')
self.assertEqual(traverse_obj(mobj, ...), [x for x in mobj.groups() if x is not None],
msg='`...` on a `re.Match` should give its `groups()`')
self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 2)), ['0123', '3'],
msg='function on a `re.Match` should give groupno, value starting at 0')
self.assertEqual(traverse_obj(mobj, 'group'), '3',
msg='str key on a `re.Match` should give group with that name')
self.assertEqual(traverse_obj(mobj, 2), '3',
msg='int key on a `re.Match` should give group with that name')
self.assertEqual(traverse_obj(mobj, 'gRoUp', casesense=False), '3',
msg='str key on a `re.Match` should respect casesense')
self.assertEqual(traverse_obj(mobj, 'fail'), None,
msg='failing str key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, 'gRoUpS', casesense=False), None,
msg='failing str key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, 8), None,
msg='failing int key on a `re.Match` should return `default`')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

@ -5305,13 +5305,14 @@ def traverse_obj(
Each of the provided `paths` is tested and the first producing a valid result will be returned. Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found. The next path will also be tested if the path branched but no results could be found.
Supported values for traversal are `Mapping`, `Sequence` and `re.Match`.
A value of None is treated as the absence of a value. A value of None is treated as the absence of a value.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
The keys in the path can be one of: The keys in the path can be one of:
- `None`: Return the current object. - `None`: Return the current object.
- `str`/`int`: Return `obj[key]`. - `str`/`int`: Return `obj[key]`. For `re.Match, return `obj.group(key)`.
- `slice`: Branch out and return all values in `obj[key]`. - `slice`: Branch out and return all values in `obj[key]`.
- `Ellipsis`: Branch out and return a list of all values. - `Ellipsis`: Branch out and return a list of all values.
- `tuple`/`list`: Branch out and return a list of all matching values. - `tuple`/`list`: Branch out and return a list of all matching values.
@ -5322,7 +5323,7 @@ def traverse_obj(
- `dict` Transform the current object and return a matching dict. - `dict` Transform the current object and return a matching dict.
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
`tuple`, `list`, and `dict` all support nested paths and branches `tuple`, `list`, and `dict` all support nested paths and branches.
@params paths Paths which to traverse by. @params paths Paths which to traverse by.
@param default Value to return if the paths do not match. @param default Value to return if the paths do not match.
@ -5370,6 +5371,8 @@ def traverse_obj(
yield from obj.values() yield from obj.values()
elif is_sequence(obj): elif is_sequence(obj):
yield from obj yield from obj
elif isinstance(obj, re.Match):
yield from obj.groups()
elif traverse_string: elif traverse_string:
yield from str(obj) yield from str(obj)
@ -5378,6 +5381,8 @@ def traverse_obj(
iter_obj = enumerate(obj) iter_obj = enumerate(obj)
elif isinstance(obj, collections.abc.Mapping): elif isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items() iter_obj = obj.items()
elif isinstance(obj, re.Match):
iter_obj = enumerate((obj.group(), *obj.groups()))
elif traverse_string: elif traverse_string:
iter_obj = enumerate(str(obj)) iter_obj = enumerate(str(obj))
else: else:
@ -5389,10 +5394,21 @@ def traverse_obj(
yield {k: v if v is not None else default for k, v in iter_obj yield {k: v if v is not None else default for k, v in iter_obj
if v is not None or default is not NO_DEFAULT} if v is not None or default is not NO_DEFAULT}
elif isinstance(obj, dict): elif isinstance(obj, collections.abc.Mapping):
yield (obj.get(key) if casesense or (key in obj) yield (obj.get(key) if casesense or (key in obj)
else next((v for k, v in obj.items() if casefold(k) == key), None)) else next((v for k, v in obj.items() if casefold(k) == key), None))
elif isinstance(obj, re.Match):
if isinstance(key, int) or casesense:
with contextlib.suppress(IndexError):
yield obj.group(key)
return
if not isinstance(key, str):
return
yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
else: else:
if is_user_input: if is_user_input:
key = (int_or_none(key) if ':' not in key key = (int_or_none(key) if ':' not in key