From 27bda89511666223b5a61d9a25f65619584b0282 Mon Sep 17 00:00:00 2001 From: Simon Sawicki Date: Tue, 26 Nov 2024 20:05:48 +0100 Subject: [PATCH] Implement `keep_header_casing` extension --- yt_dlp/networking/_requests.py | 4 + yt_dlp/utils/networking.py | 148 +++++++++++++++++++++++++++++---- 2 files changed, 136 insertions(+), 16 deletions(-) diff --git a/yt_dlp/networking/_requests.py b/yt_dlp/networking/_requests.py index 7de95ab3b..7af7d475d 100644 --- a/yt_dlp/networking/_requests.py +++ b/yt_dlp/networking/_requests.py @@ -296,6 +296,7 @@ class RequestsRH(RequestHandler, InstanceStoreMixin): extensions.pop('cookiejar', None) extensions.pop('timeout', None) extensions.pop('legacy_ssl', None) + extensions.pop('keep_header_casing', None) def _create_instance(self, cookiejar, legacy_ssl_support=None): session = RequestsSession() @@ -324,6 +325,9 @@ class RequestsRH(RequestHandler, InstanceStoreMixin): legacy_ssl_support=request.extensions.get('legacy_ssl'), ) + if request.extensions.get('keep_header_casing'): + headers = headers.sensitive() + try: requests_res = session.request( method=request.method, diff --git a/yt_dlp/utils/networking.py b/yt_dlp/utils/networking.py index 933b164be..542abace8 100644 --- a/yt_dlp/utils/networking.py +++ b/yt_dlp/utils/networking.py @@ -1,9 +1,16 @@ +from __future__ import annotations + import collections +import collections.abc import random +import typing import urllib.parse import urllib.request -from ._utils import remove_start +if typing.TYPE_CHECKING: + T = typing.TypeVar('T') + +from ._utils import NO_DEFAULT, remove_start def random_user_agent(): @@ -51,32 +58,141 @@ def random_user_agent(): return _USER_AGENT_TPL % random.choice(_CHROME_VERSIONS) -class HTTPHeaderDict(collections.UserDict, dict): +class HTTPHeaderDict(dict): """ Store and access keys case-insensitively. The constructor can take multiple dicts, in which keys in the latter are prioritised. + + Retains a case sensitive mapping of the headers, which can be accessed via `.sensitive()`. """ + def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> typing.Self: + obj = dict.__new__(cls, *args, **kwargs) + obj.__sensitive_map = {} + return obj - def __init__(self, *args, **kwargs): + def __init__(self, /, *args, **kwargs): super().__init__() - for dct in args: - if dct is not None: - self.update(dct) - self.update(kwargs) + self.__sensitive_map = {} - def __setitem__(self, key, value): - if isinstance(value, bytes): - value = value.decode('latin-1') - super().__setitem__(key.title(), str(value).strip()) + for dct in filter(None, args): + self.update(dct) + if kwargs: + self.update(kwargs) - def __getitem__(self, key): + def sensitive(self, /) -> dict[str, str]: + return { + self.__sensitive_map[key]: value + for key, value in self.items() + } + + def __contains__(self, key: str, /) -> bool: + return super().__contains__(key.title() if isinstance(key, str) else key) + + def __delitem__(self, key: str, /) -> None: + key = key.title() + del self.__sensitive_map[key] + super().__delitem__(key) + + def __getitem__(self, key, /) -> str: return super().__getitem__(key.title()) - def __delitem__(self, key): - super().__delitem__(key.title()) + def __ior__(self, other, /): + if isinstance(other, type(self)): + other = other.sensitive() + if isinstance(other, dict): + self.update(other) + return + return NotImplemented - def __contains__(self, key): - return super().__contains__(key.title() if isinstance(key, str) else key) + def __or__(self, other, /) -> typing.Self: + if isinstance(other, type(self)): + other = other.sensitive() + if isinstance(other, dict): + return type(self)(self.sensitive(), other) + return NotImplemented + + def __ror__(self, other, /) -> typing.Self: + if isinstance(other, type(self)): + other = other.sensitive() + if isinstance(other, dict): + return type(self)(other, self.sensitive()) + return NotImplemented + + def __setitem__(self, key: str, value, /) -> None: + if isinstance(value, bytes): + value = value.decode('latin-1') + key_title = key.title() + self.__sensitive_map[key_title] = key + super().__setitem__(key_title, str(value).strip()) + + def clear(self, /) -> None: + self.__sensitive_map.clear() + super().clear() + + def copy(self, /) -> typing.Self: + return type(self)(self.sensitive()) + + @typing.overload + def get(self, key: str, /) -> str | None: ... + + @typing.overload + def get(self, key: str, /, default: T) -> str | T: ... + + def get(self, key, /, default=NO_DEFAULT): + key = key.title() + if default is NO_DEFAULT: + return super().get(key) + return super().get(key, default) + + @typing.overload + def pop(self, key: str, /) -> str: ... + + @typing.overload + def pop(self, key: str, /, default: T) -> str | T: ... + + def pop(self, key, /, default=NO_DEFAULT): + key = key.title() + if default is NO_DEFAULT: + self.__sensitive_map.pop(key) + return super().pop(key) + self.__sensitive_map.pop(key, default) + return super().pop(key, default) + + def popitem(self) -> tuple[str, str]: + self.__sensitive_map.popitem() + return super().popitem() + + @typing.overload + def setdefault(self, key: str, /) -> str: ... + + @typing.overload + def setdefault(self, key: str, /, default) -> str: ... + + def setdefault(self, key, /, default=None) -> str: + key = key.title() + if key in self.__sensitive_map: + return super().__getitem__(key) + + self[key] = default or '' + return self[key] + + def update(self, other, /, **kwargs) -> None: + if isinstance(other, type(self)): + other = other.sensitive() + if isinstance(other, collections.abc.Mapping): + for key, value in other.items(): + self[key] = value + + elif hasattr(other, 'keys'): + for key in other.keys(): # noqa: SIM118 + self[key] = other[key] + + else: + for key, value in other: + self[key] = value + + for key, value in kwargs.items(): + self[key] = value std_headers = HTTPHeaderDict({