mirror of
https://github.com/yt-dlp/yt-dlp
synced 2025-01-15 03:41:33 +01:00
53b4d44f55
Fixes https://github.com/yt-dlp/yt-dlp/issues/9659 Authored by: coletdjnz
439 lines
18 KiB
Python
439 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Allow direct execution
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
import pytest
|
|
|
|
from test.helper import verify_address_availability
|
|
from yt_dlp.networking.common import Features, DEFAULT_TIMEOUT
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
import http.client
|
|
import http.cookiejar
|
|
import http.server
|
|
import json
|
|
import random
|
|
import ssl
|
|
import threading
|
|
|
|
from yt_dlp import socks, traverse_obj
|
|
from yt_dlp.cookies import YoutubeDLCookieJar
|
|
from yt_dlp.dependencies import websockets
|
|
from yt_dlp.networking import Request
|
|
from yt_dlp.networking.exceptions import (
|
|
CertificateVerifyError,
|
|
HTTPError,
|
|
ProxyError,
|
|
RequestError,
|
|
SSLError,
|
|
TransportError,
|
|
)
|
|
from yt_dlp.utils.networking import HTTPHeaderDict
|
|
|
|
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
def websocket_handler(websocket):
|
|
for message in websocket:
|
|
if isinstance(message, bytes):
|
|
if message == b'bytes':
|
|
return websocket.send('2')
|
|
elif isinstance(message, str):
|
|
if message == 'headers':
|
|
return websocket.send(json.dumps(dict(websocket.request.headers)))
|
|
elif message == 'path':
|
|
return websocket.send(websocket.request.path)
|
|
elif message == 'source_address':
|
|
return websocket.send(websocket.remote_address[0])
|
|
elif message == 'str':
|
|
return websocket.send('1')
|
|
return websocket.send(message)
|
|
|
|
|
|
def process_request(self, request):
|
|
if request.path.startswith('/gen_'):
|
|
status = http.HTTPStatus(int(request.path[5:]))
|
|
if 300 <= status.value <= 300:
|
|
return websockets.http11.Response(
|
|
status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
|
|
return self.protocol.reject(status.value, status.phrase)
|
|
return self.protocol.accept(request)
|
|
|
|
|
|
def create_websocket_server(**ws_kwargs):
|
|
import websockets.sync.server
|
|
wsd = websockets.sync.server.serve(
|
|
websocket_handler, '127.0.0.1', 0,
|
|
process_request=process_request, open_timeout=2, **ws_kwargs)
|
|
ws_port = wsd.socket.getsockname()[1]
|
|
ws_server_thread = threading.Thread(target=wsd.serve_forever)
|
|
ws_server_thread.daemon = True
|
|
ws_server_thread.start()
|
|
return ws_server_thread, ws_port
|
|
|
|
|
|
def create_ws_websocket_server():
|
|
return create_websocket_server()
|
|
|
|
|
|
def create_wss_websocket_server():
|
|
certfn = os.path.join(TEST_DIR, 'testcert.pem')
|
|
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
sslctx.load_cert_chain(certfn, None)
|
|
return create_websocket_server(ssl_context=sslctx)
|
|
|
|
|
|
MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
|
|
|
|
|
|
def create_mtls_wss_websocket_server():
|
|
certfn = os.path.join(TEST_DIR, 'testcert.pem')
|
|
cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
|
|
|
|
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
sslctx.verify_mode = ssl.CERT_REQUIRED
|
|
sslctx.load_verify_locations(cafile=cacertfn)
|
|
sslctx.load_cert_chain(certfn, None)
|
|
|
|
return create_websocket_server(ssl_context=sslctx)
|
|
|
|
|
|
def ws_validate_and_send(rh, req):
|
|
rh.validate(req)
|
|
max_tries = 3
|
|
for i in range(max_tries):
|
|
try:
|
|
return rh.send(req)
|
|
except TransportError as e:
|
|
if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
|
|
# websockets server sometimes hangs on new connections
|
|
continue
|
|
raise
|
|
|
|
|
|
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
|
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
class TestWebsSocketRequestHandlerConformance:
|
|
@classmethod
|
|
def setup_class(cls):
|
|
cls.ws_thread, cls.ws_port = create_ws_websocket_server()
|
|
cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
|
|
|
|
cls.wss_thread, cls.wss_port = create_wss_websocket_server()
|
|
cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
|
|
|
|
cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
|
|
cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
|
|
|
|
cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
|
|
cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
|
|
|
|
def test_basic_websockets(self, handler):
|
|
with handler() as rh:
|
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
|
assert 'upgrade' in ws.headers
|
|
assert ws.status == 101
|
|
ws.send('foo')
|
|
assert ws.recv() == 'foo'
|
|
ws.close()
|
|
|
|
# https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
|
|
@pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
|
|
def test_send_types(self, handler, msg, opcode):
|
|
with handler() as rh:
|
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
|
ws.send(msg)
|
|
assert int(ws.recv()) == opcode
|
|
ws.close()
|
|
|
|
def test_verify_cert(self, handler):
|
|
with handler() as rh:
|
|
with pytest.raises(CertificateVerifyError):
|
|
ws_validate_and_send(rh, Request(self.wss_base_url))
|
|
|
|
with handler(verify=False) as rh:
|
|
ws = ws_validate_and_send(rh, Request(self.wss_base_url))
|
|
assert ws.status == 101
|
|
ws.close()
|
|
|
|
def test_ssl_error(self, handler):
|
|
with handler(verify=False) as rh:
|
|
with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
|
|
ws_validate_and_send(rh, Request(self.bad_wss_host))
|
|
assert not issubclass(exc_info.type, CertificateVerifyError)
|
|
|
|
@pytest.mark.parametrize('path,expected', [
|
|
# Unicode characters should be encoded with uppercase percent-encoding
|
|
('/中文', '/%E4%B8%AD%E6%96%87'),
|
|
# don't normalize existing percent encodings
|
|
('/%c7%9f', '/%c7%9f'),
|
|
])
|
|
def test_percent_encode(self, handler, path, expected):
|
|
with handler() as rh:
|
|
ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
|
|
ws.send('path')
|
|
assert ws.recv() == expected
|
|
assert ws.status == 101
|
|
ws.close()
|
|
|
|
def test_remove_dot_segments(self, handler):
|
|
with handler() as rh:
|
|
# This isn't a comprehensive test,
|
|
# but it should be enough to check whether the handler is removing dot segments
|
|
ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
|
|
assert ws.status == 101
|
|
ws.send('path')
|
|
assert ws.recv() == '/test'
|
|
ws.close()
|
|
|
|
# We are restricted to known HTTP status codes in http.HTTPStatus
|
|
# Redirects are not supported for websockets
|
|
@pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
|
|
def test_raise_http_error(self, handler, status):
|
|
with handler() as rh:
|
|
with pytest.raises(HTTPError) as exc_info:
|
|
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
|
|
assert exc_info.value.status == status
|
|
|
|
@pytest.mark.parametrize('params,extensions', [
|
|
({'timeout': sys.float_info.min}, {}),
|
|
({}, {'timeout': sys.float_info.min}),
|
|
])
|
|
def test_read_timeout(self, handler, params, extensions):
|
|
with handler(**params) as rh:
|
|
with pytest.raises(TransportError):
|
|
ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
|
|
|
|
def test_connect_timeout(self, handler):
|
|
# nothing should be listening on this port
|
|
connect_timeout_url = 'ws://10.255.255.255'
|
|
with handler(timeout=0.01) as rh, pytest.raises(TransportError):
|
|
now = time.time()
|
|
ws_validate_and_send(rh, Request(connect_timeout_url))
|
|
assert time.time() - now < DEFAULT_TIMEOUT
|
|
|
|
# Per request timeout, should override handler timeout
|
|
request = Request(connect_timeout_url, extensions={'timeout': 0.01})
|
|
with handler() as rh, pytest.raises(TransportError):
|
|
now = time.time()
|
|
ws_validate_and_send(rh, request)
|
|
assert time.time() - now < DEFAULT_TIMEOUT
|
|
|
|
def test_cookies(self, handler):
|
|
cookiejar = YoutubeDLCookieJar()
|
|
cookiejar.set_cookie(http.cookiejar.Cookie(
|
|
version=0, name='test', value='ytdlp', port=None, port_specified=False,
|
|
domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
|
|
path_specified=True, secure=False, expires=None, discard=False, comment=None,
|
|
comment_url=None, rest={}))
|
|
|
|
with handler(cookiejar=cookiejar) as rh:
|
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
|
ws.send('headers')
|
|
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
|
ws.close()
|
|
|
|
with handler() as rh:
|
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
|
ws.send('headers')
|
|
assert 'cookie' not in json.loads(ws.recv())
|
|
ws.close()
|
|
|
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
|
|
ws.send('headers')
|
|
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
|
ws.close()
|
|
|
|
def test_source_address(self, handler):
|
|
source_address = f'127.0.0.{random.randint(5, 255)}'
|
|
verify_address_availability(source_address)
|
|
with handler(source_address=source_address) as rh:
|
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
|
ws.send('source_address')
|
|
assert source_address == ws.recv()
|
|
ws.close()
|
|
|
|
def test_response_url(self, handler):
|
|
with handler() as rh:
|
|
url = f'{self.ws_base_url}/something'
|
|
ws = ws_validate_and_send(rh, Request(url))
|
|
assert ws.url == url
|
|
ws.close()
|
|
|
|
def test_request_headers(self, handler):
|
|
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
|
|
# Global Headers
|
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
|
ws.send('headers')
|
|
headers = HTTPHeaderDict(json.loads(ws.recv()))
|
|
assert headers['test1'] == 'test'
|
|
ws.close()
|
|
|
|
# Per request headers, merged with global
|
|
ws = ws_validate_and_send(rh, Request(
|
|
self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
|
|
ws.send('headers')
|
|
headers = HTTPHeaderDict(json.loads(ws.recv()))
|
|
assert headers['test1'] == 'test'
|
|
assert headers['test2'] == 'changed'
|
|
assert headers['test3'] == 'test3'
|
|
ws.close()
|
|
|
|
@pytest.mark.parametrize('client_cert', (
|
|
{'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
|
|
{
|
|
'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
|
|
'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
|
|
},
|
|
{
|
|
'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
|
|
'client_certificate_password': 'foobar',
|
|
},
|
|
{
|
|
'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
|
|
'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
|
|
'client_certificate_password': 'foobar',
|
|
}
|
|
))
|
|
def test_mtls(self, handler, client_cert):
|
|
with handler(
|
|
# Disable client-side validation of unacceptable self-signed testcert.pem
|
|
# The test is of a check on the server side, so unaffected
|
|
verify=False,
|
|
client_cert=client_cert
|
|
) as rh:
|
|
ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
|
|
|
|
def test_request_disable_proxy(self, handler):
|
|
for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
|
|
# Given handler is configured with a proxy
|
|
with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
|
|
# When a proxy is explicitly set to None for the request
|
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'http': None}))
|
|
# Then no proxy should be used
|
|
assert ws.status == 101
|
|
ws.close()
|
|
|
|
@pytest.mark.skip_handlers_if(
|
|
lambda _, handler: Features.NO_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support NO_PROXY')
|
|
def test_noproxy(self, handler):
|
|
for proxy_proto in handler._SUPPORTED_PROXY_SCHEMES or ['ws']:
|
|
# Given the handler is configured with a proxy
|
|
with handler(proxies={'ws': f'{proxy_proto}://10.255.255.255'}, timeout=5) as rh:
|
|
for no_proxy in (f'127.0.0.1:{self.ws_port}', '127.0.0.1', 'localhost'):
|
|
# When request no proxy includes the request url host
|
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url, proxies={'no': no_proxy}))
|
|
# Then the proxy should not be used
|
|
assert ws.status == 101
|
|
ws.close()
|
|
|
|
@pytest.mark.skip_handlers_if(
|
|
lambda _, handler: Features.ALL_PROXY not in handler._SUPPORTED_FEATURES, 'handler does not support ALL_PROXY')
|
|
def test_allproxy(self, handler):
|
|
supported_proto = traverse_obj(handler._SUPPORTED_PROXY_SCHEMES, 0, default='ws')
|
|
# This is a bit of a hacky test, but it should be enough to check whether the handler is using the proxy.
|
|
# 0.1s might not be enough of a timeout if proxy is not used in all cases, but should still get failures.
|
|
with handler(proxies={'all': f'{supported_proto}://10.255.255.255'}, timeout=0.1) as rh:
|
|
with pytest.raises(TransportError):
|
|
ws_validate_and_send(rh, Request(self.ws_base_url)).close()
|
|
|
|
with handler(timeout=0.1) as rh:
|
|
with pytest.raises(TransportError):
|
|
ws_validate_and_send(
|
|
rh, Request(self.ws_base_url, proxies={'all': f'{supported_proto}://10.255.255.255'})).close()
|
|
|
|
|
|
def create_fake_ws_connection(raised):
|
|
import websockets.sync.client
|
|
|
|
class FakeWsConnection(websockets.sync.client.ClientConnection):
|
|
def __init__(self, *args, **kwargs):
|
|
class FakeResponse:
|
|
body = b''
|
|
headers = {}
|
|
status_code = 101
|
|
reason_phrase = 'test'
|
|
|
|
self.response = FakeResponse()
|
|
|
|
def send(self, *args, **kwargs):
|
|
raise raised()
|
|
|
|
def recv(self, *args, **kwargs):
|
|
raise raised()
|
|
|
|
def close(self, *args, **kwargs):
|
|
return
|
|
|
|
return FakeWsConnection()
|
|
|
|
|
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
|
class TestWebsocketsRequestHandler:
|
|
@pytest.mark.parametrize('raised,expected', [
|
|
# https://websockets.readthedocs.io/en/stable/reference/exceptions.html
|
|
(lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
|
|
# Requires a response object. Should be covered by HTTP error tests.
|
|
# (lambda: websockets.exceptions.InvalidStatus(), TransportError),
|
|
(lambda: websockets.exceptions.InvalidHandshake(), TransportError),
|
|
# These are subclasses of InvalidHandshake
|
|
(lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
|
|
(lambda: websockets.exceptions.NegotiationError(), TransportError),
|
|
# Catch-all
|
|
(lambda: websockets.exceptions.WebSocketException(), TransportError),
|
|
(lambda: TimeoutError(), TransportError),
|
|
# These may be raised by our create_connection implementation, which should also be caught
|
|
(lambda: OSError(), TransportError),
|
|
(lambda: ssl.SSLError(), SSLError),
|
|
(lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
|
|
(lambda: socks.ProxyError(), ProxyError),
|
|
])
|
|
def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
|
|
import websockets.sync.client
|
|
|
|
import yt_dlp.networking._websockets
|
|
with handler() as rh:
|
|
def fake_connect(*args, **kwargs):
|
|
raise raised()
|
|
monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
|
|
monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
|
|
with pytest.raises(expected) as exc_info:
|
|
rh.send(Request('ws://fake-url'))
|
|
assert exc_info.type is expected
|
|
|
|
@pytest.mark.parametrize('raised,expected,match', [
|
|
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
|
|
(lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
|
|
(lambda: RuntimeError(), TransportError, None),
|
|
(lambda: TimeoutError(), TransportError, None),
|
|
(lambda: TypeError(), RequestError, None),
|
|
(lambda: socks.ProxyError(), ProxyError, None),
|
|
# Catch-all
|
|
(lambda: websockets.exceptions.WebSocketException(), TransportError, None),
|
|
])
|
|
def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
|
|
from yt_dlp.networking._websockets import WebsocketsResponseAdapter
|
|
ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
|
|
with pytest.raises(expected, match=match) as exc_info:
|
|
ws.send('test')
|
|
assert exc_info.type is expected
|
|
|
|
@pytest.mark.parametrize('raised,expected,match', [
|
|
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
|
|
(lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
|
|
(lambda: RuntimeError(), TransportError, None),
|
|
(lambda: TimeoutError(), TransportError, None),
|
|
(lambda: socks.ProxyError(), ProxyError, None),
|
|
# Catch-all
|
|
(lambda: websockets.exceptions.WebSocketException(), TransportError, None),
|
|
])
|
|
def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
|
|
from yt_dlp.networking._websockets import WebsocketsResponseAdapter
|
|
ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
|
|
with pytest.raises(expected, match=match) as exc_info:
|
|
ws.recv()
|
|
assert exc_info.type is expected
|