mirror of
https://github.com/yt-dlp/yt-dlp
synced 2024-12-26 21:59:08 +01:00
suggestion to merge logic into core
This commit is contained in:
parent
20832d0cdc
commit
8e45684816
8 changed files with 80 additions and 22 deletions
|
@ -720,6 +720,15 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
|
|||
rh, Request(
|
||||
f'http://127.0.0.1:{self.http_port}/headers', proxies={'all': 'http://10.255.255.255'})).close()
|
||||
|
||||
@pytest.mark.skip_handler('Urllib', 'urllib handler does not support keep_header_casing')
|
||||
def test_keep_header_casing(self, handler):
|
||||
with handler() as rh:
|
||||
res = validate_and_send(
|
||||
rh, Request(
|
||||
f'http://127.0.0.1:{self.http_port}/headers', headers={'X-test-heaDer': 'test'}, extensions={'keep_header_casing': True})).read().decode()
|
||||
|
||||
assert 'X-test-heaDer: test' in res
|
||||
|
||||
|
||||
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
|
||||
class TestClientCertificate:
|
||||
|
@ -1289,6 +1298,7 @@ class TestRequestHandlerValidation:
|
|||
({'legacy_ssl': False}, False),
|
||||
({'legacy_ssl': True}, False),
|
||||
({'legacy_ssl': 'notabool'}, AssertionError),
|
||||
({'keep_header_casing': True}, UnsupportedRequest),
|
||||
]),
|
||||
('Requests', 'http', [
|
||||
({'cookiejar': 'notacookiejar'}, AssertionError),
|
||||
|
@ -1299,6 +1309,9 @@ class TestRequestHandlerValidation:
|
|||
({'legacy_ssl': False}, False),
|
||||
({'legacy_ssl': True}, False),
|
||||
({'legacy_ssl': 'notabool'}, AssertionError),
|
||||
({'keep_header_casing': False}, False),
|
||||
({'keep_header_casing': True}, False),
|
||||
({'keep_header_casing': 'notabool'}, AssertionError),
|
||||
]),
|
||||
('CurlCFFI', 'http', [
|
||||
({'cookiejar': 'notacookiejar'}, AssertionError),
|
||||
|
|
|
@ -44,7 +44,7 @@ def websocket_handler(websocket):
|
|||
return websocket.send('2')
|
||||
elif isinstance(message, str):
|
||||
if message == 'headers':
|
||||
return websocket.send(json.dumps(dict(websocket.request.headers)))
|
||||
return websocket.send(json.dumps(dict(websocket.request.headers.raw_items())))
|
||||
elif message == 'path':
|
||||
return websocket.send(websocket.request.path)
|
||||
elif message == 'source_address':
|
||||
|
@ -266,18 +266,18 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
with handler(cookiejar=cookiejar) as rh:
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws.send('headers')
|
||||
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
||||
assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp'
|
||||
ws.close()
|
||||
|
||||
with handler() as rh:
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws.send('headers')
|
||||
assert 'cookie' not in json.loads(ws.recv())
|
||||
assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv()))
|
||||
ws.close()
|
||||
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
|
||||
ws.send('headers')
|
||||
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
||||
assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp'
|
||||
ws.close()
|
||||
|
||||
@pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
|
||||
|
@ -287,7 +287,7 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie', extensions={'cookiejar': YoutubeDLCookieJar()}))
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': YoutubeDLCookieJar()}))
|
||||
ws.send('headers')
|
||||
assert 'cookie' not in json.loads(ws.recv())
|
||||
assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv()))
|
||||
ws.close()
|
||||
|
||||
@pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
|
||||
|
@ -298,12 +298,12 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie'))
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws.send('headers')
|
||||
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
||||
assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp'
|
||||
ws.close()
|
||||
cookiejar.clear_session_cookies()
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws.send('headers')
|
||||
assert 'cookie' not in json.loads(ws.recv())
|
||||
assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv()))
|
||||
ws.close()
|
||||
|
||||
def test_source_address(self, handler):
|
||||
|
@ -341,6 +341,14 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
assert headers['test3'] == 'test3'
|
||||
ws.close()
|
||||
|
||||
def test_keep_header_casing(self, handler):
|
||||
with handler(headers=HTTPHeaderDict({'x-TeSt1': 'test'})) as rh:
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url, headers={'x-TeSt2': 'test'}, extensions={'keep_header_casing': True}))
|
||||
ws.send('headers')
|
||||
headers = json.loads(ws.recv())
|
||||
assert 'x-TeSt1' in headers
|
||||
assert 'x-TeSt2' in headers
|
||||
|
||||
@pytest.mark.parametrize('client_cert', (
|
||||
{'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
|
||||
{
|
||||
|
|
|
@ -149,6 +149,7 @@ class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin):
|
|||
# CurlCFFIRH ignores legacy ssl options currently.
|
||||
# Impersonation generally uses a looser SSL configuration than urllib/requests.
|
||||
extensions.pop('legacy_ssl', None)
|
||||
extensions.pop('keep_header_casing', None)
|
||||
|
||||
def send(self, request: Request) -> Response:
|
||||
target = self._get_request_target(request)
|
||||
|
|
|
@ -313,13 +313,12 @@ class RequestsRH(RequestHandler, InstanceStoreMixin):
|
|||
session.trust_env = False # no need, we already load proxies from env
|
||||
return session
|
||||
|
||||
def _prepare_headers(self, _, headers):
|
||||
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
|
||||
|
||||
def _send(self, request):
|
||||
|
||||
headers = self._merge_headers(request.headers)
|
||||
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
|
||||
if request.extensions.get('keep_header_casing'):
|
||||
headers = headers.sensitive()
|
||||
|
||||
headers = self._get_headers(request)
|
||||
max_redirects_exceeded = False
|
||||
|
||||
session = self._get_instance(
|
||||
|
|
|
@ -379,13 +379,15 @@ class UrllibRH(RequestHandler, InstanceStoreMixin):
|
|||
opener.addheaders = []
|
||||
return opener
|
||||
|
||||
def _send(self, request):
|
||||
headers = self._merge_headers(request.headers)
|
||||
def _prepare_headers(self, _, headers):
|
||||
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
|
||||
|
||||
def _send(self, request):
|
||||
headers = self._get_headers(request)
|
||||
urllib_req = urllib.request.Request(
|
||||
url=request.url,
|
||||
data=request.data,
|
||||
headers=dict(headers),
|
||||
headers=headers,
|
||||
method=request.method,
|
||||
)
|
||||
|
||||
|
|
|
@ -116,6 +116,7 @@ class WebsocketsRH(WebSocketRequestHandler):
|
|||
extensions.pop('timeout', None)
|
||||
extensions.pop('cookiejar', None)
|
||||
extensions.pop('legacy_ssl', None)
|
||||
extensions.pop('keep_header_casing', None)
|
||||
|
||||
def close(self):
|
||||
# Remove the logging handler that contains a reference to our logger
|
||||
|
@ -123,15 +124,16 @@ class WebsocketsRH(WebSocketRequestHandler):
|
|||
for name, handler in self.__logging_handlers.items():
|
||||
logging.getLogger(name).removeHandler(handler)
|
||||
|
||||
def _send(self, request):
|
||||
timeout = self._calculate_timeout(request)
|
||||
headers = self._merge_headers(request.headers)
|
||||
def _prepare_headers(self, request, headers):
|
||||
if 'cookie' not in headers:
|
||||
cookiejar = self._get_cookiejar(request)
|
||||
cookie_header = cookiejar.get_cookie_header(request.url)
|
||||
if cookie_header:
|
||||
headers['cookie'] = cookie_header
|
||||
|
||||
def _send(self, request):
|
||||
timeout = self._calculate_timeout(request)
|
||||
headers = self._get_headers(request)
|
||||
wsuri = parse_uri(request.url)
|
||||
create_conn_kwargs = {
|
||||
'source_address': (self.source_address, 0) if self.source_address else None,
|
||||
|
|
|
@ -206,6 +206,7 @@ class RequestHandler(abc.ABC):
|
|||
- `cookiejar`: Cookiejar to use for this request.
|
||||
- `timeout`: socket timeout to use for this request.
|
||||
- `legacy_ssl`: Enable legacy SSL options for this request. See legacy_ssl_support.
|
||||
- `keep_header_casing`: Keep the casing of headers when sending the request.
|
||||
To enable these, add extensions.pop('<extension>', None) to _check_extensions
|
||||
|
||||
Apart from the url protocol, proxies dict may contain the following keys:
|
||||
|
@ -259,6 +260,23 @@ class RequestHandler(abc.ABC):
|
|||
def _merge_headers(self, request_headers):
|
||||
return HTTPHeaderDict(self.headers, request_headers)
|
||||
|
||||
def _prepare_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027
|
||||
"""Additional operations to prepare headers before building. To be extended by subclasses.
|
||||
@param request: Request object
|
||||
@param headers: Merged headers to prepare
|
||||
"""
|
||||
|
||||
def _get_headers(self, request: Request) -> dict[str, str]:
|
||||
"""
|
||||
Get headers for external use.
|
||||
Subclasses may define a _prepare_headers method to modify headers after merge but before building.
|
||||
"""
|
||||
headers = self._merge_headers(request.headers)
|
||||
self._prepare_headers(request, headers)
|
||||
if request.extensions.get('keep_header_casing'):
|
||||
return headers.sensitive()
|
||||
return dict(headers)
|
||||
|
||||
def _calculate_timeout(self, request):
|
||||
return float(request.extensions.get('timeout') or self.timeout)
|
||||
|
||||
|
@ -317,6 +335,7 @@ class RequestHandler(abc.ABC):
|
|||
assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType))
|
||||
assert isinstance(extensions.get('timeout'), (float, int, NoneType))
|
||||
assert isinstance(extensions.get('legacy_ssl'), (bool, NoneType))
|
||||
assert isinstance(extensions.get('keep_header_casing'), (bool, NoneType))
|
||||
|
||||
def _validate(self, request):
|
||||
self._check_url_scheme(request)
|
||||
|
|
|
@ -5,11 +5,11 @@ from abc import ABC
|
|||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .common import RequestHandler, register_preference
|
||||
from .common import RequestHandler, register_preference, Request
|
||||
from .exceptions import UnsupportedRequest
|
||||
from ..compat.types import NoneType
|
||||
from ..utils import classproperty, join_nonempty
|
||||
from ..utils.networking import std_headers
|
||||
from ..utils.networking import std_headers, HTTPHeaderDict
|
||||
|
||||
|
||||
@dataclass(order=True, frozen=True)
|
||||
|
@ -123,7 +123,17 @@ class ImpersonateRequestHandler(RequestHandler, ABC):
|
|||
"""Get the requested target for the request"""
|
||||
return self._resolve_target(request.extensions.get('impersonate') or self.impersonate)
|
||||
|
||||
def _get_impersonate_headers(self, request):
|
||||
def _prepare_impersonate_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027
|
||||
"""Additional operations to prepare headers before building. To be extended by subclasses.
|
||||
@param request: Request object
|
||||
@param headers: Merged headers to prepare
|
||||
"""
|
||||
|
||||
def _get_impersonate_headers(self, request: Request) -> dict[str, str]:
|
||||
"""
|
||||
Get headers for external impersonation use.
|
||||
Subclasses may define a _prepare_headers method to modify headers after merge but before building.
|
||||
"""
|
||||
headers = self._merge_headers(request.headers)
|
||||
if self._get_request_target(request) is not None:
|
||||
# remove all headers present in std_headers
|
||||
|
@ -131,7 +141,11 @@ class ImpersonateRequestHandler(RequestHandler, ABC):
|
|||
for k, v in std_headers.items():
|
||||
if headers.get(k) == v:
|
||||
headers.pop(k)
|
||||
return headers
|
||||
|
||||
self._prepare_impersonate_headers(request, headers)
|
||||
if request.extensions.get('keep_header_casing'):
|
||||
return headers.sensitive()
|
||||
return dict(headers)
|
||||
|
||||
|
||||
@register_preference(ImpersonateRequestHandler)
|
||||
|
|
Loading…
Reference in a new issue