Decouple plugins.py from plugin types

This commit is contained in:
coletdjnz 2024-10-20 11:59:29 +13:00
parent 9269248935
commit 21e13bfa84
No known key found for this signature in database
GPG key ID: 91984263BB39894A
8 changed files with 101 additions and 71 deletions

View file

@ -6,7 +6,7 @@ import sys
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
from yt_dlp._globals import ALL_PLUGINS_LOADED from yt_dlp._globals import all_plugins_loaded
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@ -1400,9 +1400,9 @@ class TestYoutubeDL(unittest.TestCase):
def test_load_plugins_compat(self): def test_load_plugins_compat(self):
# Should try to reload plugins if they haven't already been loaded # Should try to reload plugins if they haven't already been loaded
ALL_PLUGINS_LOADED.set(False) all_plugins_loaded.set(False)
FakeYDL().close() FakeYDL().close()
assert ALL_PLUGINS_LOADED.get() assert all_plugins_loaded.get()
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -11,8 +11,23 @@ TEST_DATA_DIR = Path(os.path.dirname(os.path.abspath(__file__)), 'testdata')
sys.path.append(str(TEST_DATA_DIR)) sys.path.append(str(TEST_DATA_DIR))
importlib.invalidate_caches() importlib.invalidate_caches()
from yt_dlp.plugins import PACKAGE_NAME, PluginType, directories, load_plugins, load_all_plugin_types from yt_dlp.plugins import PACKAGE_NAME, PluginSpec, directories, load_plugins, load_all_plugins, register_plugin_spec
from yt_dlp._globals import extractors, postprocessors, plugin_dirs, plugin_ies, plugin_pps, ALL_PLUGINS_LOADED from yt_dlp._globals import extractors, postprocessors, plugin_dirs, plugin_ies, plugin_pps, all_plugins_loaded, plugin_specs
EXTRACTOR_PLUGIN_SPEC = PluginSpec(
module_name='extractor',
suffix='IE',
destination=extractors,
plugin_destination=plugin_ies,
)
POSTPROCESSOR_PLUGIN_SPEC = PluginSpec(
module_name='postprocessor',
suffix='PP',
destination=postprocessors,
plugin_destination=plugin_pps,
)
class TestPlugins(unittest.TestCase): class TestPlugins(unittest.TestCase):
@ -23,7 +38,8 @@ class TestPlugins(unittest.TestCase):
plugin_ies.set({}) plugin_ies.set({})
plugin_pps.set({}) plugin_pps.set({})
plugin_dirs.set((...,)) plugin_dirs.set((...,))
ALL_PLUGINS_LOADED.set(False) plugin_specs.set({})
all_plugins_loaded.set(False)
importlib.invalidate_caches() importlib.invalidate_caches()
# Clearing override plugins is probably difficult # Clearing override plugins is probably difficult
for module_name in tuple(sys.modules): for module_name in tuple(sys.modules):
@ -35,7 +51,7 @@ class TestPlugins(unittest.TestCase):
self.assertIn(self.TEST_PLUGIN_DIR, map(Path, directories())) self.assertIn(self.TEST_PLUGIN_DIR, map(Path, directories()))
def test_extractor_classes(self): def test_extractor_classes(self):
plugins_ie = load_plugins(PluginType.EXTRACTORS) plugins_ie = load_plugins(EXTRACTOR_PLUGIN_SPEC)
self.assertIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys()) self.assertIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys())
self.assertIn('NormalPluginIE', plugins_ie.keys()) self.assertIn('NormalPluginIE', plugins_ie.keys())
@ -64,7 +80,7 @@ class TestPlugins(unittest.TestCase):
self.assertNotIn('_UnderscoreOverrideGenericIE', plugin_ies.get()) self.assertNotIn('_UnderscoreOverrideGenericIE', plugin_ies.get())
def test_postprocessor_classes(self): def test_postprocessor_classes(self):
plugins_pp = load_plugins(PluginType.POSTPROCESSORS) plugins_pp = load_plugins(POSTPROCESSOR_PLUGIN_SPEC)
self.assertIn('NormalPluginPP', plugins_pp.keys()) self.assertIn('NormalPluginPP', plugins_pp.keys())
self.assertIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys()) self.assertIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys())
self.assertIn('NormalPluginPP', plugin_pps.get()) self.assertIn('NormalPluginPP', plugin_pps.get())
@ -80,10 +96,10 @@ class TestPlugins(unittest.TestCase):
package = importlib.import_module(f'{PACKAGE_NAME}.{plugin_type}') package = importlib.import_module(f'{PACKAGE_NAME}.{plugin_type}')
self.assertIn(zip_path / PACKAGE_NAME / plugin_type, map(Path, package.__path__)) self.assertIn(zip_path / PACKAGE_NAME / plugin_type, map(Path, package.__path__))
plugins_ie = load_plugins(PluginType.EXTRACTORS) plugins_ie = load_plugins(EXTRACTOR_PLUGIN_SPEC)
self.assertIn('ZippedPluginIE', plugins_ie.keys()) self.assertIn('ZippedPluginIE', plugins_ie.keys())
plugins_pp = load_plugins(PluginType.POSTPROCESSORS) plugins_pp = load_plugins(POSTPROCESSOR_PLUGIN_SPEC)
self.assertIn('ZippedPluginPP', plugins_pp.keys()) self.assertIn('ZippedPluginPP', plugins_pp.keys())
finally: finally:
@ -93,11 +109,8 @@ class TestPlugins(unittest.TestCase):
def test_reloading_plugins(self): def test_reloading_plugins(self):
reload_plugins_path = TEST_DATA_DIR / 'reload_plugins' reload_plugins_path = TEST_DATA_DIR / 'reload_plugins'
load_plugins(EXTRACTOR_PLUGIN_SPEC)
for plugin_type in ('extractor', 'postprocessor'): load_plugins(POSTPROCESSOR_PLUGIN_SPEC)
importlib.import_module(f'{PACKAGE_NAME}.{plugin_type}')
load_plugins(PluginType.EXTRACTORS)
load_plugins(PluginType.POSTPROCESSORS)
# Remove default folder and add reload_plugin path # Remove default folder and add reload_plugin path
sys.path.remove(str(TEST_DATA_DIR)) sys.path.remove(str(TEST_DATA_DIR))
@ -108,7 +121,7 @@ class TestPlugins(unittest.TestCase):
package = importlib.import_module(f'{PACKAGE_NAME}.{plugin_type}') package = importlib.import_module(f'{PACKAGE_NAME}.{plugin_type}')
self.assertIn(reload_plugins_path / PACKAGE_NAME / plugin_type, map(Path, package.__path__)) self.assertIn(reload_plugins_path / PACKAGE_NAME / plugin_type, map(Path, package.__path__))
plugins_ie = load_plugins(PluginType.EXTRACTORS) plugins_ie = load_plugins(EXTRACTOR_PLUGIN_SPEC)
self.assertIn('NormalPluginIE', plugins_ie.keys()) self.assertIn('NormalPluginIE', plugins_ie.keys())
self.assertTrue( self.assertTrue(
plugins_ie['NormalPluginIE'].REPLACED, plugins_ie['NormalPluginIE'].REPLACED,
@ -117,7 +130,7 @@ class TestPlugins(unittest.TestCase):
extractors.get()['NormalPluginIE'].REPLACED, extractors.get()['NormalPluginIE'].REPLACED,
msg='Reloading has not replaced original extractor plugin globally') msg='Reloading has not replaced original extractor plugin globally')
plugins_pp = load_plugins(PluginType.POSTPROCESSORS) plugins_pp = load_plugins(POSTPROCESSOR_PLUGIN_SPEC)
self.assertIn('NormalPluginPP', plugins_pp.keys()) self.assertIn('NormalPluginPP', plugins_pp.keys())
self.assertTrue(plugins_pp['NormalPluginPP'].REPLACED, self.assertTrue(plugins_pp['NormalPluginPP'].REPLACED,
msg='Reloading has not replaced original postprocessor plugin') msg='Reloading has not replaced original postprocessor plugin')
@ -131,7 +144,7 @@ class TestPlugins(unittest.TestCase):
importlib.invalidate_caches() importlib.invalidate_caches()
def test_extractor_override_plugin(self): def test_extractor_override_plugin(self):
load_plugins(PluginType.EXTRACTORS) load_plugins(EXTRACTOR_PLUGIN_SPEC)
from yt_dlp.extractor.generic import GenericIE from yt_dlp.extractor.generic import GenericIE
@ -141,25 +154,29 @@ class TestPlugins(unittest.TestCase):
self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override') self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
importlib.invalidate_caches() importlib.invalidate_caches()
# test that loading a second time doesn't wrap a second time # test that loading a second time doesn't wrap a second time
load_plugins(PluginType.EXTRACTORS) load_plugins(EXTRACTOR_PLUGIN_SPEC)
from yt_dlp.extractor.generic import GenericIE from yt_dlp.extractor.generic import GenericIE
self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override') self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
def test_load_all_plugin_types(self): def test_load_all_plugin_types(self):
# no plugin specs registered
load_all_plugins()
self.assertNotIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys()) self.assertNotIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys())
self.assertNotIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys()) self.assertNotIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys())
load_all_plugin_types() register_plugin_spec(EXTRACTOR_PLUGIN_SPEC)
self.assertTrue(yt_dlp._globals.ALL_PLUGINS_LOADED.get()) register_plugin_spec(POSTPROCESSOR_PLUGIN_SPEC)
load_all_plugins()
self.assertTrue(yt_dlp._globals.all_plugins_loaded.get())
self.assertIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys()) self.assertIn(f'{PACKAGE_NAME}.extractor.normal', sys.modules.keys())
self.assertIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys()) self.assertIn(f'{PACKAGE_NAME}.postprocessor.normal', sys.modules.keys())
def test_plugin_dirs(self): def test_plugin_dirs(self):
plugin_dirs.set((..., str(TEST_DATA_DIR / 'plugin_packages'))) plugin_dirs.set((..., str(TEST_DATA_DIR / 'plugin_packages')))
load_all_plugin_types() load_plugins(EXTRACTOR_PLUGIN_SPEC)
self.assertTrue(yt_dlp._globals.ALL_PLUGINS_LOADED.get())
self.assertIn(f'{PACKAGE_NAME}.extractor.package', sys.modules.keys()) self.assertIn(f'{PACKAGE_NAME}.extractor.package', sys.modules.keys())
self.assertIn('PackagePluginIE', plugin_ies.get()) self.assertIn('PackagePluginIE', plugin_ies.get())

View file

@ -38,7 +38,8 @@ from ._globals import (
LAZY_EXTRACTORS, LAZY_EXTRACTORS,
plugin_ies, plugin_ies,
plugin_overrides, plugin_overrides,
plugin_pps, ALL_PLUGINS_LOADED, plugin_pps,
all_plugins_loaded,
) )
from .minicurses import format_text from .minicurses import format_text
from .networking import HEADRequest, Request, RequestDirector from .networking import HEADRequest, Request, RequestDirector
@ -51,7 +52,7 @@ from .networking.exceptions import (
network_exceptions, network_exceptions,
) )
from .networking.impersonate import ImpersonateRequestHandler from .networking.impersonate import ImpersonateRequestHandler
from .plugins import directories as plugin_directories, load_all_plugin_types from .plugins import directories as plugin_directories, load_all_plugins
from .postprocessor import ( from .postprocessor import (
EmbedThumbnailPP, EmbedThumbnailPP,
FFmpegFixupDuplicateMoovPP, FFmpegFixupDuplicateMoovPP,
@ -646,8 +647,8 @@ class YoutubeDL:
self.__header_cookies = [] self.__header_cookies = []
# compat for API: load plugins if they have not already # compat for API: load plugins if they have not already
if not ALL_PLUGINS_LOADED.get(): if not all_plugins_loaded.get():
load_all_plugin_types() load_all_plugins()
stdout = sys.stderr if self.params.get('logtostderr') else sys.stdout stdout = sys.stderr if self.params.get('logtostderr') else sys.stdout
self._out_files = Namespace( self._out_files = Namespace(

View file

@ -22,7 +22,7 @@ from .extractor.adobepass import MSO_INFO
from .networking.impersonate import ImpersonateTarget from .networking.impersonate import ImpersonateTarget
from ._globals import IN_CLI, plugin_dirs from ._globals import IN_CLI, plugin_dirs
from .options import parseOpts from .options import parseOpts
from .plugins import load_all_plugin_types from .plugins import load_all_plugins
from .postprocessor import ( from .postprocessor import (
FFmpegExtractAudioPP, FFmpegExtractAudioPP,
FFmpegMergerPP, FFmpegMergerPP,
@ -987,7 +987,7 @@ def _real_main(argv=None):
# load all plugins into the global lookup # load all plugins into the global lookup
plugin_dirs.set(opts.plugin_dirs) plugin_dirs.set(opts.plugin_dirs)
load_all_plugin_types() load_all_plugins()
with YoutubeDL(ydl_opts) as ydl: with YoutubeDL(ydl_opts) as ydl:
pre_process = opts.update_self or opts.rm_cachedir pre_process = opts.update_self or opts.rm_cachedir

View file

@ -9,8 +9,13 @@ IN_CLI = ContextVar('IN_CLI', default=False)
# `False`=force, `None`=disabled, `True`=enabled # `False`=force, `None`=disabled, `True`=enabled
LAZY_EXTRACTORS = ContextVar('LAZY_EXTRACTORS', default=False) LAZY_EXTRACTORS = ContextVar('LAZY_EXTRACTORS', default=False)
# Plugins
plugin_specs = ContextVar('plugin_specs', default={})
# Whether plugins have been loaded once # Whether plugins have been loaded once
ALL_PLUGINS_LOADED = ContextVar('PLUGINS_LOADED', default=False) all_plugins_loaded = ContextVar('all_plugins_loaded', default=False)
# `...`=search default plugin dirs # `...`=search default plugin dirs
plugin_dirs = ContextVar('plugin_dirs', default=(..., )) plugin_dirs = ContextVar('plugin_dirs', default=(..., ))

View file

@ -1,9 +1,18 @@
from .._globals import extractors as _extractors_context from .._globals import extractors as _extractors_context
from .._globals import plugin_ies as _plugin_ies_context
from ..compat.compat_utils import passthrough_module from ..compat.compat_utils import passthrough_module
from ..plugins import PluginSpec, register_plugin_spec
passthrough_module(__name__, '.extractors') passthrough_module(__name__, '.extractors')
del passthrough_module del passthrough_module
register_plugin_spec(PluginSpec(
module_name='extractor',
suffix='IE',
destination=_extractors_context,
plugin_destination=_plugin_ies_context,
))
def gen_extractor_classes(): def gen_extractor_classes():
""" Return a list of supported extractors. """ Return a list of supported extractors.

View file

@ -1,6 +1,5 @@
import contextlib import contextlib
import dataclasses import dataclasses
import enum
import importlib import importlib
import importlib.abc import importlib.abc
import importlib.machinery import importlib.machinery
@ -17,12 +16,9 @@ from pathlib import Path
from zipfile import ZipFile from zipfile import ZipFile
from ._globals import ( from ._globals import (
extractors,
plugin_dirs, plugin_dirs,
plugin_ies, all_plugins_loaded,
plugin_pps, plugin_specs,
postprocessors,
ALL_PLUGINS_LOADED,
) )
from .compat import functools # isort: split from .compat import functools # isort: split
@ -40,9 +36,12 @@ COMPAT_PACKAGE_NAME = 'ytdlp_plugins'
_BASE_PACKAGE_PATH = Path(__file__).parent _BASE_PACKAGE_PATH = Path(__file__).parent
class PluginType(enum.Enum): @dataclasses.dataclass
POSTPROCESSORS = ('postprocessor', 'PP') class PluginSpec:
EXTRACTORS = ('extractor', 'IE') module_name: str
suffix: str
destination: ContextVar
plugin_destination: ContextVar
class PluginLoader(importlib.abc.Loader): class PluginLoader(importlib.abc.Loader):
@ -68,7 +67,7 @@ def dirs_in_zip(archive):
return () return ()
def default_plugin_paths(): def external_plugin_paths():
def _get_package_paths(*root_paths, containing_folder): def _get_package_paths(*root_paths, containing_folder):
for config_dir in orderedSet(map(Path, root_paths), lazy=True): for config_dir in orderedSet(map(Path, root_paths), lazy=True):
# We need to filter the base path added when running __main__.py directly # We need to filter the base path added when running __main__.py directly
@ -115,7 +114,7 @@ class PluginFinder(importlib.abc.MetaPathFinder):
def search_locations(self, fullname): def search_locations(self, fullname):
candidate_locations = itertools.chain.from_iterable( candidate_locations = itertools.chain.from_iterable(
default_plugin_paths() if candidate is ... else Path(candidate).iterdir() external_plugin_paths() if candidate is ... else Path(candidate).iterdir()
for candidate in plugin_dirs.get() for candidate in plugin_dirs.get()
) )
@ -174,27 +173,8 @@ def get_regular_classes(module, module_name, suffix):
)) ))
@dataclasses.dataclass def load_plugins(plugin_spec: PluginSpec):
class _PluginTypeConfig: name, suffix = plugin_spec.module_name, plugin_spec.suffix
destination: ContextVar
plugin_destination: ContextVar
_plugin_type_lookup = {
PluginType.POSTPROCESSORS: _PluginTypeConfig(
destination=postprocessors,
plugin_destination=plugin_pps,
),
PluginType.EXTRACTORS: _PluginTypeConfig(
destination=extractors,
plugin_destination=plugin_ies,
),
}
def load_plugins(plugin_type: PluginType):
plugin_config = _plugin_type_lookup[plugin_type]
name, suffix = plugin_type.value
regular_classes = {} regular_classes = {}
if os.environ.get('YTDLP_NO_PLUGINS'): if os.environ.get('YTDLP_NO_PLUGINS'):
return regular_classes return regular_classes
@ -235,25 +215,36 @@ def load_plugins(plugin_type: PluginType):
regular_classes.update(get_regular_classes(plugins, spec.name, suffix)) regular_classes.update(get_regular_classes(plugins, spec.name, suffix))
# Add the classes into the global plugin lookup for that type # Add the classes into the global plugin lookup for that type
plugin_config.plugin_destination.set(regular_classes) plugin_spec.plugin_destination.set(regular_classes)
# We want to prepend to the main lookup for that type # We want to prepend to the main lookup for that type
plugin_config.destination.set(merge_dicts(regular_classes, plugin_config.destination.get())) plugin_spec.destination.set(merge_dicts(regular_classes, plugin_spec.destination.get()))
return regular_classes return regular_classes
def load_all_plugin_types(): def load_all_plugins():
for plugin_type in PluginType: for plugin_spec in plugin_specs.get().values():
load_plugins(plugin_type) load_plugins(plugin_spec)
ALL_PLUGINS_LOADED.set(True) all_plugins_loaded.set(True)
sys.meta_path.insert(0, PluginFinder(f'{PACKAGE_NAME}.extractor', f'{PACKAGE_NAME}.postprocessor')) def register_plugin_spec(plugin_spec: PluginSpec):
# If the plugin spec for a module is already registered, it will not be added again
if plugin_spec.module_name not in plugin_specs.get():
plugin_specs.get()[plugin_spec.module_name] = plugin_spec
sys.meta_path.insert(0, PluginFinder(f'{PACKAGE_NAME}.{plugin_spec.module_name}'))
def get_plugin_spec(module_name):
return plugin_specs.get().get(module_name)
__all__ = [ __all__ = [
'directories', 'directories',
'load_plugins', 'load_plugins',
'load_all_plugin_types', 'load_all_plugins',
'register_plugin_spec',
'get_plugin_spec',
'PACKAGE_NAME', 'PACKAGE_NAME',
'COMPAT_PACKAGE_NAME', 'COMPAT_PACKAGE_NAME',
] ]

View file

@ -34,7 +34,7 @@ from .sponskrub import SponSkrubPP
from .sponsorblock import SponsorBlockPP from .sponsorblock import SponsorBlockPP
from .xattrpp import XAttrMetadataPP from .xattrpp import XAttrMetadataPP
from .._globals import plugin_pps, postprocessors from .._globals import plugin_pps, postprocessors
from ..plugins import PACKAGE_NAME from ..plugins import PACKAGE_NAME, register_plugin_spec, PluginSpec
from ..utils import deprecation_warning from ..utils import deprecation_warning
@ -53,6 +53,13 @@ def get_postprocessor(key):
return postprocessors.get()[key + 'PP'] return postprocessors.get()[key + 'PP']
register_plugin_spec(PluginSpec(
module_name='postprocessor',
suffix='PP',
destination=postprocessors,
plugin_destination=plugin_pps,
))
_default_pps = { _default_pps = {
name: value name: value
for name, value in globals().items() for name, value in globals().items()