From 42666586028c50c1283e181f49363364cf50ee18 Mon Sep 17 00:00:00 2001 From: coletdjnz Date: Sat, 19 Oct 2024 14:58:33 +1300 Subject: [PATCH] Get plugin overrides working --- yt_dlp/extractor/common.py | 15 +---- yt_dlp/plugins.py | 131 +++++++++++++++++++++++++------------ 2 files changed, 91 insertions(+), 55 deletions(-) diff --git a/yt_dlp/extractor/common.py b/yt_dlp/extractor/common.py index c435abdcb..68909d9d6 100644 --- a/yt_dlp/extractor/common.py +++ b/yt_dlp/extractor/common.py @@ -6,7 +6,6 @@ import hashlib import http.client import http.cookiejar import http.cookies -import inspect import itertools import json import math @@ -31,7 +30,6 @@ from ..compat import ( from ..cookies import LenientSimpleCookie from ..downloader.f4m import get_base_url, remove_encrypted_media from ..downloader.hls import HlsFD -from ..globals import plugin_overrides from ..networking import HEADRequest, Request from ..networking.exceptions import ( HTTPError, @@ -3934,17 +3932,8 @@ class InfoExtractor: @classmethod def __init_subclass__(cls, *, plugin_name=None, **kwargs): - if plugin_name: - mro = inspect.getmro(cls) - super_class = cls.__wrapped__ = mro[mro.index(cls) + 1] - cls.PLUGIN_NAME, cls.ie_key = plugin_name, super_class.ie_key - cls.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}' - while getattr(super_class, '__wrapped__', None): - super_class = super_class.__wrapped__ - setattr(sys.modules[super_class.__module__], super_class.__name__, cls) - plugin_overrides.get()[super_class].append(cls) - # if plugin_name is not None: - # cls._plugin_name = plugin_name + if plugin_name is not None: + cls._plugin_name = plugin_name return super().__init_subclass__(**kwargs) diff --git a/yt_dlp/plugins.py b/yt_dlp/plugins.py index 67cf4a869..c0888ee2f 100644 --- a/yt_dlp/plugins.py +++ b/yt_dlp/plugins.py @@ -1,4 +1,5 @@ import contextlib +import dataclasses import enum import importlib import importlib.abc @@ -10,7 +11,9 @@ import os import pkgutil import sys import traceback +import warnings import zipimport +from contextvars import ContextVar from pathlib import Path from zipfile import ZipFile @@ -19,7 +22,7 @@ from .globals import ( plugin_dirs, plugin_ies, plugin_pps, - postprocessors, + postprocessors, plugin_overrides, ) from .compat import functools # isort: split @@ -42,12 +45,6 @@ class PluginType(enum.Enum): EXTRACTORS = ('extractor', 'IE') -_plugin_type_lookup = { - PluginType.POSTPROCESSORS: (postprocessors, plugin_pps), - PluginType.EXTRACTORS: (extractors, plugin_ies), -} - - class PluginLoader(importlib.abc.Loader): """Dummy loader for virtual namespace packages""" @@ -165,22 +162,74 @@ def iter_modules(subpackage): yield from pkgutil.iter_modules(path=pkg.__path__, prefix=f'{fullname}.') -def load_module(module, module_name, suffix): - result = inspect.getmembers(module, lambda obj: ( +def get_regular_modules(module, module_name, suffix): + # Find standard public plugin classes (not overrides) + return inspect.getmembers(module, lambda obj: ( inspect.isclass(obj) and obj.__name__.endswith(suffix) and obj.__module__.startswith(module_name) and not obj.__name__.startswith('_') - and obj.__name__ in getattr(module, '__all__', [obj.__name__]))) - return result + and obj.__name__ in getattr(module, '__all__', [obj.__name__]) + and getattr(obj, '_plugin_name', None) is None + )) + + +load_module = get_regular_modules + + +def get_override_modules(module, module_name, suffix): + # Find override plugin classes + def predicate(obj): + if not inspect.isclass(obj): + return False + mro = inspect.getmro(obj) + return ( + obj.__module__.startswith(module_name) + and getattr(obj, '_plugin_name', None) is not None + and mro[mro.index(obj) + 1].__name__.endswith(suffix) + ) + return inspect.getmembers(module, predicate) + + +def configure_ie_override_class(klass, super_class, plugin_name): + ie_key = getattr(super_class, 'ie_key', None) + if not ie_key: + warnings.warn(f'Override plugin {klass} is not an extractor') + return False + klass.ie_key = ie_key + klass.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}' + + +@dataclasses.dataclass +class _PluginTypeConfig: + destination: ContextVar + plugin_destination: ContextVar + # Function to configure the override class. Return False to skip the class + # Takes (klass, super_class, plugin_name) as arguments + configure_override_func: callable = lambda *args: None + + +_plugin_type_lookup = { + PluginType.POSTPROCESSORS: _PluginTypeConfig( + destination=postprocessors, + plugin_destination=plugin_pps, + configure_override_func=None, + ), + PluginType.EXTRACTORS: _PluginTypeConfig( + destination=extractors, + plugin_destination=plugin_ies, + configure_override_func=configure_ie_override_class, + ), +} def load_plugins(plugin_type: PluginType): - destination, plugin_destination = _plugin_type_lookup[plugin_type] + plugin_config = _plugin_type_lookup[plugin_type] name, suffix = plugin_type.value - classes = {} + regular_classes = {} + override_classes = {} if os.environ.get('YTDLP_NO_PLUGINS'): - return classes + return regular_classes for finder, module_name, _ in iter_modules(name): if any(x.startswith('_') for x in module_name.split('.')): @@ -201,7 +250,8 @@ def load_plugins(plugin_type: PluginType): f'Error while importing module {module_name!r}\n{traceback.format_exc(limit=-1)}', ) continue - classes.update(load_module(module, module_name, suffix)) + regular_classes.update(get_regular_modules(module, module_name, suffix)) + override_classes.update(get_override_modules(module, module_name, suffix)) # Compat: old plugin system using __init__.py # Note: plugins imported this way do not show up in directories() @@ -215,41 +265,38 @@ def load_plugins(plugin_type: PluginType): plugins = importlib.util.module_from_spec(spec) sys.modules[spec.name] = plugins spec.loader.exec_module(plugins) - classes.update(load_module(plugins, spec.name, suffix)) + regular_classes.update(get_regular_modules(plugins, spec.name, suffix)) - # regular_plugins = {} - # __init_subclass__ was removed so we manually add overrides - # for name, klass in classes.items(): - # plugin_name = getattr(klass, '_plugin_name', None) - # if not plugin_name: - # regular_plugins[name] = klass - # continue + # Configure override classes + for name, klass in override_classes.items(): + plugin_name = getattr(klass, '_plugin_name', None) + if not plugin_name: + # these should always have plugin_name + continue - # FIXME: Most likely something wrong here - # This does not work as plugin overrides are not available here. They are not imported in plugin_ies. + mro = inspect.getmro(klass) + super_class = klass.__wrapped__ = mro[mro.index(klass) + 1] + klass.PLUGIN_NAME = plugin_name - # mro = inspect.getmro(klass) - # super_class = klass.__wrapped__ = mro[mro.index(klass) + 1] - # klass.PLUGIN_NAME, klass.ie_key = plugin_name, super_class.ie_key - # klass.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}' - # while getattr(super_class, '__wrapped__', None): - # super_class = super_class.__wrapped__ - # setattr(sys.modules[super_class.__module__], super_class.__name__, klass) - # plugin_overrides.get()[super_class].append(klass) + if plugin_config.configure_override_func(klass, super_class, plugin_name) is False: + continue - # Add the classes into the global plugin lookup - plugin_destination.set(classes) - # # We want to prepend to the main lookup - destination.set(merge_dicts(destination.get(), classes)) + while getattr(super_class, '__wrapped__', None): + super_class = super_class.__wrapped__ + setattr(sys.modules[super_class.__module__], super_class.__name__, klass) + plugin_overrides.get()[super_class].append(klass) - return classes + # Add the classes into the global plugin lookup for that type + plugin_config.plugin_destination.set(regular_classes) + # We want to prepend to the main lookup for that type + plugin_config.destination.set(merge_dicts(plugin_config.destination.get(), regular_classes)) + + return regular_classes def load_all_plugin_types(): - # for plugin_type in PluginType: - # load_plugins(plugin_type) - load_plugins(PluginType.EXTRACTORS) - + for plugin_type in PluginType: + load_plugins(plugin_type) sys.meta_path.insert(0, PluginFinder(f'{PACKAGE_NAME}.extractor', f'{PACKAGE_NAME}.postprocessor'))