[networking] Add legacy_ssl request extension (#10448)

Supported by Urllib, Requests and Websockets request handlers. Ignored by CurlCFFI.

Also added couple cookie-related tests.

Authored by: coletdjnz
This commit is contained in:
coletdjnz 2024-07-14 11:22:43 +12:00 committed by GitHub
parent 8b8b442cb0
commit 150ecc45d9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 162 additions and 9 deletions

View file

@ -265,6 +265,11 @@ class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
self.wfile.write(payload) self.wfile.write(payload)
self.finish() self.finish()
elif self.path == '/get_cookie':
self.send_response(200)
self.send_header('Set-Cookie', 'test=ytdlp; path=/')
self.end_headers()
self.finish()
else: else:
self._status(404) self._status(404)
@ -338,6 +343,52 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
validate_and_send(rh, Request(f'https://127.0.0.1:{https_port}/headers')) validate_and_send(rh, Request(f'https://127.0.0.1:{https_port}/headers'))
assert not issubclass(exc_info.type, CertificateVerifyError) assert not issubclass(exc_info.type, CertificateVerifyError)
@pytest.mark.skip_handler('CurlCFFI', 'legacy_ssl ignored by CurlCFFI')
def test_legacy_ssl_extension(self, handler):
# HTTPS server with old ciphers
# XXX: is there a better way to test this than to create a new server?
https_httpd = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.maximum_version = ssl.TLSVersion.TLSv1_2
sslctx.set_ciphers('SHA1:AESCCM:aDSS:eNULL:aNULL')
sslctx.load_cert_chain(os.path.join(TEST_DIR, 'testcert.pem'), None)
https_httpd.socket = sslctx.wrap_socket(https_httpd.socket, server_side=True)
https_port = http_server_port(https_httpd)
https_server_thread = threading.Thread(target=https_httpd.serve_forever)
https_server_thread.daemon = True
https_server_thread.start()
with handler(verify=False) as rh:
res = validate_and_send(rh, Request(f'https://127.0.0.1:{https_port}/headers', extensions={'legacy_ssl': True}))
assert res.status == 200
res.close()
# Ensure only applies to request extension
with pytest.raises(SSLError):
validate_and_send(rh, Request(f'https://127.0.0.1:{https_port}/headers'))
@pytest.mark.skip_handler('CurlCFFI', 'legacy_ssl ignored by CurlCFFI')
def test_legacy_ssl_support(self, handler):
# HTTPS server with old ciphers
# XXX: is there a better way to test this than to create a new server?
https_httpd = http.server.ThreadingHTTPServer(
('127.0.0.1', 0), HTTPTestRequestHandler)
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.maximum_version = ssl.TLSVersion.TLSv1_2
sslctx.set_ciphers('SHA1:AESCCM:aDSS:eNULL:aNULL')
sslctx.load_cert_chain(os.path.join(TEST_DIR, 'testcert.pem'), None)
https_httpd.socket = sslctx.wrap_socket(https_httpd.socket, server_side=True)
https_port = http_server_port(https_httpd)
https_server_thread = threading.Thread(target=https_httpd.serve_forever)
https_server_thread.daemon = True
https_server_thread.start()
with handler(verify=False, legacy_ssl_support=True) as rh:
res = validate_and_send(rh, Request(f'https://127.0.0.1:{https_port}/headers'))
assert res.status == 200
res.close()
def test_percent_encode(self, handler): def test_percent_encode(self, handler):
with handler() as rh: with handler() as rh:
# Unicode characters should be encoded with uppercase percent-encoding # Unicode characters should be encoded with uppercase percent-encoding
@ -490,6 +541,24 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
rh, Request(f'http://127.0.0.1:{self.http_port}/headers', extensions={'cookiejar': cookiejar})).read() rh, Request(f'http://127.0.0.1:{self.http_port}/headers', extensions={'cookiejar': cookiejar})).read()
assert b'cookie: test=ytdlp' in data.lower() assert b'cookie: test=ytdlp' in data.lower()
def test_cookie_sync_only_cookiejar(self, handler):
# Ensure that cookies are ONLY being handled by the cookiejar
with handler() as rh:
validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/get_cookie', extensions={'cookiejar': YoutubeDLCookieJar()}))
data = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/headers', extensions={'cookiejar': YoutubeDLCookieJar()})).read()
assert b'cookie: test=ytdlp' not in data.lower()
def test_cookie_sync_delete_cookie(self, handler):
# Ensure that cookies are ONLY being handled by the cookiejar
cookiejar = YoutubeDLCookieJar()
with handler(cookiejar=cookiejar) as rh:
validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/get_cookie'))
data = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/headers')).read()
assert b'cookie: test=ytdlp' in data.lower()
cookiejar.clear_session_cookies()
data = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}/headers')).read()
assert b'cookie: test=ytdlp' not in data.lower()
def test_headers(self, handler): def test_headers(self, handler):
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
@ -1199,6 +1268,9 @@ class TestRequestHandlerValidation:
({'timeout': 1}, False), ({'timeout': 1}, False),
({'timeout': 'notatimeout'}, AssertionError), ({'timeout': 'notatimeout'}, AssertionError),
({'unsupported': 'value'}, UnsupportedRequest), ({'unsupported': 'value'}, UnsupportedRequest),
({'legacy_ssl': False}, False),
({'legacy_ssl': True}, False),
({'legacy_ssl': 'notabool'}, AssertionError),
]), ]),
('Requests', 'http', [ ('Requests', 'http', [
({'cookiejar': 'notacookiejar'}, AssertionError), ({'cookiejar': 'notacookiejar'}, AssertionError),
@ -1206,6 +1278,9 @@ class TestRequestHandlerValidation:
({'timeout': 1}, False), ({'timeout': 1}, False),
({'timeout': 'notatimeout'}, AssertionError), ({'timeout': 'notatimeout'}, AssertionError),
({'unsupported': 'value'}, UnsupportedRequest), ({'unsupported': 'value'}, UnsupportedRequest),
({'legacy_ssl': False}, False),
({'legacy_ssl': True}, False),
({'legacy_ssl': 'notabool'}, AssertionError),
]), ]),
('CurlCFFI', 'http', [ ('CurlCFFI', 'http', [
({'cookiejar': 'notacookiejar'}, AssertionError), ({'cookiejar': 'notacookiejar'}, AssertionError),
@ -1219,6 +1294,9 @@ class TestRequestHandlerValidation:
({'impersonate': ImpersonateTarget(None, None, None, None)}, False), ({'impersonate': ImpersonateTarget(None, None, None, None)}, False),
({'impersonate': ImpersonateTarget()}, False), ({'impersonate': ImpersonateTarget()}, False),
({'impersonate': 'chrome'}, AssertionError), ({'impersonate': 'chrome'}, AssertionError),
({'legacy_ssl': False}, False),
({'legacy_ssl': True}, False),
({'legacy_ssl': 'notabool'}, AssertionError),
]), ]),
(NoCheckRH, 'http', [ (NoCheckRH, 'http', [
({'cookiejar': 'notacookiejar'}, False), ({'cookiejar': 'notacookiejar'}, False),
@ -1227,6 +1305,9 @@ class TestRequestHandlerValidation:
('Websockets', 'ws', [ ('Websockets', 'ws', [
({'cookiejar': YoutubeDLCookieJar()}, False), ({'cookiejar': YoutubeDLCookieJar()}, False),
({'timeout': 2}, False), ({'timeout': 2}, False),
({'legacy_ssl': False}, False),
({'legacy_ssl': True}, False),
({'legacy_ssl': 'notabool'}, AssertionError),
]), ]),
] ]

View file

@ -61,6 +61,10 @@ def process_request(self, request):
return websockets.http11.Response( return websockets.http11.Response(
status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'') status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
return self.protocol.reject(status.value, status.phrase) return self.protocol.reject(status.value, status.phrase)
elif request.path.startswith('/get_cookie'):
response = self.protocol.accept(request)
response.headers['Set-Cookie'] = 'test=ytdlp'
return response
return self.protocol.accept(request) return self.protocol.accept(request)
@ -102,6 +106,15 @@ def create_mtls_wss_websocket_server():
return create_websocket_server(ssl_context=sslctx) return create_websocket_server(ssl_context=sslctx)
def create_legacy_wss_websocket_server():
certfn = os.path.join(TEST_DIR, 'testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.maximum_version = ssl.TLSVersion.TLSv1_2
sslctx.set_ciphers('SHA1:AESCCM:aDSS:eNULL:aNULL')
sslctx.load_cert_chain(certfn, None)
return create_websocket_server(ssl_context=sslctx)
def ws_validate_and_send(rh, req): def ws_validate_and_send(rh, req):
rh.validate(req) rh.validate(req)
max_tries = 3 max_tries = 3
@ -132,6 +145,9 @@ class TestWebsSocketRequestHandlerConformance:
cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server() cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}' cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
cls.legacy_wss_thread, cls.legacy_wss_port = create_legacy_wss_websocket_server()
cls.legacy_wss_host = f'wss://127.0.0.1:{cls.legacy_wss_port}'
def test_basic_websockets(self, handler): def test_basic_websockets(self, handler):
with handler() as rh: with handler() as rh:
ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
@ -166,6 +182,22 @@ class TestWebsSocketRequestHandlerConformance:
ws_validate_and_send(rh, Request(self.bad_wss_host)) ws_validate_and_send(rh, Request(self.bad_wss_host))
assert not issubclass(exc_info.type, CertificateVerifyError) assert not issubclass(exc_info.type, CertificateVerifyError)
def test_legacy_ssl_extension(self, handler):
with handler(verify=False) as rh:
ws = ws_validate_and_send(rh, Request(self.legacy_wss_host, extensions={'legacy_ssl': True}))
assert ws.status == 101
ws.close()
# Ensure only applies to request extension
with pytest.raises(SSLError):
ws_validate_and_send(rh, Request(self.legacy_wss_host))
def test_legacy_ssl_support(self, handler):
with handler(verify=False, legacy_ssl_support=True) as rh:
ws = ws_validate_and_send(rh, Request(self.legacy_wss_host))
assert ws.status == 101
ws.close()
@pytest.mark.parametrize('path,expected', [ @pytest.mark.parametrize('path,expected', [
# Unicode characters should be encoded with uppercase percent-encoding # Unicode characters should be encoded with uppercase percent-encoding
('/中文', '/%E4%B8%AD%E6%96%87'), ('/中文', '/%E4%B8%AD%E6%96%87'),
@ -248,6 +280,32 @@ class TestWebsSocketRequestHandlerConformance:
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
ws.close() ws.close()
@pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
def test_cookie_sync_only_cookiejar(self, handler):
# Ensure that cookies are ONLY being handled by the cookiejar
with handler() as rh:
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie', extensions={'cookiejar': YoutubeDLCookieJar()}))
ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': YoutubeDLCookieJar()}))
ws.send('headers')
assert 'cookie' not in json.loads(ws.recv())
ws.close()
@pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
def test_cookie_sync_delete_cookie(self, handler):
# Ensure that cookies are ONLY being handled by the cookiejar
cookiejar = YoutubeDLCookieJar()
with handler(verbose=True, cookiejar=cookiejar) as rh:
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie'))
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
ws.close()
cookiejar.clear_session_cookies()
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
assert 'cookie' not in json.loads(ws.recv())
ws.close()
def test_source_address(self, handler): def test_source_address(self, handler):
source_address = f'127.0.0.{random.randint(5, 255)}' source_address = f'127.0.0.{random.randint(5, 255)}'
verify_address_availability(source_address) verify_address_availability(source_address)

View file

@ -146,6 +146,9 @@ class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin):
extensions.pop('impersonate', None) extensions.pop('impersonate', None)
extensions.pop('cookiejar', None) extensions.pop('cookiejar', None)
extensions.pop('timeout', None) extensions.pop('timeout', None)
# CurlCFFIRH ignores legacy ssl options currently.
# Impersonation generally uses a looser SSL configuration than urllib/requests.
extensions.pop('legacy_ssl', None)
def send(self, request: Request) -> Response: def send(self, request: Request) -> Response:
target = self._get_request_target(request) target = self._get_request_target(request)

View file

@ -295,11 +295,12 @@ class RequestsRH(RequestHandler, InstanceStoreMixin):
super()._check_extensions(extensions) super()._check_extensions(extensions)
extensions.pop('cookiejar', None) extensions.pop('cookiejar', None)
extensions.pop('timeout', None) extensions.pop('timeout', None)
extensions.pop('legacy_ssl', None)
def _create_instance(self, cookiejar): def _create_instance(self, cookiejar, legacy_ssl_support=None):
session = RequestsSession() session = RequestsSession()
http_adapter = RequestsHTTPAdapter( http_adapter = RequestsHTTPAdapter(
ssl_context=self._make_sslcontext(), ssl_context=self._make_sslcontext(legacy_ssl_support=legacy_ssl_support),
source_address=self.source_address, source_address=self.source_address,
max_retries=urllib3.util.retry.Retry(False), max_retries=urllib3.util.retry.Retry(False),
) )
@ -318,7 +319,10 @@ class RequestsRH(RequestHandler, InstanceStoreMixin):
max_redirects_exceeded = False max_redirects_exceeded = False
session = self._get_instance(cookiejar=self._get_cookiejar(request)) session = self._get_instance(
cookiejar=self._get_cookiejar(request),
legacy_ssl_support=request.extensions.get('legacy_ssl'),
)
try: try:
requests_res = session.request( requests_res = session.request(

View file

@ -348,14 +348,15 @@ class UrllibRH(RequestHandler, InstanceStoreMixin):
super()._check_extensions(extensions) super()._check_extensions(extensions)
extensions.pop('cookiejar', None) extensions.pop('cookiejar', None)
extensions.pop('timeout', None) extensions.pop('timeout', None)
extensions.pop('legacy_ssl', None)
def _create_instance(self, proxies, cookiejar): def _create_instance(self, proxies, cookiejar, legacy_ssl_support=None):
opener = urllib.request.OpenerDirector() opener = urllib.request.OpenerDirector()
handlers = [ handlers = [
ProxyHandler(proxies), ProxyHandler(proxies),
HTTPHandler( HTTPHandler(
debuglevel=int(bool(self.verbose)), debuglevel=int(bool(self.verbose)),
context=self._make_sslcontext(), context=self._make_sslcontext(legacy_ssl_support=legacy_ssl_support),
source_address=self.source_address), source_address=self.source_address),
HTTPCookieProcessor(cookiejar), HTTPCookieProcessor(cookiejar),
DataHandler(), DataHandler(),
@ -391,6 +392,7 @@ class UrllibRH(RequestHandler, InstanceStoreMixin):
opener = self._get_instance( opener = self._get_instance(
proxies=self._get_proxies(request), proxies=self._get_proxies(request),
cookiejar=self._get_cookiejar(request), cookiejar=self._get_cookiejar(request),
legacy_ssl_support=request.extensions.get('legacy_ssl'),
) )
try: try:
res = opener.open(urllib_req, timeout=self._calculate_timeout(request)) res = opener.open(urllib_req, timeout=self._calculate_timeout(request))

View file

@ -118,6 +118,7 @@ class WebsocketsRH(WebSocketRequestHandler):
super()._check_extensions(extensions) super()._check_extensions(extensions)
extensions.pop('timeout', None) extensions.pop('timeout', None)
extensions.pop('cookiejar', None) extensions.pop('cookiejar', None)
extensions.pop('legacy_ssl', None)
def close(self): def close(self):
# Remove the logging handler that contains a reference to our logger # Remove the logging handler that contains a reference to our logger
@ -154,13 +155,14 @@ class WebsocketsRH(WebSocketRequestHandler):
address=(wsuri.host, wsuri.port), address=(wsuri.host, wsuri.port),
**create_conn_kwargs, **create_conn_kwargs,
) )
ssl_ctx = self._make_sslcontext(legacy_ssl_support=request.extensions.get('legacy_ssl'))
conn = websockets.sync.client.connect( conn = websockets.sync.client.connect(
sock=sock, sock=sock,
uri=request.url, uri=request.url,
additional_headers=headers, additional_headers=headers,
open_timeout=timeout, open_timeout=timeout,
user_agent_header=None, user_agent_header=None,
ssl_context=self._make_sslcontext() if wsuri.secure else None, ssl_context=ssl_ctx if wsuri.secure else None,
close_timeout=0, # not ideal, but prevents yt-dlp hanging close_timeout=0, # not ideal, but prevents yt-dlp hanging
) )
return WebsocketsResponseAdapter(conn, url=request.url) return WebsocketsResponseAdapter(conn, url=request.url)

View file

@ -205,6 +205,7 @@ class RequestHandler(abc.ABC):
The following extensions are defined for RequestHandler: The following extensions are defined for RequestHandler:
- `cookiejar`: Cookiejar to use for this request. - `cookiejar`: Cookiejar to use for this request.
- `timeout`: socket timeout to use for this request. - `timeout`: socket timeout to use for this request.
- `legacy_ssl`: Enable legacy SSL options for this request. See legacy_ssl_support.
To enable these, add extensions.pop('<extension>', None) to _check_extensions To enable these, add extensions.pop('<extension>', None) to _check_extensions
Apart from the url protocol, proxies dict may contain the following keys: Apart from the url protocol, proxies dict may contain the following keys:
@ -247,10 +248,10 @@ class RequestHandler(abc.ABC):
self.legacy_ssl_support = legacy_ssl_support self.legacy_ssl_support = legacy_ssl_support
super().__init__() super().__init__()
def _make_sslcontext(self): def _make_sslcontext(self, legacy_ssl_support=None):
return make_ssl_context( return make_ssl_context(
verify=self.verify, verify=self.verify,
legacy_support=self.legacy_ssl_support, legacy_support=legacy_ssl_support if legacy_ssl_support is not None else self.legacy_ssl_support,
use_certifi=not self.prefer_system_certs, use_certifi=not self.prefer_system_certs,
**self._client_cert, **self._client_cert,
) )
@ -262,7 +263,8 @@ class RequestHandler(abc.ABC):
return float(request.extensions.get('timeout') or self.timeout) return float(request.extensions.get('timeout') or self.timeout)
def _get_cookiejar(self, request): def _get_cookiejar(self, request):
return request.extensions.get('cookiejar') or self.cookiejar cookiejar = request.extensions.get('cookiejar')
return self.cookiejar if cookiejar is None else cookiejar
def _get_proxies(self, request): def _get_proxies(self, request):
return (request.proxies or self.proxies).copy() return (request.proxies or self.proxies).copy()
@ -314,6 +316,7 @@ class RequestHandler(abc.ABC):
"""Check extensions for unsupported extensions. Subclasses should extend this.""" """Check extensions for unsupported extensions. Subclasses should extend this."""
assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType)) assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType))
assert isinstance(extensions.get('timeout'), (float, int, NoneType)) assert isinstance(extensions.get('timeout'), (float, int, NoneType))
assert isinstance(extensions.get('legacy_ssl'), (bool, NoneType))
def _validate(self, request): def _validate(self, request):
self._check_url_scheme(request) self._check_url_scheme(request)