mirror of
https://github.com/yt-dlp/yt-dlp
synced 2025-01-01 06:21:09 +01:00
Get plugin overrides working
This commit is contained in:
parent
9f1f2c5410
commit
4266658602
2 changed files with 91 additions and 55 deletions
|
@ -6,7 +6,6 @@ import hashlib
|
|||
import http.client
|
||||
import http.cookiejar
|
||||
import http.cookies
|
||||
import inspect
|
||||
import itertools
|
||||
import json
|
||||
import math
|
||||
|
@ -31,7 +30,6 @@ from ..compat import (
|
|||
from ..cookies import LenientSimpleCookie
|
||||
from ..downloader.f4m import get_base_url, remove_encrypted_media
|
||||
from ..downloader.hls import HlsFD
|
||||
from ..globals import plugin_overrides
|
||||
from ..networking import HEADRequest, Request
|
||||
from ..networking.exceptions import (
|
||||
HTTPError,
|
||||
|
@ -3934,17 +3932,8 @@ class InfoExtractor:
|
|||
|
||||
@classmethod
|
||||
def __init_subclass__(cls, *, plugin_name=None, **kwargs):
|
||||
if plugin_name:
|
||||
mro = inspect.getmro(cls)
|
||||
super_class = cls.__wrapped__ = mro[mro.index(cls) + 1]
|
||||
cls.PLUGIN_NAME, cls.ie_key = plugin_name, super_class.ie_key
|
||||
cls.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}'
|
||||
while getattr(super_class, '__wrapped__', None):
|
||||
super_class = super_class.__wrapped__
|
||||
setattr(sys.modules[super_class.__module__], super_class.__name__, cls)
|
||||
plugin_overrides.get()[super_class].append(cls)
|
||||
# if plugin_name is not None:
|
||||
# cls._plugin_name = plugin_name
|
||||
if plugin_name is not None:
|
||||
cls._plugin_name = plugin_name
|
||||
return super().__init_subclass__(**kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import contextlib
|
||||
import dataclasses
|
||||
import enum
|
||||
import importlib
|
||||
import importlib.abc
|
||||
|
@ -10,7 +11,9 @@ import os
|
|||
import pkgutil
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
import zipimport
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
|
@ -19,7 +22,7 @@ from .globals import (
|
|||
plugin_dirs,
|
||||
plugin_ies,
|
||||
plugin_pps,
|
||||
postprocessors,
|
||||
postprocessors, plugin_overrides,
|
||||
)
|
||||
|
||||
from .compat import functools # isort: split
|
||||
|
@ -42,12 +45,6 @@ class PluginType(enum.Enum):
|
|||
EXTRACTORS = ('extractor', 'IE')
|
||||
|
||||
|
||||
_plugin_type_lookup = {
|
||||
PluginType.POSTPROCESSORS: (postprocessors, plugin_pps),
|
||||
PluginType.EXTRACTORS: (extractors, plugin_ies),
|
||||
}
|
||||
|
||||
|
||||
class PluginLoader(importlib.abc.Loader):
|
||||
"""Dummy loader for virtual namespace packages"""
|
||||
|
||||
|
@ -165,22 +162,74 @@ def iter_modules(subpackage):
|
|||
yield from pkgutil.iter_modules(path=pkg.__path__, prefix=f'{fullname}.')
|
||||
|
||||
|
||||
def load_module(module, module_name, suffix):
|
||||
result = inspect.getmembers(module, lambda obj: (
|
||||
def get_regular_modules(module, module_name, suffix):
|
||||
# Find standard public plugin classes (not overrides)
|
||||
return inspect.getmembers(module, lambda obj: (
|
||||
inspect.isclass(obj)
|
||||
and obj.__name__.endswith(suffix)
|
||||
and obj.__module__.startswith(module_name)
|
||||
and not obj.__name__.startswith('_')
|
||||
and obj.__name__ in getattr(module, '__all__', [obj.__name__])))
|
||||
return result
|
||||
and obj.__name__ in getattr(module, '__all__', [obj.__name__])
|
||||
and getattr(obj, '_plugin_name', None) is None
|
||||
))
|
||||
|
||||
|
||||
load_module = get_regular_modules
|
||||
|
||||
|
||||
def get_override_modules(module, module_name, suffix):
|
||||
# Find override plugin classes
|
||||
def predicate(obj):
|
||||
if not inspect.isclass(obj):
|
||||
return False
|
||||
mro = inspect.getmro(obj)
|
||||
return (
|
||||
obj.__module__.startswith(module_name)
|
||||
and getattr(obj, '_plugin_name', None) is not None
|
||||
and mro[mro.index(obj) + 1].__name__.endswith(suffix)
|
||||
)
|
||||
return inspect.getmembers(module, predicate)
|
||||
|
||||
|
||||
def configure_ie_override_class(klass, super_class, plugin_name):
|
||||
ie_key = getattr(super_class, 'ie_key', None)
|
||||
if not ie_key:
|
||||
warnings.warn(f'Override plugin {klass} is not an extractor')
|
||||
return False
|
||||
klass.ie_key = ie_key
|
||||
klass.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _PluginTypeConfig:
|
||||
destination: ContextVar
|
||||
plugin_destination: ContextVar
|
||||
# Function to configure the override class. Return False to skip the class
|
||||
# Takes (klass, super_class, plugin_name) as arguments
|
||||
configure_override_func: callable = lambda *args: None
|
||||
|
||||
|
||||
_plugin_type_lookup = {
|
||||
PluginType.POSTPROCESSORS: _PluginTypeConfig(
|
||||
destination=postprocessors,
|
||||
plugin_destination=plugin_pps,
|
||||
configure_override_func=None,
|
||||
),
|
||||
PluginType.EXTRACTORS: _PluginTypeConfig(
|
||||
destination=extractors,
|
||||
plugin_destination=plugin_ies,
|
||||
configure_override_func=configure_ie_override_class,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def load_plugins(plugin_type: PluginType):
|
||||
destination, plugin_destination = _plugin_type_lookup[plugin_type]
|
||||
plugin_config = _plugin_type_lookup[plugin_type]
|
||||
name, suffix = plugin_type.value
|
||||
classes = {}
|
||||
regular_classes = {}
|
||||
override_classes = {}
|
||||
if os.environ.get('YTDLP_NO_PLUGINS'):
|
||||
return classes
|
||||
return regular_classes
|
||||
|
||||
for finder, module_name, _ in iter_modules(name):
|
||||
if any(x.startswith('_') for x in module_name.split('.')):
|
||||
|
@ -201,7 +250,8 @@ def load_plugins(plugin_type: PluginType):
|
|||
f'Error while importing module {module_name!r}\n{traceback.format_exc(limit=-1)}',
|
||||
)
|
||||
continue
|
||||
classes.update(load_module(module, module_name, suffix))
|
||||
regular_classes.update(get_regular_modules(module, module_name, suffix))
|
||||
override_classes.update(get_override_modules(module, module_name, suffix))
|
||||
|
||||
# Compat: old plugin system using __init__.py
|
||||
# Note: plugins imported this way do not show up in directories()
|
||||
|
@ -215,41 +265,38 @@ def load_plugins(plugin_type: PluginType):
|
|||
plugins = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = plugins
|
||||
spec.loader.exec_module(plugins)
|
||||
classes.update(load_module(plugins, spec.name, suffix))
|
||||
regular_classes.update(get_regular_modules(plugins, spec.name, suffix))
|
||||
|
||||
# regular_plugins = {}
|
||||
# __init_subclass__ was removed so we manually add overrides
|
||||
# for name, klass in classes.items():
|
||||
# plugin_name = getattr(klass, '_plugin_name', None)
|
||||
# if not plugin_name:
|
||||
# regular_plugins[name] = klass
|
||||
# continue
|
||||
# Configure override classes
|
||||
for name, klass in override_classes.items():
|
||||
plugin_name = getattr(klass, '_plugin_name', None)
|
||||
if not plugin_name:
|
||||
# these should always have plugin_name
|
||||
continue
|
||||
|
||||
# FIXME: Most likely something wrong here
|
||||
# This does not work as plugin overrides are not available here. They are not imported in plugin_ies.
|
||||
mro = inspect.getmro(klass)
|
||||
super_class = klass.__wrapped__ = mro[mro.index(klass) + 1]
|
||||
klass.PLUGIN_NAME = plugin_name
|
||||
|
||||
# mro = inspect.getmro(klass)
|
||||
# super_class = klass.__wrapped__ = mro[mro.index(klass) + 1]
|
||||
# klass.PLUGIN_NAME, klass.ie_key = plugin_name, super_class.ie_key
|
||||
# klass.IE_NAME = f'{super_class.IE_NAME}+{plugin_name}'
|
||||
# while getattr(super_class, '__wrapped__', None):
|
||||
# super_class = super_class.__wrapped__
|
||||
# setattr(sys.modules[super_class.__module__], super_class.__name__, klass)
|
||||
# plugin_overrides.get()[super_class].append(klass)
|
||||
if plugin_config.configure_override_func(klass, super_class, plugin_name) is False:
|
||||
continue
|
||||
|
||||
# Add the classes into the global plugin lookup
|
||||
plugin_destination.set(classes)
|
||||
# # We want to prepend to the main lookup
|
||||
destination.set(merge_dicts(destination.get(), classes))
|
||||
while getattr(super_class, '__wrapped__', None):
|
||||
super_class = super_class.__wrapped__
|
||||
setattr(sys.modules[super_class.__module__], super_class.__name__, klass)
|
||||
plugin_overrides.get()[super_class].append(klass)
|
||||
|
||||
return classes
|
||||
# Add the classes into the global plugin lookup for that type
|
||||
plugin_config.plugin_destination.set(regular_classes)
|
||||
# We want to prepend to the main lookup for that type
|
||||
plugin_config.destination.set(merge_dicts(plugin_config.destination.get(), regular_classes))
|
||||
|
||||
return regular_classes
|
||||
|
||||
|
||||
def load_all_plugin_types():
|
||||
# for plugin_type in PluginType:
|
||||
# load_plugins(plugin_type)
|
||||
load_plugins(PluginType.EXTRACTORS)
|
||||
|
||||
for plugin_type in PluginType:
|
||||
load_plugins(plugin_type)
|
||||
|
||||
sys.meta_path.insert(0, PluginFinder(f'{PACKAGE_NAME}.extractor', f'{PACKAGE_NAME}.postprocessor'))
|
||||
|
||||
|
|
Loading…
Reference in a new issue