suggestion to merge logic into core

This commit is contained in:
coletdjnz 2024-11-30 11:01:15 +13:00
parent 20832d0cdc
commit 8e45684816
No known key found for this signature in database
GPG key ID: 91984263BB39894A
8 changed files with 80 additions and 22 deletions

View file

@ -720,6 +720,15 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
rh, Request( rh, Request(
f'http://127.0.0.1:{self.http_port}/headers', proxies={'all': 'http://10.255.255.255'})).close() f'http://127.0.0.1:{self.http_port}/headers', proxies={'all': 'http://10.255.255.255'})).close()
@pytest.mark.skip_handler('Urllib', 'urllib handler does not support keep_header_casing')
def test_keep_header_casing(self, handler):
with handler() as rh:
res = validate_and_send(
rh, Request(
f'http://127.0.0.1:{self.http_port}/headers', headers={'X-test-heaDer': 'test'}, extensions={'keep_header_casing': True})).read().decode()
assert 'X-test-heaDer: test' in res
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True) @pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
class TestClientCertificate: class TestClientCertificate:
@ -1289,6 +1298,7 @@ class TestRequestHandlerValidation:
({'legacy_ssl': False}, False), ({'legacy_ssl': False}, False),
({'legacy_ssl': True}, False), ({'legacy_ssl': True}, False),
({'legacy_ssl': 'notabool'}, AssertionError), ({'legacy_ssl': 'notabool'}, AssertionError),
({'keep_header_casing': True}, UnsupportedRequest),
]), ]),
('Requests', 'http', [ ('Requests', 'http', [
({'cookiejar': 'notacookiejar'}, AssertionError), ({'cookiejar': 'notacookiejar'}, AssertionError),
@ -1299,6 +1309,9 @@ class TestRequestHandlerValidation:
({'legacy_ssl': False}, False), ({'legacy_ssl': False}, False),
({'legacy_ssl': True}, False), ({'legacy_ssl': True}, False),
({'legacy_ssl': 'notabool'}, AssertionError), ({'legacy_ssl': 'notabool'}, AssertionError),
({'keep_header_casing': False}, False),
({'keep_header_casing': True}, False),
({'keep_header_casing': 'notabool'}, AssertionError),
]), ]),
('CurlCFFI', 'http', [ ('CurlCFFI', 'http', [
({'cookiejar': 'notacookiejar'}, AssertionError), ({'cookiejar': 'notacookiejar'}, AssertionError),

View file

@ -44,7 +44,7 @@ def websocket_handler(websocket):
return websocket.send('2') return websocket.send('2')
elif isinstance(message, str): elif isinstance(message, str):
if message == 'headers': if message == 'headers':
return websocket.send(json.dumps(dict(websocket.request.headers))) return websocket.send(json.dumps(dict(websocket.request.headers.raw_items())))
elif message == 'path': elif message == 'path':
return websocket.send(websocket.request.path) return websocket.send(websocket.request.path)
elif message == 'source_address': elif message == 'source_address':
@ -266,18 +266,18 @@ class TestWebsSocketRequestHandlerConformance:
with handler(cookiejar=cookiejar) as rh: with handler(cookiejar=cookiejar) as rh:
ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers') ws.send('headers')
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp'
ws.close() ws.close()
with handler() as rh: with handler() as rh:
ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers') ws.send('headers')
assert 'cookie' not in json.loads(ws.recv()) assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv()))
ws.close() ws.close()
ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
ws.send('headers') ws.send('headers')
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp'
ws.close() ws.close()
@pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets') @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
@ -287,7 +287,7 @@ class TestWebsSocketRequestHandlerConformance:
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie', extensions={'cookiejar': YoutubeDLCookieJar()})) ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie', extensions={'cookiejar': YoutubeDLCookieJar()}))
ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': YoutubeDLCookieJar()})) ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': YoutubeDLCookieJar()}))
ws.send('headers') ws.send('headers')
assert 'cookie' not in json.loads(ws.recv()) assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv()))
ws.close() ws.close()
@pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets') @pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
@ -298,12 +298,12 @@ class TestWebsSocketRequestHandlerConformance:
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie')) ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie'))
ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers') ws.send('headers')
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp'
ws.close() ws.close()
cookiejar.clear_session_cookies() cookiejar.clear_session_cookies()
ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers') ws.send('headers')
assert 'cookie' not in json.loads(ws.recv()) assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv()))
ws.close() ws.close()
def test_source_address(self, handler): def test_source_address(self, handler):
@ -341,6 +341,14 @@ class TestWebsSocketRequestHandlerConformance:
assert headers['test3'] == 'test3' assert headers['test3'] == 'test3'
ws.close() ws.close()
def test_keep_header_casing(self, handler):
with handler(headers=HTTPHeaderDict({'x-TeSt1': 'test'})) as rh:
ws = ws_validate_and_send(rh, Request(self.ws_base_url, headers={'x-TeSt2': 'test'}, extensions={'keep_header_casing': True}))
ws.send('headers')
headers = json.loads(ws.recv())
assert 'x-TeSt1' in headers
assert 'x-TeSt2' in headers
@pytest.mark.parametrize('client_cert', ( @pytest.mark.parametrize('client_cert', (
{'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}, {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
{ {

View file

@ -149,6 +149,7 @@ class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin):
# CurlCFFIRH ignores legacy ssl options currently. # CurlCFFIRH ignores legacy ssl options currently.
# Impersonation generally uses a looser SSL configuration than urllib/requests. # Impersonation generally uses a looser SSL configuration than urllib/requests.
extensions.pop('legacy_ssl', None) extensions.pop('legacy_ssl', None)
extensions.pop('keep_header_casing', None)
def send(self, request: Request) -> Response: def send(self, request: Request) -> Response:
target = self._get_request_target(request) target = self._get_request_target(request)

View file

@ -313,13 +313,12 @@ class RequestsRH(RequestHandler, InstanceStoreMixin):
session.trust_env = False # no need, we already load proxies from env session.trust_env = False # no need, we already load proxies from env
return session return session
def _prepare_headers(self, _, headers):
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
def _send(self, request): def _send(self, request):
headers = self._merge_headers(request.headers) headers = self._get_headers(request)
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
if request.extensions.get('keep_header_casing'):
headers = headers.sensitive()
max_redirects_exceeded = False max_redirects_exceeded = False
session = self._get_instance( session = self._get_instance(

View file

@ -379,13 +379,15 @@ class UrllibRH(RequestHandler, InstanceStoreMixin):
opener.addheaders = [] opener.addheaders = []
return opener return opener
def _send(self, request): def _prepare_headers(self, _, headers):
headers = self._merge_headers(request.headers)
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS) add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
def _send(self, request):
headers = self._get_headers(request)
urllib_req = urllib.request.Request( urllib_req = urllib.request.Request(
url=request.url, url=request.url,
data=request.data, data=request.data,
headers=dict(headers), headers=headers,
method=request.method, method=request.method,
) )

View file

@ -116,6 +116,7 @@ class WebsocketsRH(WebSocketRequestHandler):
extensions.pop('timeout', None) extensions.pop('timeout', None)
extensions.pop('cookiejar', None) extensions.pop('cookiejar', None)
extensions.pop('legacy_ssl', None) extensions.pop('legacy_ssl', None)
extensions.pop('keep_header_casing', None)
def close(self): def close(self):
# Remove the logging handler that contains a reference to our logger # Remove the logging handler that contains a reference to our logger
@ -123,15 +124,16 @@ class WebsocketsRH(WebSocketRequestHandler):
for name, handler in self.__logging_handlers.items(): for name, handler in self.__logging_handlers.items():
logging.getLogger(name).removeHandler(handler) logging.getLogger(name).removeHandler(handler)
def _send(self, request): def _prepare_headers(self, request, headers):
timeout = self._calculate_timeout(request)
headers = self._merge_headers(request.headers)
if 'cookie' not in headers: if 'cookie' not in headers:
cookiejar = self._get_cookiejar(request) cookiejar = self._get_cookiejar(request)
cookie_header = cookiejar.get_cookie_header(request.url) cookie_header = cookiejar.get_cookie_header(request.url)
if cookie_header: if cookie_header:
headers['cookie'] = cookie_header headers['cookie'] = cookie_header
def _send(self, request):
timeout = self._calculate_timeout(request)
headers = self._get_headers(request)
wsuri = parse_uri(request.url) wsuri = parse_uri(request.url)
create_conn_kwargs = { create_conn_kwargs = {
'source_address': (self.source_address, 0) if self.source_address else None, 'source_address': (self.source_address, 0) if self.source_address else None,

View file

@ -206,6 +206,7 @@ class RequestHandler(abc.ABC):
- `cookiejar`: Cookiejar to use for this request. - `cookiejar`: Cookiejar to use for this request.
- `timeout`: socket timeout to use for this request. - `timeout`: socket timeout to use for this request.
- `legacy_ssl`: Enable legacy SSL options for this request. See legacy_ssl_support. - `legacy_ssl`: Enable legacy SSL options for this request. See legacy_ssl_support.
- `keep_header_casing`: Keep the casing of headers when sending the request.
To enable these, add extensions.pop('<extension>', None) to _check_extensions To enable these, add extensions.pop('<extension>', None) to _check_extensions
Apart from the url protocol, proxies dict may contain the following keys: Apart from the url protocol, proxies dict may contain the following keys:
@ -259,6 +260,23 @@ class RequestHandler(abc.ABC):
def _merge_headers(self, request_headers): def _merge_headers(self, request_headers):
return HTTPHeaderDict(self.headers, request_headers) return HTTPHeaderDict(self.headers, request_headers)
def _prepare_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027
"""Additional operations to prepare headers before building. To be extended by subclasses.
@param request: Request object
@param headers: Merged headers to prepare
"""
def _get_headers(self, request: Request) -> dict[str, str]:
"""
Get headers for external use.
Subclasses may define a _prepare_headers method to modify headers after merge but before building.
"""
headers = self._merge_headers(request.headers)
self._prepare_headers(request, headers)
if request.extensions.get('keep_header_casing'):
return headers.sensitive()
return dict(headers)
def _calculate_timeout(self, request): def _calculate_timeout(self, request):
return float(request.extensions.get('timeout') or self.timeout) return float(request.extensions.get('timeout') or self.timeout)
@ -317,6 +335,7 @@ class RequestHandler(abc.ABC):
assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType)) assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType))
assert isinstance(extensions.get('timeout'), (float, int, NoneType)) assert isinstance(extensions.get('timeout'), (float, int, NoneType))
assert isinstance(extensions.get('legacy_ssl'), (bool, NoneType)) assert isinstance(extensions.get('legacy_ssl'), (bool, NoneType))
assert isinstance(extensions.get('keep_header_casing'), (bool, NoneType))
def _validate(self, request): def _validate(self, request):
self._check_url_scheme(request) self._check_url_scheme(request)

View file

@ -5,11 +5,11 @@ from abc import ABC
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from .common import RequestHandler, register_preference from .common import RequestHandler, register_preference, Request
from .exceptions import UnsupportedRequest from .exceptions import UnsupportedRequest
from ..compat.types import NoneType from ..compat.types import NoneType
from ..utils import classproperty, join_nonempty from ..utils import classproperty, join_nonempty
from ..utils.networking import std_headers from ..utils.networking import std_headers, HTTPHeaderDict
@dataclass(order=True, frozen=True) @dataclass(order=True, frozen=True)
@ -123,7 +123,17 @@ class ImpersonateRequestHandler(RequestHandler, ABC):
"""Get the requested target for the request""" """Get the requested target for the request"""
return self._resolve_target(request.extensions.get('impersonate') or self.impersonate) return self._resolve_target(request.extensions.get('impersonate') or self.impersonate)
def _get_impersonate_headers(self, request): def _prepare_impersonate_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027
"""Additional operations to prepare headers before building. To be extended by subclasses.
@param request: Request object
@param headers: Merged headers to prepare
"""
def _get_impersonate_headers(self, request: Request) -> dict[str, str]:
"""
Get headers for external impersonation use.
Subclasses may define a _prepare_headers method to modify headers after merge but before building.
"""
headers = self._merge_headers(request.headers) headers = self._merge_headers(request.headers)
if self._get_request_target(request) is not None: if self._get_request_target(request) is not None:
# remove all headers present in std_headers # remove all headers present in std_headers
@ -131,7 +141,11 @@ class ImpersonateRequestHandler(RequestHandler, ABC):
for k, v in std_headers.items(): for k, v in std_headers.items():
if headers.get(k) == v: if headers.get(k) == v:
headers.pop(k) headers.pop(k)
return headers
self._prepare_impersonate_headers(request, headers)
if request.extensions.get('keep_header_casing'):
return headers.sensitive()
return dict(headers)
@register_preference(ImpersonateRequestHandler) @register_preference(ImpersonateRequestHandler)