mirror of
https://github.com/yt-dlp/yt-dlp
synced 2025-01-29 20:35:06 +01:00
revert back to init_subclass, add guard against multiple imports of same plugin
This commit is contained in:
parent
42771dde1c
commit
a19dd28fdc
3 changed files with 24 additions and 54 deletions
|
@ -116,7 +116,7 @@ class TestPlugins(unittest.TestCase):
|
||||||
for module_name in tuple(sys.modules):
|
for module_name in tuple(sys.modules):
|
||||||
if module_name.startswith(f'{PACKAGE_NAME}.extractor'):
|
if module_name.startswith(f'{PACKAGE_NAME}.extractor'):
|
||||||
del sys.modules[module_name]
|
del sys.modules[module_name]
|
||||||
plugins_ie = load_plugins(PluginType.EXTRACTORS)
|
load_plugins(PluginType.EXTRACTORS)
|
||||||
|
|
||||||
from yt_dlp.extractor.generic import GenericIE
|
from yt_dlp.extractor.generic import GenericIE
|
||||||
|
|
||||||
|
@ -124,6 +124,11 @@ class TestPlugins(unittest.TestCase):
|
||||||
self.assertEqual(GenericIE.SECONDARY_TEST_FIELD, 'underscore-override')
|
self.assertEqual(GenericIE.SECONDARY_TEST_FIELD, 'underscore-override')
|
||||||
|
|
||||||
self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
|
self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
|
||||||
|
importlib.invalidate_caches()
|
||||||
|
# test that loading a second time doesn't wrap a second time
|
||||||
|
load_plugins(PluginType.EXTRACTORS)
|
||||||
|
from yt_dlp.extractor.generic import GenericIE
|
||||||
|
self.assertEqual(GenericIE.IE_NAME, 'generic+override+underscore-override')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -6,6 +6,7 @@ import hashlib
|
||||||
import http.client
|
import http.client
|
||||||
import http.cookiejar
|
import http.cookiejar
|
||||||
import http.cookies
|
import http.cookies
|
||||||
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
@ -21,6 +22,7 @@ import urllib.parse
|
||||||
import urllib.request
|
import urllib.request
|
||||||
import xml.etree.ElementTree
|
import xml.etree.ElementTree
|
||||||
|
|
||||||
|
from .._globals import plugin_overrides
|
||||||
from ..compat import (
|
from ..compat import (
|
||||||
compat_etree_fromstring,
|
compat_etree_fromstring,
|
||||||
compat_expanduser,
|
compat_expanduser,
|
||||||
|
@ -3933,7 +3935,19 @@ class InfoExtractor:
|
||||||
@classmethod
|
@classmethod
|
||||||
def __init_subclass__(cls, *, plugin_name=None, **kwargs):
|
def __init_subclass__(cls, *, plugin_name=None, **kwargs):
|
||||||
if plugin_name is not None:
|
if plugin_name is not None:
|
||||||
cls._plugin_name = plugin_name
|
mro = inspect.getmro(cls)
|
||||||
|
next_mro_class = super_class = mro[mro.index(cls) + 1]
|
||||||
|
|
||||||
|
while getattr(super_class, '__wrapped__', None):
|
||||||
|
super_class = super_class.__wrapped__
|
||||||
|
|
||||||
|
if not any(override.PLUGIN_NAME == plugin_name for override in plugin_overrides.get()[super_class]):
|
||||||
|
cls.__wrapped__ = next_mro_class
|
||||||
|
cls.PLUGIN_NAME, cls.ie_key = plugin_name, next_mro_class.ie_key
|
||||||
|
cls.IE_NAME = f'{next_mro_class.IE_NAME}+{plugin_name}'
|
||||||
|
|
||||||
|
setattr(sys.modules[super_class.__module__], super_class.__name__, cls)
|
||||||
|
plugin_overrides.get()[super_class].append(cls)
|
||||||
return super().__init_subclass__(**kwargs)
|
return super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,6 @@ import os
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
|
||||||
import zipimport
|
import zipimport
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -22,7 +21,8 @@ from ._globals import (
|
||||||
plugin_dirs,
|
plugin_dirs,
|
||||||
plugin_ies,
|
plugin_ies,
|
||||||
plugin_pps,
|
plugin_pps,
|
||||||
postprocessors, plugin_overrides, ALL_PLUGINS_LOADED,
|
postprocessors,
|
||||||
|
ALL_PLUGINS_LOADED,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .compat import functools # isort: split
|
from .compat import functools # isort: split
|
||||||
|
@ -170,52 +170,24 @@ def get_regular_classes(module, module_name, suffix):
|
||||||
and obj.__module__.startswith(module_name)
|
and obj.__module__.startswith(module_name)
|
||||||
and not obj.__name__.startswith('_')
|
and not obj.__name__.startswith('_')
|
||||||
and obj.__name__ in getattr(module, '__all__', [obj.__name__])
|
and obj.__name__ in getattr(module, '__all__', [obj.__name__])
|
||||||
and getattr(obj, '_plugin_name', None) is None
|
and getattr(obj, 'PLUGIN_NAME', None) is None
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
def get_override_classes(module, module_name, suffix):
|
|
||||||
# Find override plugin classes
|
|
||||||
def predicate(obj):
|
|
||||||
if not inspect.isclass(obj):
|
|
||||||
return False
|
|
||||||
mro = inspect.getmro(obj)
|
|
||||||
return (
|
|
||||||
obj.__module__.startswith(module_name)
|
|
||||||
and getattr(obj, '_plugin_name', None) is not None
|
|
||||||
and mro[mro.index(obj) + 1].__name__.endswith(suffix)
|
|
||||||
)
|
|
||||||
return inspect.getmembers(module, predicate)
|
|
||||||
|
|
||||||
|
|
||||||
def configure_ie_override_class(klass, super_class, plugin_name):
|
|
||||||
ie_key = getattr(super_class, 'ie_key', None)
|
|
||||||
if not ie_key:
|
|
||||||
warnings.warn(f'Override plugin {klass} is not an extractor')
|
|
||||||
return False
|
|
||||||
klass.ie_key = ie_key
|
|
||||||
klass.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}'
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class _PluginTypeConfig:
|
class _PluginTypeConfig:
|
||||||
destination: ContextVar
|
destination: ContextVar
|
||||||
plugin_destination: ContextVar
|
plugin_destination: ContextVar
|
||||||
# Function to configure the override class. Return False to skip the class
|
|
||||||
# Takes (klass, super_class, plugin_name) as arguments
|
|
||||||
configure_override_func: callable = lambda *args: None
|
|
||||||
|
|
||||||
|
|
||||||
_plugin_type_lookup = {
|
_plugin_type_lookup = {
|
||||||
PluginType.POSTPROCESSORS: _PluginTypeConfig(
|
PluginType.POSTPROCESSORS: _PluginTypeConfig(
|
||||||
destination=postprocessors,
|
destination=postprocessors,
|
||||||
plugin_destination=plugin_pps,
|
plugin_destination=plugin_pps,
|
||||||
configure_override_func=None,
|
|
||||||
),
|
),
|
||||||
PluginType.EXTRACTORS: _PluginTypeConfig(
|
PluginType.EXTRACTORS: _PluginTypeConfig(
|
||||||
destination=extractors,
|
destination=extractors,
|
||||||
plugin_destination=plugin_ies,
|
plugin_destination=plugin_ies,
|
||||||
configure_override_func=configure_ie_override_class,
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,7 +196,6 @@ def load_plugins(plugin_type: PluginType):
|
||||||
plugin_config = _plugin_type_lookup[plugin_type]
|
plugin_config = _plugin_type_lookup[plugin_type]
|
||||||
name, suffix = plugin_type.value
|
name, suffix = plugin_type.value
|
||||||
regular_classes = {}
|
regular_classes = {}
|
||||||
override_classes = {}
|
|
||||||
if os.environ.get('YTDLP_NO_PLUGINS'):
|
if os.environ.get('YTDLP_NO_PLUGINS'):
|
||||||
return regular_classes
|
return regular_classes
|
||||||
|
|
||||||
|
@ -248,7 +219,6 @@ def load_plugins(plugin_type: PluginType):
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
regular_classes.update(get_regular_classes(module, module_name, suffix))
|
regular_classes.update(get_regular_classes(module, module_name, suffix))
|
||||||
override_classes.update(get_override_classes(module, module_name, suffix))
|
|
||||||
|
|
||||||
# Compat: old plugin system using __init__.py
|
# Compat: old plugin system using __init__.py
|
||||||
# Note: plugins imported this way do not show up in directories()
|
# Note: plugins imported this way do not show up in directories()
|
||||||
|
@ -264,25 +234,6 @@ def load_plugins(plugin_type: PluginType):
|
||||||
spec.loader.exec_module(plugins)
|
spec.loader.exec_module(plugins)
|
||||||
regular_classes.update(get_regular_classes(plugins, spec.name, suffix))
|
regular_classes.update(get_regular_classes(plugins, spec.name, suffix))
|
||||||
|
|
||||||
# Configure override classes
|
|
||||||
for _, klass in override_classes.items():
|
|
||||||
plugin_name = getattr(klass, '_plugin_name', None)
|
|
||||||
if not plugin_name:
|
|
||||||
# these should always have plugin_name
|
|
||||||
continue
|
|
||||||
|
|
||||||
mro = inspect.getmro(klass)
|
|
||||||
super_class = klass.__wrapped__ = mro[mro.index(klass) + 1]
|
|
||||||
klass.PLUGIN_NAME = plugin_name
|
|
||||||
|
|
||||||
if plugin_config.configure_override_func(klass, super_class, plugin_name) is False:
|
|
||||||
continue
|
|
||||||
|
|
||||||
while getattr(super_class, '__wrapped__', None):
|
|
||||||
super_class = super_class.__wrapped__
|
|
||||||
setattr(sys.modules[super_class.__module__], super_class.__name__, klass)
|
|
||||||
plugin_overrides.get()[super_class].append(klass)
|
|
||||||
|
|
||||||
# Add the classes into the global plugin lookup for that type
|
# Add the classes into the global plugin lookup for that type
|
||||||
plugin_config.plugin_destination.set(regular_classes)
|
plugin_config.plugin_destination.set(regular_classes)
|
||||||
# We want to prepend to the main lookup for that type
|
# We want to prepend to the main lookup for that type
|
||||||
|
|
Loading…
Add table
Reference in a new issue