From 8e45684816a0743e282a8267dcd30b0fdbc60f7b Mon Sep 17 00:00:00 2001 From: coletdjnz Date: Sat, 30 Nov 2024 11:01:15 +1300 Subject: [PATCH] suggestion to merge logic into core --- test/test_networking.py | 13 +++++++++++++ test/test_websockets.py | 22 +++++++++++++++------- yt_dlp/networking/_curlcffi.py | 1 + yt_dlp/networking/_requests.py | 9 ++++----- yt_dlp/networking/_urllib.py | 8 +++++--- yt_dlp/networking/_websockets.py | 8 +++++--- yt_dlp/networking/common.py | 19 +++++++++++++++++++ yt_dlp/networking/impersonate.py | 22 ++++++++++++++++++---- 8 files changed, 80 insertions(+), 22 deletions(-) diff --git a/test/test_networking.py b/test/test_networking.py index d96624af18..f4bc4f08c3 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -720,6 +720,15 @@ class TestHTTPRequestHandler(TestRequestHandlerBase): rh, Request( f'http://127.0.0.1:{self.http_port}/headers', proxies={'all': 'http://10.255.255.255'})).close() + @pytest.mark.skip_handler('Urllib', 'urllib handler does not support keep_header_casing') + def test_keep_header_casing(self, handler): + with handler() as rh: + res = validate_and_send( + rh, Request( + f'http://127.0.0.1:{self.http_port}/headers', headers={'X-test-heaDer': 'test'}, extensions={'keep_header_casing': True})).read().decode() + + assert 'X-test-heaDer: test' in res + @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True) class TestClientCertificate: @@ -1289,6 +1298,7 @@ class TestRequestHandlerValidation: ({'legacy_ssl': False}, False), ({'legacy_ssl': True}, False), ({'legacy_ssl': 'notabool'}, AssertionError), + ({'keep_header_casing': True}, UnsupportedRequest), ]), ('Requests', 'http', [ ({'cookiejar': 'notacookiejar'}, AssertionError), @@ -1299,6 +1309,9 @@ class TestRequestHandlerValidation: ({'legacy_ssl': False}, False), ({'legacy_ssl': True}, False), ({'legacy_ssl': 'notabool'}, AssertionError), + ({'keep_header_casing': False}, False), + ({'keep_header_casing': True}, False), + ({'keep_header_casing': 'notabool'}, AssertionError), ]), ('CurlCFFI', 'http', [ ({'cookiejar': 'notacookiejar'}, AssertionError), diff --git a/test/test_websockets.py b/test/test_websockets.py index 06112cc0b8..dead5fe5c5 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -44,7 +44,7 @@ def websocket_handler(websocket): return websocket.send('2') elif isinstance(message, str): if message == 'headers': - return websocket.send(json.dumps(dict(websocket.request.headers))) + return websocket.send(json.dumps(dict(websocket.request.headers.raw_items()))) elif message == 'path': return websocket.send(websocket.request.path) elif message == 'source_address': @@ -266,18 +266,18 @@ class TestWebsSocketRequestHandlerConformance: with handler(cookiejar=cookiejar) as rh: ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send('headers') - assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' + assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp' ws.close() with handler() as rh: ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send('headers') - assert 'cookie' not in json.loads(ws.recv()) + assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv())) ws.close() ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) ws.send('headers') - assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' + assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp' ws.close() @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets') @@ -287,7 +287,7 @@ class TestWebsSocketRequestHandlerConformance: ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie', extensions={'cookiejar': YoutubeDLCookieJar()})) ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': YoutubeDLCookieJar()})) ws.send('headers') - assert 'cookie' not in json.loads(ws.recv()) + assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv())) ws.close() @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets') @@ -298,12 +298,12 @@ class TestWebsSocketRequestHandlerConformance: ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie')) ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send('headers') - assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' + assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp' ws.close() cookiejar.clear_session_cookies() ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send('headers') - assert 'cookie' not in json.loads(ws.recv()) + assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv())) ws.close() def test_source_address(self, handler): @@ -341,6 +341,14 @@ class TestWebsSocketRequestHandlerConformance: assert headers['test3'] == 'test3' ws.close() + def test_keep_header_casing(self, handler): + with handler(headers=HTTPHeaderDict({'x-TeSt1': 'test'})) as rh: + ws = ws_validate_and_send(rh, Request(self.ws_base_url, headers={'x-TeSt2': 'test'}, extensions={'keep_header_casing': True})) + ws.send('headers') + headers = json.loads(ws.recv()) + assert 'x-TeSt1' in headers + assert 'x-TeSt2' in headers + @pytest.mark.parametrize('client_cert', ( {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}, { diff --git a/yt_dlp/networking/_curlcffi.py b/yt_dlp/networking/_curlcffi.py index 0643348e7e..c81bd60c83 100644 --- a/yt_dlp/networking/_curlcffi.py +++ b/yt_dlp/networking/_curlcffi.py @@ -149,6 +149,7 @@ class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin): # CurlCFFIRH ignores legacy ssl options currently. # Impersonation generally uses a looser SSL configuration than urllib/requests. extensions.pop('legacy_ssl', None) + extensions.pop('keep_header_casing', None) def send(self, request: Request) -> Response: target = self._get_request_target(request) diff --git a/yt_dlp/networking/_requests.py b/yt_dlp/networking/_requests.py index 6068e18655..23775845d6 100644 --- a/yt_dlp/networking/_requests.py +++ b/yt_dlp/networking/_requests.py @@ -313,13 +313,12 @@ class RequestsRH(RequestHandler, InstanceStoreMixin): session.trust_env = False # no need, we already load proxies from env return session + def _prepare_headers(self, _, headers): + add_accept_encoding_header(headers, SUPPORTED_ENCODINGS) + def _send(self, request): - headers = self._merge_headers(request.headers) - add_accept_encoding_header(headers, SUPPORTED_ENCODINGS) - if request.extensions.get('keep_header_casing'): - headers = headers.sensitive() - + headers = self._get_headers(request) max_redirects_exceeded = False session = self._get_instance( diff --git a/yt_dlp/networking/_urllib.py b/yt_dlp/networking/_urllib.py index 510bb2a691..a188b35f57 100644 --- a/yt_dlp/networking/_urllib.py +++ b/yt_dlp/networking/_urllib.py @@ -379,13 +379,15 @@ class UrllibRH(RequestHandler, InstanceStoreMixin): opener.addheaders = [] return opener - def _send(self, request): - headers = self._merge_headers(request.headers) + def _prepare_headers(self, _, headers): add_accept_encoding_header(headers, SUPPORTED_ENCODINGS) + + def _send(self, request): + headers = self._get_headers(request) urllib_req = urllib.request.Request( url=request.url, data=request.data, - headers=dict(headers), + headers=headers, method=request.method, ) diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py index ec55567dae..7e5ab46004 100644 --- a/yt_dlp/networking/_websockets.py +++ b/yt_dlp/networking/_websockets.py @@ -116,6 +116,7 @@ class WebsocketsRH(WebSocketRequestHandler): extensions.pop('timeout', None) extensions.pop('cookiejar', None) extensions.pop('legacy_ssl', None) + extensions.pop('keep_header_casing', None) def close(self): # Remove the logging handler that contains a reference to our logger @@ -123,15 +124,16 @@ class WebsocketsRH(WebSocketRequestHandler): for name, handler in self.__logging_handlers.items(): logging.getLogger(name).removeHandler(handler) - def _send(self, request): - timeout = self._calculate_timeout(request) - headers = self._merge_headers(request.headers) + def _prepare_headers(self, request, headers): if 'cookie' not in headers: cookiejar = self._get_cookiejar(request) cookie_header = cookiejar.get_cookie_header(request.url) if cookie_header: headers['cookie'] = cookie_header + def _send(self, request): + timeout = self._calculate_timeout(request) + headers = self._get_headers(request) wsuri = parse_uri(request.url) create_conn_kwargs = { 'source_address': (self.source_address, 0) if self.source_address else None, diff --git a/yt_dlp/networking/common.py b/yt_dlp/networking/common.py index e8951c7e7d..ddceaa9a97 100644 --- a/yt_dlp/networking/common.py +++ b/yt_dlp/networking/common.py @@ -206,6 +206,7 @@ class RequestHandler(abc.ABC): - `cookiejar`: Cookiejar to use for this request. - `timeout`: socket timeout to use for this request. - `legacy_ssl`: Enable legacy SSL options for this request. See legacy_ssl_support. + - `keep_header_casing`: Keep the casing of headers when sending the request. To enable these, add extensions.pop('', None) to _check_extensions Apart from the url protocol, proxies dict may contain the following keys: @@ -259,6 +260,23 @@ class RequestHandler(abc.ABC): def _merge_headers(self, request_headers): return HTTPHeaderDict(self.headers, request_headers) + def _prepare_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027 + """Additional operations to prepare headers before building. To be extended by subclasses. + @param request: Request object + @param headers: Merged headers to prepare + """ + + def _get_headers(self, request: Request) -> dict[str, str]: + """ + Get headers for external use. + Subclasses may define a _prepare_headers method to modify headers after merge but before building. + """ + headers = self._merge_headers(request.headers) + self._prepare_headers(request, headers) + if request.extensions.get('keep_header_casing'): + return headers.sensitive() + return dict(headers) + def _calculate_timeout(self, request): return float(request.extensions.get('timeout') or self.timeout) @@ -317,6 +335,7 @@ class RequestHandler(abc.ABC): assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType)) assert isinstance(extensions.get('timeout'), (float, int, NoneType)) assert isinstance(extensions.get('legacy_ssl'), (bool, NoneType)) + assert isinstance(extensions.get('keep_header_casing'), (bool, NoneType)) def _validate(self, request): self._check_url_scheme(request) diff --git a/yt_dlp/networking/impersonate.py b/yt_dlp/networking/impersonate.py index 0626b3b491..aad87eb4c0 100644 --- a/yt_dlp/networking/impersonate.py +++ b/yt_dlp/networking/impersonate.py @@ -5,11 +5,11 @@ from abc import ABC from dataclasses import dataclass from typing import Any -from .common import RequestHandler, register_preference +from .common import RequestHandler, register_preference, Request from .exceptions import UnsupportedRequest from ..compat.types import NoneType from ..utils import classproperty, join_nonempty -from ..utils.networking import std_headers +from ..utils.networking import std_headers, HTTPHeaderDict @dataclass(order=True, frozen=True) @@ -123,7 +123,17 @@ class ImpersonateRequestHandler(RequestHandler, ABC): """Get the requested target for the request""" return self._resolve_target(request.extensions.get('impersonate') or self.impersonate) - def _get_impersonate_headers(self, request): + def _prepare_impersonate_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027 + """Additional operations to prepare headers before building. To be extended by subclasses. + @param request: Request object + @param headers: Merged headers to prepare + """ + + def _get_impersonate_headers(self, request: Request) -> dict[str, str]: + """ + Get headers for external impersonation use. + Subclasses may define a _prepare_headers method to modify headers after merge but before building. + """ headers = self._merge_headers(request.headers) if self._get_request_target(request) is not None: # remove all headers present in std_headers @@ -131,7 +141,11 @@ class ImpersonateRequestHandler(RequestHandler, ABC): for k, v in std_headers.items(): if headers.get(k) == v: headers.pop(k) - return headers + + self._prepare_impersonate_headers(request, headers) + if request.extensions.get('keep_header_casing'): + return headers.sensitive() + return dict(headers) @register_preference(ImpersonateRequestHandler)