Move away from contextvars

This commit is contained in:
coletdjnz 2024-11-30 12:57:30 +13:00
parent 2699951172
commit 51f3740030
No known key found for this signature in database
GPG key ID: 91984263BB39894A
12 changed files with 90 additions and 91 deletions

View file

@ -49,7 +49,7 @@ def main():
' _module = None',
*extra_ie_code(DummyInfoExtractor),
'\nclass LazyLoadSearchExtractor(LazyLoadExtractor):\n pass\n',
*build_ies(list(extractors.get().values()), (InfoExtractor, SearchInfoExtractor), DummyInfoExtractor),
*build_ies(list(extractors.value.values()), (InfoExtractor, SearchInfoExtractor), DummyInfoExtractor),
))
write_file(lazy_extractors_filename, f'{module_src}\n')

View file

@ -1399,9 +1399,9 @@ class TestYoutubeDL(unittest.TestCase):
def test_load_plugins_compat(self):
# Should try to reload plugins if they haven't already been loaded
all_plugins_loaded.set(False)
all_plugins_loaded.value = False
FakeYDL().close()
assert all_plugins_loaded.get()
assert all_plugins_loaded.value
if __name__ == '__main__':

View file

@ -5,7 +5,7 @@ import sys
import unittest
from pathlib import Path
import yt_dlp._globals
from yt_dlp.plugins import set_plugin_dirs, add_plugin_dirs, PluginDirs, disable_plugins
from yt_dlp.plugins import set_plugin_dirs, add_plugin_dirs, disable_plugins
from yt_dlp.utils import YoutubeDLError
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@ -37,12 +37,12 @@ class TestPlugins(unittest.TestCase):
TEST_PLUGIN_DIR = TEST_DATA_DIR / PACKAGE_NAME
def setUp(self):
plugin_ies.set({})
plugin_pps.set({})
plugin_dirs.set((PluginDirs.DEFAULT_EXTERNAL,))
plugin_specs.set({})
all_plugins_loaded.set(False)
plugins_enabled.set(True)
plugin_ies.value = {}
plugin_pps.value = {}
plugin_dirs.value = ['external']
plugin_specs.value = {}
all_plugins_loaded.value = False
plugins_enabled.value = True
importlib.invalidate_caches()
# Clearing override plugins is probably difficult
for module_name in tuple(sys.modules):
@ -64,29 +64,29 @@ class TestPlugins(unittest.TestCase):
f'{PACKAGE_NAME}.extractor._ignore' in sys.modules,
'loaded module beginning with underscore')
self.assertNotIn('IgnorePluginIE', plugins_ie.keys())
self.assertNotIn('IgnorePluginIE', plugin_ies.get())
self.assertNotIn('IgnorePluginIE', plugin_ies.value)
# Don't load extractors with underscore prefix
self.assertNotIn('_IgnoreUnderscorePluginIE', plugins_ie.keys())
self.assertNotIn('_IgnoreUnderscorePluginIE', plugin_ies.get())
self.assertNotIn('_IgnoreUnderscorePluginIE', plugin_ies.value)
# Don't load extractors not specified in __all__ (if supplied)
self.assertNotIn('IgnoreNotInAllPluginIE', plugins_ie.keys())
self.assertNotIn('IgnoreNotInAllPluginIE', plugin_ies.get())
self.assertNotIn('IgnoreNotInAllPluginIE', plugin_ies.value)
self.assertIn('InAllPluginIE', plugins_ie.keys())
self.assertIn('InAllPluginIE', plugin_ies.get())
self.assertIn('InAllPluginIE', plugin_ies.value)
# Don't load override extractors
self.assertNotIn('OverrideGenericIE', plugins_ie.keys())
self.assertNotIn('OverrideGenericIE', plugin_ies.get())
self.assertNotIn('OverrideGenericIE', plugin_ies.value)
self.assertNotIn('_UnderscoreOverrideGenericIE', plugins_ie.keys())
self.assertNotIn('_UnderscoreOverrideGenericIE', plugin_ies.get())
self.assertNotIn('_UnderscoreOverrideGenericIE', plugin_ies.value)
def test_postprocessor_classes(self):
plugins_pp = load_plugins(POSTPROCESSOR_PLUGIN_SPEC)
self.assertIn('NormalPluginPP', plugins_pp.keys())
self.assertIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys())
self.assertIn('NormalPluginPP', plugin_pps.get())
self.assertIn('NormalPluginPP', plugin_pps.value)
def test_importing_zipped_module(self):
zip_path = TEST_DATA_DIR / 'zipped_plugins.zip'
@ -130,7 +130,7 @@ class TestPlugins(unittest.TestCase):
plugins_ie['NormalPluginIE'].REPLACED,
msg='Reloading has not replaced original extractor plugin')
self.assertTrue(
extractors.get()['NormalPluginIE'].REPLACED,
extractors.value['NormalPluginIE'].REPLACED,
msg='Reloading has not replaced original extractor plugin globally')
plugins_pp = load_plugins(POSTPROCESSOR_PLUGIN_SPEC)
@ -138,7 +138,7 @@ class TestPlugins(unittest.TestCase):
self.assertTrue(plugins_pp['NormalPluginPP'].REPLACED,
msg='Reloading has not replaced original postprocessor plugin')
self.assertTrue(
postprocessors.get()['NormalPluginPP'].REPLACED,
postprocessors.value['NormalPluginPP'].REPLACED,
msg='Reloading has not replaced original postprocessor plugin globally')
finally:
@ -172,7 +172,7 @@ class TestPlugins(unittest.TestCase):
register_plugin_spec(EXTRACTOR_PLUGIN_SPEC)
register_plugin_spec(POSTPROCESSOR_PLUGIN_SPEC)
load_all_plugins()
self.assertTrue(yt_dlp._globals.all_plugins_loaded.get())
self.assertTrue(yt_dlp._globals.all_plugins_loaded.value)
self.assertIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys())
self.assertIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys())
@ -182,36 +182,36 @@ class TestPlugins(unittest.TestCase):
custom_plugin_dir = str(TEST_DATA_DIR / 'plugin_packages')
set_plugin_dirs(custom_plugin_dir)
self.assertEqual(plugin_dirs.get(), (custom_plugin_dir, ))
self.assertNotIn('external', plugin_dirs.get())
self.assertEqual(plugin_dirs.value, [custom_plugin_dir])
self.assertNotIn('external', plugin_dirs.value)
load_plugins(EXTRACTOR_PLUGIN_SPEC)
self.assertIn(f'{PACKAGE_NAME}.extractor.package', sys.modules.keys())
self.assertIn('PackagePluginIE', plugin_ies.get())
self.assertIn('PackagePluginIE', plugin_ies.value)
def test_add_plugin_dirs(self):
custom_plugin_dir = str(TEST_DATA_DIR / 'plugin_packages')
self.assertEqual(plugin_dirs.get(), (PluginDirs.DEFAULT_EXTERNAL,))
self.assertEqual(plugin_dirs.value, ['external'])
add_plugin_dirs(custom_plugin_dir)
self.assertEqual(plugin_dirs.get(), (PluginDirs.DEFAULT_EXTERNAL, custom_plugin_dir))
self.assertEqual(plugin_dirs.value, ['external', custom_plugin_dir])
load_plugins(EXTRACTOR_PLUGIN_SPEC)
self.assertIn(f'{PACKAGE_NAME}.extractor.package', sys.modules.keys())
self.assertIn('PackagePluginIE', plugin_ies.get())
self.assertIn('PackagePluginIE', plugin_ies.value)
def test_disable_plugins(self):
disable_plugins()
ies = load_plugins(EXTRACTOR_PLUGIN_SPEC)
self.assertEqual(ies, {})
self.assertNotIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys())
self.assertNotIn('NormalPluginIE', plugin_ies.get())
self.assertNotIn('NormalPluginIE', plugin_ies.value)
pps = load_plugins(POSTPROCESSOR_PLUGIN_SPEC)
self.assertEqual(pps, {})
self.assertNotIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys())
self.assertNotIn('NormalPluginPP', plugin_pps.get())
self.assertNotIn('NormalPluginPP', plugin_pps.value)
def test_disable_plugins_already_loaded(self):
register_plugin_spec(EXTRACTOR_PLUGIN_SPEC)
@ -221,7 +221,7 @@ class TestPlugins(unittest.TestCase):
with self.assertRaises(YoutubeDLError):
disable_plugins()
self.assertTrue(plugins_enabled.get())
self.assertTrue(plugins_enabled.value)
ies = load_plugins(EXTRACTOR_PLUGIN_SPEC)
self.assertIn('NormalPluginIE', ies)

View file

@ -648,7 +648,7 @@ class YoutubeDL:
self.__header_cookies = []
# compat for API: load plugins if they have not already
if not all_plugins_loaded.get():
if not all_plugins_loaded.value:
load_all_plugins()
stdout = sys.stderr if self.params.get('logtostderr') else sys.stdout
@ -4032,14 +4032,14 @@ class YoutubeDL:
_make_label(ORIGIN, CHANNEL.partition('@')[2] or __version__, __version__),
f'[{RELEASE_GIT_HEAD[:9]}]' if RELEASE_GIT_HEAD else '',
'' if source == 'unknown' else f'({source})',
'' if IN_CLI.get() else 'API' if klass == YoutubeDL else f'API:{self.__module__}.{klass.__qualname__}',
'' if IN_CLI.value else 'API' if klass == YoutubeDL else f'API:{self.__module__}.{klass.__qualname__}',
delim=' '))
if not IN_CLI.get():
if not IN_CLI.value:
write_debug(f'params: {self.params}')
import_extractors()
lazy_extractors = LAZY_EXTRACTORS.get()
lazy_extractors = LAZY_EXTRACTORS.value
if lazy_extractors is None:
write_debug('Lazy loading extractors is disabled')
elif not lazy_extractors:
@ -4079,19 +4079,19 @@ class YoutubeDL:
for plugin_type, plugins in (('Extractor', plugin_ies), ('Post-Processor', plugin_pps)):
display_list = [
klass.__name__ if klass.__name__ == name else f'{klass.__name__} as {name}'
for name, klass in plugins.get().items()]
for name, klass in plugins.value.items()]
if plugin_type == 'Extractor':
display_list.extend(f'{plugins[-1].IE_NAME.partition("+")[2]} ({parent.__name__})'
for parent, plugins in plugin_overrides.get().items())
for parent, plugins in plugin_overrides.value.items())
if not display_list:
continue
write_debug(f'{plugin_type} Plugins: {", ".join(sorted(display_list))}')
if not plugins_enabled.get():
if not plugins_enabled.value:
write_debug('Plugins are disabled')
plugin_dirs = plugin_directories()
if plugin_dirs and plugins_enabled.get():
if plugin_dirs and plugins_enabled.value:
write_debug(f'Plugin directories: {plugin_dirs}')
@functools.cached_property

View file

@ -23,7 +23,6 @@ from ._globals import IN_CLI as _IN_CLI
from .options import parseOpts
from .plugins import load_all_plugins as _load_all_plugins
from .plugins import disable_plugins as _disable_plugins
from .plugins import PluginDirs as _PluginDirs
from .plugins import set_plugin_dirs as _set_plugin_dirs
from .postprocessor import (
FFmpegExtractAudioPP,
@ -433,7 +432,7 @@ def validate_options(opts):
# Other options
opts.plugin_dirs = opts.plugin_dirs or []
if 'no-external' not in opts.plugin_dirs:
opts.plugin_dirs.append(_PluginDirs.DEFAULT_EXTERNAL)
opts.plugin_dirs.append('external')
if opts.playlist_items is not None:
try:
@ -1096,7 +1095,7 @@ def _real_main(argv=None):
def main(argv=None):
_IN_CLI.set(True)
_IN_CLI.value = True
try:
_exit(*variadic(_real_main(argv)))
except (CookieLoadError, DownloadError):

View file

@ -1,25 +1,32 @@
from collections import defaultdict
from contextvars import ContextVar
class Indirect:
def __init__(self, initial, /):
self.value = initial
def __repr__(self, /):
return f'{type(self).__name__}({self.value!r})'
# Internal only - no backwards compatibility guaranteed
postprocessors = ContextVar('postprocessors', default={})
extractors = ContextVar('extractors', default={})
IN_CLI = ContextVar('IN_CLI', default=False)
postprocessors = Indirect({})
extractors = Indirect({})
IN_CLI = Indirect(False)
# `False`=force, `None`=disabled, `True`=enabled
LAZY_EXTRACTORS = ContextVar('LAZY_EXTRACTORS', default=False)
LAZY_EXTRACTORS = Indirect(False)
# Plugins
plugin_specs = ContextVar('plugin_specs', default={})
plugin_specs = Indirect({})
# Whether plugins have been loaded once
all_plugins_loaded = ContextVar('all_plugins_loaded', default=False)
all_plugins_loaded = Indirect(False)
plugins_enabled = ContextVar('plugins_enabled', default=True)
plugins_enabled = Indirect(True)
plugin_dirs = ContextVar('plugin_dirs', default=('external', ))
plugin_ies = ContextVar('plugin_ies', default={})
plugin_overrides = ContextVar('plugin_overrides', default=defaultdict(list))
plugin_pps = ContextVar('plugin_pps', default={})
plugin_dirs = Indirect(['external'])
plugin_ies = Indirect({})
plugin_overrides = Indirect(defaultdict(list))
plugin_pps = Indirect({})

View file

@ -19,7 +19,7 @@ def gen_extractor_classes():
The order does matter; the first extractor matched is the one handling the URL.
"""
import_extractors()
return list(_extractors_context.get().values())
return list(_extractors_context.value.values())
def gen_extractors():
@ -47,7 +47,7 @@ def list_extractors(age_limit=None):
def get_info_extractor(ie_name):
"""Returns the info extractor class with the given ie_name"""
import_extractors()
return _extractors_context.get()[f'{ie_name}IE']
return _extractors_context.value[f'{ie_name}IE']
def import_extractors():

View file

@ -3969,13 +3969,13 @@ class InfoExtractor:
while getattr(super_class, '__wrapped__', None):
super_class = super_class.__wrapped__
if not any(override.PLUGIN_NAME == plugin_name for override in _plugin_overrides.get()[super_class]):
if not any(override.PLUGIN_NAME == plugin_name for override in _plugin_overrides.value[super_class]):
cls.__wrapped__ = next_mro_class
cls.PLUGIN_NAME, cls.ie_key = plugin_name, next_mro_class.ie_key
cls.IE_NAME = f'{next_mro_class.IE_NAME}+{plugin_name}'
setattr(sys.modules[super_class.__module__], super_class.__name__, cls)
_plugin_overrides.get()[super_class].append(cls)
_plugin_overrides.value[super_class].append(cls)
return super().__init_subclass__(**kwargs)

View file

@ -8,9 +8,9 @@ _CLASS_LOOKUP = None
if not os.environ.get('YTDLP_NO_LAZY_EXTRACTORS'):
try:
from .lazy_extractors import _CLASS_LOOKUP
LAZY_EXTRACTORS.set(True)
LAZY_EXTRACTORS = True
except ImportError:
LAZY_EXTRACTORS.set(None)
LAZY_EXTRACTORS = False
if not _CLASS_LOOKUP:
from . import _extractors
@ -23,7 +23,7 @@ if not _CLASS_LOOKUP:
_CLASS_LOOKUP['GenericIE'] = _extractors.GenericIE
# We want to append to the main lookup
_current = _extractors_context.get()
_current = _extractors_context.value
for name, ie in _CLASS_LOOKUP.items():
_current.setdefault(name, ie)

View file

@ -1,6 +1,5 @@
import contextlib
import dataclasses
import enum
import importlib
import importlib.abc
import importlib.machinery
@ -12,7 +11,7 @@ import pkgutil
import sys
import traceback
import zipimport
from contextvars import ContextVar
import functools
from pathlib import Path
from zipfile import ZipFile
@ -21,11 +20,10 @@ from ._globals import (
all_plugins_loaded,
plugin_specs,
plugins_enabled,
Indirect,
)
from .compat import functools # isort: split
from .utils import (
Config,
get_executable_path,
get_system_config_dirs,
get_user_config_dirs,
@ -50,23 +48,18 @@ __all__ = [
'add_plugin_dirs',
'set_plugin_dirs',
'disable_plugins',
'PluginDirs',
'get_plugin_spec',
'PACKAGE_NAME',
'COMPAT_PACKAGE_NAME',
]
class PluginDirs(enum.Enum):
DEFAULT_EXTERNAL = 'external' # The default external plugin directories
@dataclasses.dataclass
class PluginSpec:
module_name: str
suffix: str
destination: ContextVar
plugin_destination: ContextVar
destination: Indirect
plugin_destination: Indirect
class PluginLoader(importlib.abc.Loader):
@ -139,8 +132,8 @@ class PluginFinder(importlib.abc.MetaPathFinder):
def search_locations(self, fullname):
candidate_locations = itertools.chain.from_iterable(
external_plugin_paths() if candidate == PluginDirs.DEFAULT_EXTERNAL else Path(candidate).iterdir()
for candidate in plugin_dirs.get()
external_plugin_paths() if candidate == 'external' else Path(candidate).iterdir()
for candidate in plugin_dirs.value
)
parts = Path(*fullname.split('.'))
@ -201,7 +194,7 @@ def get_regular_classes(module, module_name, suffix):
def load_plugins(plugin_spec: PluginSpec):
name, suffix = plugin_spec.module_name, plugin_spec.suffix
regular_classes = {}
if os.environ.get('YTDLP_NO_PLUGINS') or plugins_enabled.get() is False:
if os.environ.get('YTDLP_NO_PLUGINS') or plugins_enabled.value is False:
return regular_classes
for finder, module_name, _ in iter_modules(name):
@ -228,7 +221,7 @@ def load_plugins(plugin_spec: PluginSpec):
# Compat: old plugin system using __init__.py
# Note: plugins imported this way do not show up in directories()
# nor are considered part of the yt_dlp_plugins namespace package
if PluginDirs.DEFAULT_EXTERNAL in plugin_dirs.get():
if 'external' in plugin_dirs.value:
with contextlib.suppress(FileNotFoundError):
spec = importlib.util.spec_from_file_location(
name,
@ -240,46 +233,46 @@ def load_plugins(plugin_spec: PluginSpec):
regular_classes.update(get_regular_classes(plugins, spec.name, suffix))
# Add the classes into the global plugin lookup for that type
plugin_spec.plugin_destination.set(regular_classes)
plugin_spec.plugin_destination.value = regular_classes
# We want to prepend to the main lookup for that type
plugin_spec.destination.set(merge_dicts(regular_classes, plugin_spec.destination.get()))
plugin_spec.destination.value = merge_dicts(regular_classes, plugin_spec.destination.value)
return regular_classes
def load_all_plugins():
for plugin_spec in plugin_specs.get().values():
for plugin_spec in plugin_specs.value.values():
load_plugins(plugin_spec)
all_plugins_loaded.set(True)
all_plugins_loaded.value = True
def register_plugin_spec(plugin_spec: PluginSpec):
# If the plugin spec for a module is already registered, it will not be added again
if plugin_spec.module_name not in plugin_specs.get():
plugin_specs.get()[plugin_spec.module_name] = plugin_spec
if plugin_spec.module_name not in plugin_specs.value:
plugin_specs.value[plugin_spec.module_name] = plugin_spec
sys.meta_path.insert(0, PluginFinder(f'{PACKAGE_NAME}.{plugin_spec.module_name}'))
def add_plugin_dirs(*paths):
"""Add external plugin dirs to the existing ones"""
plugin_dirs.set((*plugin_dirs.get(), *paths))
plugin_dirs.value.extend(paths)
def set_plugin_dirs(*paths):
"""Set external plugin dirs, overriding the default ones"""
plugin_dirs.set(tuple(paths))
plugin_dirs.value = list(paths)
def get_plugin_spec(module_name):
return plugin_specs.get().get(module_name)
return plugin_specs.value.get(module_name)
def disable_plugins():
if (
all_plugins_loaded.get()
or any(len(plugin_spec.plugin_destination.get()) != 0 for plugin_spec in plugin_specs.get().values())
all_plugins_loaded.value
or any(len(plugin_spec.plugin_destination.value) != 0 for plugin_spec in plugin_specs.value.values())
):
# note: we can't detect all cases when plugins are loaded (e.g. if spec isn't registered)
raise YoutubeDLError('Plugins have already been loaded. Cannot disable plugins after loading plugins.')
plugins_enabled.set(False)
plugins_enabled.value = False

View file

@ -39,7 +39,7 @@ from ..utils import deprecation_warning
def __getattr__(name):
lookup = plugin_pps.get()
lookup = plugin_pps.value
if name in lookup:
deprecation_warning(
f'Importing a plugin Post-Processor from {__name__} is deprecated. '
@ -50,7 +50,7 @@ def __getattr__(name):
def get_postprocessor(key):
return postprocessors.get()[key + 'PP']
return postprocessors.value[key + 'PP']
register_plugin_spec(PluginSpec(
@ -65,6 +65,6 @@ _default_pps = {
for name, value in globals().items()
if name.endswith('PP') or name in ('PostProcessor', 'FFmpegPostProcessor')
}
postprocessors.set(_default_pps)
postprocessors.value.update(_default_pps)
__all__ = list(_default_pps.values())

View file

@ -1484,7 +1484,7 @@ def write_string(s, out=None, encoding=None):
# TODO: Use global logger
def deprecation_warning(msg, *, printer=None, stacklevel=0, **kwargs):
if IN_CLI.get():
if IN_CLI.value:
if msg in deprecation_warning._cache:
return
deprecation_warning._cache.add(msg)