mirror of
https://github.com/yt-dlp/yt-dlp
synced 2025-01-16 03:40:50 +01:00
revert back to init_subclass, add guard against multiple imports of same plugin
This commit is contained in:
parent
42771dde1c
commit
a19dd28fdc
3 changed files with 24 additions and 54 deletions
|
@ -116,7 +116,7 @@ class TestPlugins(unittest.TestCase):
|
|||
for module_name in tuple(sys.modules):
|
||||
if module_name.startswith(f'{PACKAGE_NAME}.extractor'):
|
||||
del sys.modules[module_name]
|
||||
plugins_ie = load_plugins(PluginType.EXTRACTORS)
|
||||
load_plugins(PluginType.EXTRACTORS)
|
||||
|
||||
from yt_dlp.extractor.generic import GenericIE
|
||||
|
||||
|
@ -124,6 +124,11 @@ class TestPlugins(unittest.TestCase):
|
|||
self.assertEqual(GenericIE.SECONDARY_TEST_FIELD, 'underscore-override')
|
||||
|
||||
self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
|
||||
importlib.invalidate_caches()
|
||||
# test that loading a second time doesn't wrap a second time
|
||||
load_plugins(PluginType.EXTRACTORS)
|
||||
from yt_dlp.extractor.generic import GenericIE
|
||||
self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -6,6 +6,7 @@ import hashlib
|
|||
import http.client
|
||||
import http.cookiejar
|
||||
import http.cookies
|
||||
import inspect
|
||||
import itertools
|
||||
import json
|
||||
import math
|
||||
|
@ -21,6 +22,7 @@ import urllib.parse
|
|||
import urllib.request
|
||||
import xml.etree.ElementTree
|
||||
|
||||
from .._globals import plugin_overrides
|
||||
from ..compat import (
|
||||
compat_etree_fromstring,
|
||||
compat_expanduser,
|
||||
|
@ -3933,7 +3935,19 @@ class InfoExtractor:
|
|||
@classmethod
|
||||
def __init_subclass__(cls, *, plugin_name=None, **kwargs):
|
||||
if plugin_name is not None:
|
||||
cls._plugin_name = plugin_name
|
||||
mro = inspect.getmro(cls)
|
||||
next_mro_class = super_class = mro[mro.index(cls) + 1]
|
||||
|
||||
while getattr(super_class, '__wrapped__', None):
|
||||
super_class = super_class.__wrapped__
|
||||
|
||||
if not any(override.PLUGIN_NAME == plugin_name for override in plugin_overrides.get()[super_class]):
|
||||
cls.__wrapped__ = next_mro_class
|
||||
cls.PLUGIN_NAME, cls.ie_key = plugin_name, next_mro_class.ie_key
|
||||
cls.IE_NAME = f'{next_mro_class.IE_NAME}+{plugin_name}'
|
||||
|
||||
setattr(sys.modules[super_class.__module__], super_class.__name__, cls)
|
||||
plugin_overrides.get()[super_class].append(cls)
|
||||
return super().__init_subclass__(**kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ import os
|
|||
import pkgutil
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
import zipimport
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
|
@ -22,7 +21,8 @@ from ._globals import (
|
|||
plugin_dirs,
|
||||
plugin_ies,
|
||||
plugin_pps,
|
||||
postprocessors, plugin_overrides, ALL_PLUGINS_LOADED,
|
||||
postprocessors,
|
||||
ALL_PLUGINS_LOADED,
|
||||
)
|
||||
|
||||
from .compat import functools # isort: split
|
||||
|
@ -170,52 +170,24 @@ def get_regular_classes(module, module_name, suffix):
|
|||
and obj.__module__.startswith(module_name)
|
||||
and not obj.__name__.startswith('_')
|
||||
and obj.__name__ in getattr(module, '__all__', [obj.__name__])
|
||||
and getattr(obj, '_plugin_name', None) is None
|
||||
and getattr(obj, 'PLUGIN_NAME', None) is None
|
||||
))
|
||||
|
||||
|
||||
def get_override_classes(module, module_name, suffix):
|
||||
# Find override plugin classes
|
||||
def predicate(obj):
|
||||
if not inspect.isclass(obj):
|
||||
return False
|
||||
mro = inspect.getmro(obj)
|
||||
return (
|
||||
obj.__module__.startswith(module_name)
|
||||
and getattr(obj, '_plugin_name', None) is not None
|
||||
and mro[mro.index(obj) + 1].__name__.endswith(suffix)
|
||||
)
|
||||
return inspect.getmembers(module, predicate)
|
||||
|
||||
|
||||
def configure_ie_override_class(klass, super_class, plugin_name):
|
||||
ie_key = getattr(super_class, 'ie_key', None)
|
||||
if not ie_key:
|
||||
warnings.warn(f'Override plugin {klass} is not an extractor')
|
||||
return False
|
||||
klass.ie_key = ie_key
|
||||
klass.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _PluginTypeConfig:
|
||||
destination: ContextVar
|
||||
plugin_destination: ContextVar
|
||||
# Function to configure the override class. Return False to skip the class
|
||||
# Takes (klass, super_class, plugin_name) as arguments
|
||||
configure_override_func: callable = lambda *args: None
|
||||
|
||||
|
||||
_plugin_type_lookup = {
|
||||
PluginType.POSTPROCESSORS: _PluginTypeConfig(
|
||||
destination=postprocessors,
|
||||
plugin_destination=plugin_pps,
|
||||
configure_override_func=None,
|
||||
),
|
||||
PluginType.EXTRACTORS: _PluginTypeConfig(
|
||||
destination=extractors,
|
||||
plugin_destination=plugin_ies,
|
||||
configure_override_func=configure_ie_override_class,
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -224,7 +196,6 @@ def load_plugins(plugin_type: PluginType):
|
|||
plugin_config = _plugin_type_lookup[plugin_type]
|
||||
name, suffix = plugin_type.value
|
||||
regular_classes = {}
|
||||
override_classes = {}
|
||||
if os.environ.get('YTDLP_NO_PLUGINS'):
|
||||
return regular_classes
|
||||
|
||||
|
@ -248,7 +219,6 @@ def load_plugins(plugin_type: PluginType):
|
|||
)
|
||||
continue
|
||||
regular_classes.update(get_regular_classes(module, module_name, suffix))
|
||||
override_classes.update(get_override_classes(module, module_name, suffix))
|
||||
|
||||
# Compat: old plugin system using __init__.py
|
||||
# Note: plugins imported this way do not show up in directories()
|
||||
|
@ -264,25 +234,6 @@ def load_plugins(plugin_type: PluginType):
|
|||
spec.loader.exec_module(plugins)
|
||||
regular_classes.update(get_regular_classes(plugins, spec.name, suffix))
|
||||
|
||||
# Configure override classes
|
||||
for _, klass in override_classes.items():
|
||||
plugin_name = getattr(klass, '_plugin_name', None)
|
||||
if not plugin_name:
|
||||
# these should always have plugin_name
|
||||
continue
|
||||
|
||||
mro = inspect.getmro(klass)
|
||||
super_class = klass.__wrapped__ = mro[mro.index(klass) + 1]
|
||||
klass.PLUGIN_NAME = plugin_name
|
||||
|
||||
if plugin_config.configure_override_func(klass, super_class, plugin_name) is False:
|
||||
continue
|
||||
|
||||
while getattr(super_class, '__wrapped__', None):
|
||||
super_class = super_class.__wrapped__
|
||||
setattr(sys.modules[super_class.__module__], super_class.__name__, klass)
|
||||
plugin_overrides.get()[super_class].append(klass)
|
||||
|
||||
# Add the classes into the global plugin lookup for that type
|
||||
plugin_config.plugin_destination.set(regular_classes)
|
||||
# We want to prepend to the main lookup for that type
|
||||
|
|
Loading…
Reference in a new issue