diff --git a/requirements.txt b/requirements.txt index 5b6270a7da..d983fa03ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ brotlicffi; implementation_name!='cpython' certifi requests>=2.31.0,<3 urllib3>=1.26.17,<3 +websockets>=12.0 diff --git a/test/conftest.py b/test/conftest.py index 15549d30b9..2fbc269e1f 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -19,3 +19,8 @@ def handler(request): pytest.skip(f'{RH_KEY} request handler is not available') return functools.partial(handler, logger=FakeLogger) + + +def validate_and_send(rh, req): + rh.validate(req) + return rh.send(req) diff --git a/test/test_networking.py b/test/test_networking.py index 4466fc0485..64af6e459a 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -52,6 +52,8 @@ from yt_dlp.networking.exceptions import ( from yt_dlp.utils._utils import _YDLLogger as FakeLogger from yt_dlp.utils.networking import HTTPHeaderDict +from test.conftest import validate_and_send + TEST_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -275,11 +277,6 @@ class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler): self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode()) -def validate_and_send(rh, req): - rh.validate(req) - return rh.send(req) - - class TestRequestHandlerBase: @classmethod def setup_class(cls): @@ -872,8 +869,9 @@ class TestRequestsRequestHandler(TestRequestHandlerBase): ]) @pytest.mark.parametrize('handler', ['Requests'], indirect=True) def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match): - from urllib3.response import HTTPResponse as Urllib3Response from requests.models import Response as RequestsResponse + from urllib3.response import HTTPResponse as Urllib3Response + from yt_dlp.networking._requests import RequestsResponseAdapter requests_res = RequestsResponse() requests_res.raw = Urllib3Response(body=b'', status=200) @@ -929,13 +927,17 @@ class TestRequestHandlerValidation: ('http', False, {}), ('https', False, {}), ]), + ('Websockets', [ + ('ws', False, {}), + ('wss', False, {}), + ]), (NoCheckRH, [('http', False, {})]), (ValidationRH, [('http', UnsupportedRequest, {})]) ] PROXY_SCHEME_TESTS = [ # scheme, expected to fail - ('Urllib', [ + ('Urllib', 'http', [ ('http', False), ('https', UnsupportedRequest), ('socks4', False), @@ -944,7 +946,7 @@ class TestRequestHandlerValidation: ('socks5h', False), ('socks', UnsupportedRequest), ]), - ('Requests', [ + ('Requests', 'http', [ ('http', False), ('https', False), ('socks4', False), @@ -952,8 +954,11 @@ class TestRequestHandlerValidation: ('socks5', False), ('socks5h', False), ]), - (NoCheckRH, [('http', False)]), - (HTTPSupportedRH, [('http', UnsupportedRequest)]), + (NoCheckRH, 'http', [('http', False)]), + (HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]), + ('Websockets', 'ws', [('http', UnsupportedRequest)]), + (NoCheckRH, 'http', [('http', False)]), + (HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]), ] PROXY_KEY_TESTS = [ @@ -972,7 +977,7 @@ class TestRequestHandlerValidation: ] EXTENSION_TESTS = [ - ('Urllib', [ + ('Urllib', 'http', [ ({'cookiejar': 'notacookiejar'}, AssertionError), ({'cookiejar': YoutubeDLCookieJar()}, False), ({'cookiejar': CookieJar()}, AssertionError), @@ -980,17 +985,21 @@ class TestRequestHandlerValidation: ({'timeout': 'notatimeout'}, AssertionError), ({'unsupported': 'value'}, UnsupportedRequest), ]), - ('Requests', [ + ('Requests', 'http', [ ({'cookiejar': 'notacookiejar'}, AssertionError), ({'cookiejar': YoutubeDLCookieJar()}, False), ({'timeout': 1}, False), ({'timeout': 'notatimeout'}, AssertionError), ({'unsupported': 'value'}, UnsupportedRequest), ]), - (NoCheckRH, [ + (NoCheckRH, 'http', [ ({'cookiejar': 'notacookiejar'}, False), ({'somerandom': 'test'}, False), # but any extension is allowed through ]), + ('Websockets', 'ws', [ + ({'cookiejar': YoutubeDLCookieJar()}, False), + ({'timeout': 2}, False), + ]), ] @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [ @@ -1016,14 +1025,14 @@ class TestRequestHandlerValidation: run_validation(handler, fail, Request('http://', proxies={proxy_key: 'http://example.com'})) run_validation(handler, fail, Request('http://'), proxies={proxy_key: 'http://example.com'}) - @pytest.mark.parametrize('handler,scheme,fail', [ - (handler_tests[0], scheme, fail) + @pytest.mark.parametrize('handler,req_scheme,scheme,fail', [ + (handler_tests[0], handler_tests[1], scheme, fail) for handler_tests in PROXY_SCHEME_TESTS - for scheme, fail in handler_tests[1] + for scheme, fail in handler_tests[2] ], indirect=['handler']) - def test_proxy_scheme(self, handler, scheme, fail): - run_validation(handler, fail, Request('http://', proxies={'http': f'{scheme}://example.com'})) - run_validation(handler, fail, Request('http://'), proxies={'http': f'{scheme}://example.com'}) + def test_proxy_scheme(self, handler, req_scheme, scheme, fail): + run_validation(handler, fail, Request(f'{req_scheme}://', proxies={req_scheme: f'{scheme}://example.com'})) + run_validation(handler, fail, Request(f'{req_scheme}://'), proxies={req_scheme: f'{scheme}://example.com'}) @pytest.mark.parametrize('handler', ['Urllib', HTTPSupportedRH, 'Requests'], indirect=True) def test_empty_proxy(self, handler): @@ -1035,14 +1044,14 @@ class TestRequestHandlerValidation: def test_invalid_proxy_url(self, handler, proxy_url): run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': proxy_url})) - @pytest.mark.parametrize('handler,extensions,fail', [ - (handler_tests[0], extensions, fail) + @pytest.mark.parametrize('handler,scheme,extensions,fail', [ + (handler_tests[0], handler_tests[1], extensions, fail) for handler_tests in EXTENSION_TESTS - for extensions, fail in handler_tests[1] + for extensions, fail in handler_tests[2] ], indirect=['handler']) - def test_extension(self, handler, extensions, fail): + def test_extension(self, handler, scheme, extensions, fail): run_validation( - handler, fail, Request('http://', extensions=extensions)) + handler, fail, Request(f'{scheme}://', extensions=extensions)) def test_invalid_request_type(self): rh = self.ValidationRH(logger=FakeLogger()) @@ -1075,6 +1084,22 @@ class FakeRHYDL(FakeYDL): self._request_director = self.build_request_director([FakeRH]) +class AllUnsupportedRHYDL(FakeYDL): + + def __init__(self, *args, **kwargs): + + class UnsupportedRH(RequestHandler): + def _send(self, request: Request): + pass + + _SUPPORTED_FEATURES = () + _SUPPORTED_PROXY_SCHEMES = () + _SUPPORTED_URL_SCHEMES = () + + super().__init__(*args, **kwargs) + self._request_director = self.build_request_director([UnsupportedRH]) + + class TestRequestDirector: def test_handler_operations(self): @@ -1234,6 +1259,12 @@ class TestYoutubeDLNetworking: with pytest.raises(RequestError, match=r'file:// URLs are disabled by default'): ydl.urlopen('file://') + @pytest.mark.parametrize('scheme', (['ws', 'wss'])) + def test_websocket_unavailable_error(self, scheme): + with AllUnsupportedRHYDL() as ydl: + with pytest.raises(RequestError, match=r'This request requires WebSocket support'): + ydl.urlopen(f'{scheme}://') + def test_legacy_server_connect_error(self): with FakeRHYDL() as ydl: for error in ('UNSAFE_LEGACY_RENEGOTIATION_DISABLED', 'SSLV3_ALERT_HANDSHAKE_FAILURE'): diff --git a/test/test_socks.py b/test/test_socks.py index d8ac88dad5..71f783e132 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -210,6 +210,16 @@ class SocksHTTPTestRequestHandler(http.server.BaseHTTPRequestHandler, SocksTestR self.wfile.write(payload.encode()) +class SocksWebSocketTestRequestHandler(SocksTestRequestHandler): + def handle(self): + import websockets.sync.server + protocol = websockets.ServerProtocol() + connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0) + connection.handshake() + connection.send(json.dumps(self.socks_info)) + connection.close() + + @contextlib.contextmanager def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs): server = server_thread = None @@ -252,8 +262,22 @@ class HTTPSocksTestProxyContext(SocksProxyTestContext): return json.loads(handler.send(request).read().decode()) +class WebSocketSocksTestProxyContext(SocksProxyTestContext): + REQUEST_HANDLER_CLASS = SocksWebSocketTestRequestHandler + + def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs): + request = Request(f'ws://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs) + handler.validate(request) + ws = handler.send(request) + ws.send('socks_info') + socks_info = ws.recv() + ws.close() + return json.loads(socks_info) + + CTX_MAP = { 'http': HTTPSocksTestProxyContext, + 'ws': WebSocketSocksTestProxyContext, } @@ -263,7 +287,7 @@ def ctx(request): class TestSocks4Proxy: - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks4_no_auth(self, handler, ctx): with handler() as rh: with ctx.socks_server(Socks4ProxyHandler) as server_address: @@ -271,7 +295,7 @@ class TestSocks4Proxy: rh, proxies={'all': f'socks4://{server_address}'}) assert response['version'] == 4 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks4_auth(self, handler, ctx): with handler() as rh: with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address: @@ -281,7 +305,7 @@ class TestSocks4Proxy: rh, proxies={'all': f'socks4://user:@{server_address}'}) assert response['version'] == 4 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks4a_ipv4_target(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler) as server_address: with handler(proxies={'all': f'socks4a://{server_address}'}) as rh: @@ -289,7 +313,7 @@ class TestSocks4Proxy: assert response['version'] == 4 assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1') - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks4a_domain_target(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler) as server_address: with handler(proxies={'all': f'socks4a://{server_address}'}) as rh: @@ -298,7 +322,7 @@ class TestSocks4Proxy: assert response['ipv4_address'] is None assert response['domain_address'] == 'localhost' - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_ipv4_client_source_address(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler) as server_address: source_address = f'127.0.0.{random.randint(5, 255)}' @@ -308,7 +332,7 @@ class TestSocks4Proxy: assert response['client_address'][0] == source_address assert response['version'] == 4 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) @pytest.mark.parametrize('reply_code', [ Socks4CD.REQUEST_REJECTED_OR_FAILED, Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD, @@ -320,7 +344,7 @@ class TestSocks4Proxy: with pytest.raises(ProxyError): ctx.socks_info_request(rh) - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_ipv6_socks4_proxy(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address: with handler(proxies={'all': f'socks4://{server_address}'}) as rh: @@ -329,7 +353,7 @@ class TestSocks4Proxy: assert response['ipv4_address'] == '127.0.0.1' assert response['version'] == 4 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_timeout(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address: with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh: @@ -339,7 +363,7 @@ class TestSocks4Proxy: class TestSocks5Proxy: - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5_no_auth(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: @@ -347,7 +371,7 @@ class TestSocks5Proxy: assert response['auth_methods'] == [0x0] assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5_user_pass(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address: with handler() as rh: @@ -360,7 +384,7 @@ class TestSocks5Proxy: assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS] assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5_ipv4_target(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: @@ -368,7 +392,7 @@ class TestSocks5Proxy: assert response['ipv4_address'] == '127.0.0.1' assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5_domain_target(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: @@ -376,7 +400,7 @@ class TestSocks5Proxy: assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1') assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5h_domain_target(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5h://{server_address}'}) as rh: @@ -385,7 +409,7 @@ class TestSocks5Proxy: assert response['domain_address'] == 'localhost' assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5h_ip_target(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5h://{server_address}'}) as rh: @@ -394,7 +418,7 @@ class TestSocks5Proxy: assert response['domain_address'] is None assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5_ipv6_destination(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: @@ -402,7 +426,7 @@ class TestSocks5Proxy: assert response['ipv6_address'] == '::1' assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_ipv6_socks5_proxy(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: @@ -413,7 +437,7 @@ class TestSocks5Proxy: # XXX: is there any feasible way of testing IPv6 source addresses? # Same would go for non-proxy source_address test... - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_ipv4_client_source_address(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: source_address = f'127.0.0.{random.randint(5, 255)}' @@ -422,7 +446,7 @@ class TestSocks5Proxy: assert response['client_address'][0] == source_address assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) @pytest.mark.parametrize('reply_code', [ Socks5Reply.GENERAL_FAILURE, Socks5Reply.CONNECTION_NOT_ALLOWED, @@ -439,7 +463,7 @@ class TestSocks5Proxy: with pytest.raises(ProxyError): ctx.socks_info_request(rh) - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Websockets', 'ws')], indirect=True) def test_timeout(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address: with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh: diff --git a/test/test_websockets.py b/test/test_websockets.py new file mode 100644 index 0000000000..39d3c7d722 --- /dev/null +++ b/test/test_websockets.py @@ -0,0 +1,380 @@ +#!/usr/bin/env python3 + +# Allow direct execution +import os +import sys + +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import http.client +import http.cookiejar +import http.server +import json +import random +import ssl +import threading + +from yt_dlp import socks +from yt_dlp.cookies import YoutubeDLCookieJar +from yt_dlp.dependencies import websockets +from yt_dlp.networking import Request +from yt_dlp.networking.exceptions import ( + CertificateVerifyError, + HTTPError, + ProxyError, + RequestError, + SSLError, + TransportError, +) +from yt_dlp.utils.networking import HTTPHeaderDict + +from test.conftest import validate_and_send + +TEST_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def websocket_handler(websocket): + for message in websocket: + if isinstance(message, bytes): + if message == b'bytes': + return websocket.send('2') + elif isinstance(message, str): + if message == 'headers': + return websocket.send(json.dumps(dict(websocket.request.headers))) + elif message == 'path': + return websocket.send(websocket.request.path) + elif message == 'source_address': + return websocket.send(websocket.remote_address[0]) + elif message == 'str': + return websocket.send('1') + return websocket.send(message) + + +def process_request(self, request): + if request.path.startswith('/gen_'): + status = http.HTTPStatus(int(request.path[5:])) + if 300 <= status.value <= 300: + return websockets.http11.Response( + status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'') + return self.protocol.reject(status.value, status.phrase) + return self.protocol.accept(request) + + +def create_websocket_server(**ws_kwargs): + import websockets.sync.server + wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs) + ws_port = wsd.socket.getsockname()[1] + ws_server_thread = threading.Thread(target=wsd.serve_forever) + ws_server_thread.daemon = True + ws_server_thread.start() + return ws_server_thread, ws_port + + +def create_ws_websocket_server(): + return create_websocket_server() + + +def create_wss_websocket_server(): + certfn = os.path.join(TEST_DIR, 'testcert.pem') + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslctx.load_cert_chain(certfn, None) + return create_websocket_server(ssl_context=sslctx) + + +MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate') + + +def create_mtls_wss_websocket_server(): + certfn = os.path.join(TEST_DIR, 'testcert.pem') + cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt') + + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslctx.verify_mode = ssl.CERT_REQUIRED + sslctx.load_verify_locations(cafile=cacertfn) + sslctx.load_cert_chain(certfn, None) + + return create_websocket_server(ssl_context=sslctx) + + +@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') +class TestWebsSocketRequestHandlerConformance: + @classmethod + def setup_class(cls): + cls.ws_thread, cls.ws_port = create_ws_websocket_server() + cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}' + + cls.wss_thread, cls.wss_port = create_wss_websocket_server() + cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}' + + cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)) + cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}' + + cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server() + cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}' + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_basic_websockets(self, handler): + with handler() as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + assert 'upgrade' in ws.headers + assert ws.status == 101 + ws.send('foo') + assert ws.recv() == 'foo' + ws.close() + + # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)]) + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_send_types(self, handler, msg, opcode): + with handler() as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send(msg) + assert int(ws.recv()) == opcode + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_verify_cert(self, handler): + with handler() as rh: + with pytest.raises(CertificateVerifyError): + validate_and_send(rh, Request(self.wss_base_url)) + + with handler(verify=False) as rh: + ws = validate_and_send(rh, Request(self.wss_base_url)) + assert ws.status == 101 + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_ssl_error(self, handler): + with handler(verify=False) as rh: + with pytest.raises(SSLError, match='sslv3 alert handshake failure') as exc_info: + validate_and_send(rh, Request(self.bad_wss_host)) + assert not issubclass(exc_info.type, CertificateVerifyError) + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + @pytest.mark.parametrize('path,expected', [ + # Unicode characters should be encoded with uppercase percent-encoding + ('/中文', '/%E4%B8%AD%E6%96%87'), + # don't normalize existing percent encodings + ('/%c7%9f', '/%c7%9f'), + ]) + def test_percent_encode(self, handler, path, expected): + with handler() as rh: + ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}')) + ws.send('path') + assert ws.recv() == expected + assert ws.status == 101 + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_remove_dot_segments(self, handler): + with handler() as rh: + # This isn't a comprehensive test, + # but it should be enough to check whether the handler is removing dot segments + ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test')) + assert ws.status == 101 + ws.send('path') + assert ws.recv() == '/test' + ws.close() + + # We are restricted to known HTTP status codes in http.HTTPStatus + # Redirects are not supported for websockets + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511)) + def test_raise_http_error(self, handler, status): + with handler() as rh: + with pytest.raises(HTTPError) as exc_info: + validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) + assert exc_info.value.status == status + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + @pytest.mark.parametrize('params,extensions', [ + ({'timeout': 0.00001}, {}), + ({}, {'timeout': 0.00001}), + ]) + def test_timeout(self, handler, params, extensions): + with handler(**params) as rh: + with pytest.raises(TransportError): + validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_cookies(self, handler): + cookiejar = YoutubeDLCookieJar() + cookiejar.set_cookie(http.cookiejar.Cookie( + version=0, name='test', value='ytdlp', port=None, port_specified=False, + domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/', + path_specified=True, secure=False, expires=None, discard=False, comment=None, + comment_url=None, rest={})) + + with handler(cookiejar=cookiejar) as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send('headers') + assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' + ws.close() + + with handler() as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send('headers') + assert 'cookie' not in json.loads(ws.recv()) + ws.close() + + ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) + ws.send('headers') + assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_source_address(self, handler): + source_address = f'127.0.0.{random.randint(5, 255)}' + with handler(source_address=source_address) as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send('source_address') + assert source_address == ws.recv() + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_response_url(self, handler): + with handler() as rh: + url = f'{self.ws_base_url}/something' + ws = validate_and_send(rh, Request(url)) + assert ws.url == url + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_request_headers(self, handler): + with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: + # Global Headers + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send('headers') + headers = HTTPHeaderDict(json.loads(ws.recv())) + assert headers['test1'] == 'test' + ws.close() + + # Per request headers, merged with global + ws = validate_and_send(rh, Request( + self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'})) + ws.send('headers') + headers = HTTPHeaderDict(json.loads(ws.recv())) + assert headers['test1'] == 'test' + assert headers['test2'] == 'changed' + assert headers['test3'] == 'test3' + ws.close() + + @pytest.mark.parametrize('client_cert', ( + {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}, + { + 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), + 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'), + }, + { + 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'), + 'client_certificate_password': 'foobar', + }, + { + 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), + 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'), + 'client_certificate_password': 'foobar', + } + )) + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_mtls(self, handler, client_cert): + with handler( + # Disable client-side validation of unacceptable self-signed testcert.pem + # The test is of a check on the server side, so unaffected + verify=False, + client_cert=client_cert + ) as rh: + validate_and_send(rh, Request(self.mtls_wss_base_url)).close() + + +def create_fake_ws_connection(raised): + import websockets.sync.client + + class FakeWsConnection(websockets.sync.client.ClientConnection): + def __init__(self, *args, **kwargs): + class FakeResponse: + body = b'' + headers = {} + status_code = 101 + reason_phrase = 'test' + + self.response = FakeResponse() + + def send(self, *args, **kwargs): + raise raised() + + def recv(self, *args, **kwargs): + raise raised() + + def close(self, *args, **kwargs): + return + + return FakeWsConnection() + + +@pytest.mark.parametrize('handler', ['Websockets'], indirect=True) +class TestWebsocketsRequestHandler: + @pytest.mark.parametrize('raised,expected', [ + # https://websockets.readthedocs.io/en/stable/reference/exceptions.html + (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError), + # Requires a response object. Should be covered by HTTP error tests. + # (lambda: websockets.exceptions.InvalidStatus(), TransportError), + (lambda: websockets.exceptions.InvalidHandshake(), TransportError), + # These are subclasses of InvalidHandshake + (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError), + (lambda: websockets.exceptions.NegotiationError(), TransportError), + # Catch-all + (lambda: websockets.exceptions.WebSocketException(), TransportError), + (lambda: TimeoutError(), TransportError), + # These may be raised by our create_connection implementation, which should also be caught + (lambda: OSError(), TransportError), + (lambda: ssl.SSLError(), SSLError), + (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError), + (lambda: socks.ProxyError(), ProxyError), + ]) + def test_request_error_mapping(self, handler, monkeypatch, raised, expected): + import websockets.sync.client + + import yt_dlp.networking._websockets + with handler() as rh: + def fake_connect(*args, **kwargs): + raise raised() + monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None) + monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect) + with pytest.raises(expected) as exc_info: + rh.send(Request('ws://fake-url')) + assert exc_info.type is expected + + @pytest.mark.parametrize('raised,expected,match', [ + # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send + (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), + (lambda: RuntimeError(), TransportError, None), + (lambda: TimeoutError(), TransportError, None), + (lambda: TypeError(), RequestError, None), + (lambda: socks.ProxyError(), ProxyError, None), + # Catch-all + (lambda: websockets.exceptions.WebSocketException(), TransportError, None), + ]) + def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match): + from yt_dlp.networking._websockets import WebsocketsResponseAdapter + ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') + with pytest.raises(expected, match=match) as exc_info: + ws.send('test') + assert exc_info.type is expected + + @pytest.mark.parametrize('raised,expected,match', [ + # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv + (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), + (lambda: RuntimeError(), TransportError, None), + (lambda: TimeoutError(), TransportError, None), + (lambda: socks.ProxyError(), ProxyError, None), + # Catch-all + (lambda: websockets.exceptions.WebSocketException(), TransportError, None), + ]) + def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match): + from yt_dlp.networking._websockets import WebsocketsResponseAdapter + ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') + with pytest.raises(expected, match=match) as exc_info: + ws.recv() + assert exc_info.type is expected diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py index 740826b452..85b282bd51 100644 --- a/yt_dlp/YoutubeDL.py +++ b/yt_dlp/YoutubeDL.py @@ -4052,6 +4052,7 @@ class YoutubeDL: return self._request_director.send(req) except NoSupportingHandlers as e: for ue in e.unsupported_errors: + # FIXME: This depends on the order of errors. if not (ue.handler and ue.msg): continue if ue.handler.RH_KEY == 'Urllib' and 'unsupported url scheme: "file"' in ue.msg.lower(): @@ -4061,6 +4062,15 @@ class YoutubeDL: if 'unsupported proxy type: "https"' in ue.msg.lower(): raise RequestError( 'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests') + + elif ( + re.match(r'unsupported url scheme: "wss?"', ue.msg.lower()) + and 'websockets' not in self._request_director.handlers + ): + raise RequestError( + 'This request requires WebSocket support. ' + 'Ensure one of the following dependencies are installed: websockets', + cause=ue) from ue raise except SSLError as e: if 'UNSAFE_LEGACY_RENEGOTIATION_DISABLED' in str(e): diff --git a/yt_dlp/downloader/niconico.py b/yt_dlp/downloader/niconico.py index 5720f6eb8f..fef8bff73a 100644 --- a/yt_dlp/downloader/niconico.py +++ b/yt_dlp/downloader/niconico.py @@ -6,7 +6,7 @@ from . import get_suitable_downloader from .common import FileDownloader from .external import FFmpegFD from ..networking import Request -from ..utils import DownloadError, WebSocketsWrapper, str_or_none, try_get +from ..utils import DownloadError, str_or_none, try_get class NiconicoDmcFD(FileDownloader): @@ -64,7 +64,6 @@ class NiconicoLiveFD(FileDownloader): ws_url = info_dict['url'] ws_extractor = info_dict['ws'] ws_origin_host = info_dict['origin'] - cookies = info_dict.get('cookies') live_quality = info_dict.get('live_quality', 'high') live_latency = info_dict.get('live_latency', 'high') dl = FFmpegFD(self.ydl, self.params or {}) @@ -76,12 +75,7 @@ class NiconicoLiveFD(FileDownloader): def communicate_ws(reconnect): if reconnect: - ws = WebSocketsWrapper(ws_url, { - 'Cookies': str_or_none(cookies) or '', - 'Origin': f'https://{ws_origin_host}', - 'Accept': '*/*', - 'User-Agent': self.params['http_headers']['User-Agent'], - }) + ws = self.ydl.urlopen(Request(ws_url, headers={'Origin': f'https://{ws_origin_host}'})) if self.ydl.params.get('verbose', False): self.to_screen('[debug] Sending startWatching request') ws.send(json.dumps({ diff --git a/yt_dlp/extractor/fc2.py b/yt_dlp/extractor/fc2.py index ba19b6cab4..bbc4b56931 100644 --- a/yt_dlp/extractor/fc2.py +++ b/yt_dlp/extractor/fc2.py @@ -2,11 +2,9 @@ import re from .common import InfoExtractor from ..compat import compat_parse_qs -from ..dependencies import websockets from ..networking import Request from ..utils import ( ExtractorError, - WebSocketsWrapper, js_to_json, traverse_obj, update_url_query, @@ -167,8 +165,6 @@ class FC2LiveIE(InfoExtractor): }] def _real_extract(self, url): - if not websockets: - raise ExtractorError('websockets library is not available. Please install it.', expected=True) video_id = self._match_id(url) webpage = self._download_webpage('https://live.fc2.com/%s/' % video_id, video_id) @@ -199,13 +195,9 @@ class FC2LiveIE(InfoExtractor): ws_url = update_url_query(control_server['url'], {'control_token': control_server['control_token']}) playlist_data = None - self.to_screen('%s: Fetching HLS playlist info via WebSocket' % video_id) - ws = WebSocketsWrapper(ws_url, { - 'Cookie': str(self._get_cookies('https://live.fc2.com/'))[12:], + ws = self._request_webpage(Request(ws_url, headers={ 'Origin': 'https://live.fc2.com', - 'Accept': '*/*', - 'User-Agent': self.get_param('http_headers')['User-Agent'], - }) + }), video_id, note='Fetching HLS playlist info via WebSocket') self.write_debug('Sending HLS server request') diff --git a/yt_dlp/extractor/niconico.py b/yt_dlp/extractor/niconico.py index fa2d709d28..797b5268af 100644 --- a/yt_dlp/extractor/niconico.py +++ b/yt_dlp/extractor/niconico.py @@ -8,12 +8,11 @@ import time from urllib.parse import urlparse from .common import InfoExtractor, SearchInfoExtractor -from ..dependencies import websockets +from ..networking import Request from ..networking.exceptions import HTTPError from ..utils import ( ExtractorError, OnDemandPagedList, - WebSocketsWrapper, bug_reports_message, clean_html, float_or_none, @@ -934,8 +933,6 @@ class NiconicoLiveIE(InfoExtractor): _KNOWN_LATENCY = ('high', 'low') def _real_extract(self, url): - if not websockets: - raise ExtractorError('websockets library is not available. Please install it.', expected=True) video_id = self._match_id(url) webpage, urlh = self._download_webpage_handle(f'https://live.nicovideo.jp/watch/{video_id}', video_id) @@ -950,17 +947,13 @@ class NiconicoLiveIE(InfoExtractor): }) hostname = remove_start(urlparse(urlh.url).hostname, 'sp.') - cookies = try_get(urlh.url, self._downloader._calc_cookies) latency = try_get(self._configuration_arg('latency'), lambda x: x[0]) if latency not in self._KNOWN_LATENCY: latency = 'high' - ws = WebSocketsWrapper(ws_url, { - 'Cookies': str_or_none(cookies) or '', - 'Origin': f'https://{hostname}', - 'Accept': '*/*', - 'User-Agent': self.get_param('http_headers')['User-Agent'], - }) + ws = self._request_webpage( + Request(ws_url, headers={'Origin': f'https://{hostname}'}), + video_id=video_id, note='Connecting to WebSocket server') self.write_debug('[debug] Sending HLS server request') ws.send(json.dumps({ @@ -1034,7 +1027,6 @@ class NiconicoLiveIE(InfoExtractor): 'protocol': 'niconico_live', 'ws': ws, 'video_id': video_id, - 'cookies': cookies, 'live_latency': latency, 'origin': hostname, }) diff --git a/yt_dlp/networking/__init__.py b/yt_dlp/networking/__init__.py index aa8d0eabe4..96c5a0678f 100644 --- a/yt_dlp/networking/__init__.py +++ b/yt_dlp/networking/__init__.py @@ -21,3 +21,11 @@ except ImportError: pass except Exception as e: warnings.warn(f'Failed to import "requests" request handler: {e}' + bug_reports_message()) + +try: + from . import _websockets +except ImportError: + pass +except Exception as e: + warnings.warn(f'Failed to import "websockets" request handler: {e}' + bug_reports_message()) + diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py new file mode 100644 index 0000000000..ad85554e45 --- /dev/null +++ b/yt_dlp/networking/_websockets.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import io +import logging +import ssl +import sys + +from ._helper import create_connection, select_proxy, make_socks_proxy_opts, create_socks_proxy_socket +from .common import Response, register_rh, Features +from .exceptions import ( + CertificateVerifyError, + HTTPError, + RequestError, + SSLError, + TransportError, ProxyError, +) +from .websocket import WebSocketRequestHandler, WebSocketResponse +from ..compat import functools +from ..dependencies import websockets +from ..utils import int_or_none +from ..socks import ProxyError as SocksProxyError + +if not websockets: + raise ImportError('websockets is not installed') + +import websockets.version + +websockets_version = tuple(map(int_or_none, websockets.version.version.split('.'))) +if websockets_version < (12, 0): + raise ImportError('Only websockets>=12.0 is supported') + +import websockets.sync.client +from websockets.uri import parse_uri + + +class WebsocketsResponseAdapter(WebSocketResponse): + + def __init__(self, wsw: websockets.sync.client.ClientConnection, url): + super().__init__( + fp=io.BytesIO(wsw.response.body or b''), + url=url, + headers=wsw.response.headers, + status=wsw.response.status_code, + reason=wsw.response.reason_phrase, + ) + self.wsw = wsw + + def close(self): + self.wsw.close() + super().close() + + def send(self, message): + # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send + try: + return self.wsw.send(message) + except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: + raise TransportError(cause=e) from e + except SocksProxyError as e: + raise ProxyError(cause=e) from e + except TypeError as e: + raise RequestError(cause=e) from e + + def recv(self): + # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv + try: + return self.wsw.recv() + except SocksProxyError as e: + raise ProxyError(cause=e) from e + except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e: + raise TransportError(cause=e) from e + + +@register_rh +class WebsocketsRH(WebSocketRequestHandler): + """ + Websockets request handler + https://websockets.readthedocs.io + https://github.com/python-websockets/websockets + """ + _SUPPORTED_URL_SCHEMES = ('wss', 'ws') + _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h') + _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY) + RH_NAME = 'websockets' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + for name in ('websockets.client', 'websockets.server'): + logger = logging.getLogger(name) + handler = logging.StreamHandler(stream=sys.stdout) + handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s')) + logger.addHandler(handler) + if self.verbose: + logger.setLevel(logging.DEBUG) + + def _check_extensions(self, extensions): + super()._check_extensions(extensions) + extensions.pop('timeout', None) + extensions.pop('cookiejar', None) + + def _send(self, request): + timeout = float(request.extensions.get('timeout') or self.timeout) + headers = self._merge_headers(request.headers) + if 'cookie' not in headers: + cookiejar = request.extensions.get('cookiejar') or self.cookiejar + cookie_header = cookiejar.get_cookie_header(request.url) + if cookie_header: + headers['cookie'] = cookie_header + + wsuri = parse_uri(request.url) + create_conn_kwargs = { + 'source_address': (self.source_address, 0) if self.source_address else None, + 'timeout': timeout + } + proxy = select_proxy(request.url, request.proxies or self.proxies or {}) + try: + if proxy: + socks_proxy_options = make_socks_proxy_opts(proxy) + sock = create_connection( + address=(socks_proxy_options['addr'], socks_proxy_options['port']), + _create_socket_func=functools.partial( + create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options), + **create_conn_kwargs + ) + else: + sock = create_connection( + address=(wsuri.host, wsuri.port), + **create_conn_kwargs + ) + conn = websockets.sync.client.connect( + sock=sock, + uri=request.url, + additional_headers=headers, + open_timeout=timeout, + user_agent_header=None, + ssl_context=self._make_sslcontext() if wsuri.secure else None, + close_timeout=0, # not ideal, but prevents yt-dlp hanging + ) + return WebsocketsResponseAdapter(conn, url=request.url) + + # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html + except SocksProxyError as e: + raise ProxyError(cause=e) from e + except websockets.exceptions.InvalidURI as e: + raise RequestError(cause=e) from e + except ssl.SSLCertVerificationError as e: + raise CertificateVerifyError(cause=e) from e + except ssl.SSLError as e: + raise SSLError(cause=e) from e + except websockets.exceptions.InvalidStatus as e: + raise HTTPError( + Response( + fp=io.BytesIO(e.response.body), + url=request.url, + headers=e.response.headers, + status=e.response.status_code, + reason=e.response.reason_phrase), + ) from e + except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e: + raise TransportError(cause=e) from e diff --git a/yt_dlp/networking/websocket.py b/yt_dlp/networking/websocket.py new file mode 100644 index 0000000000..09fcf78ac2 --- /dev/null +++ b/yt_dlp/networking/websocket.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import abc + +from .common import Response, RequestHandler + + +class WebSocketResponse(Response): + + def send(self, message: bytes | str): + """ + Send a message to the server. + + @param message: The message to send. A string (str) is sent as a text frame, bytes is sent as a binary frame. + """ + raise NotImplementedError + + def recv(self): + raise NotImplementedError + + +class WebSocketRequestHandler(RequestHandler, abc.ABC): + pass diff --git a/yt_dlp/utils/_legacy.py b/yt_dlp/utils/_legacy.py index dde02092c9..aa9f46d204 100644 --- a/yt_dlp/utils/_legacy.py +++ b/yt_dlp/utils/_legacy.py @@ -1,4 +1,6 @@ """No longer used and new code should not use. Exists only for API compat.""" +import asyncio +import atexit import platform import struct import sys @@ -32,6 +34,77 @@ has_certifi = bool(certifi) has_websockets = bool(websockets) +class WebSocketsWrapper: + """Wraps websockets module to use in non-async scopes""" + pool = None + + def __init__(self, url, headers=None, connect=True, **ws_kwargs): + self.loop = asyncio.new_event_loop() + # XXX: "loop" is deprecated + self.conn = websockets.connect( + url, extra_headers=headers, ping_interval=None, + close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'), **ws_kwargs) + if connect: + self.__enter__() + atexit.register(self.__exit__, None, None, None) + + def __enter__(self): + if not self.pool: + self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop) + return self + + def send(self, *args): + self.run_with_loop(self.pool.send(*args), self.loop) + + def recv(self, *args): + return self.run_with_loop(self.pool.recv(*args), self.loop) + + def __exit__(self, type, value, traceback): + try: + return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop) + finally: + self.loop.close() + self._cancel_all_tasks(self.loop) + + # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications + # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class + @staticmethod + def run_with_loop(main, loop): + if not asyncio.iscoroutine(main): + raise ValueError(f'a coroutine was expected, got {main!r}') + + try: + return loop.run_until_complete(main) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + if hasattr(loop, 'shutdown_default_executor'): + loop.run_until_complete(loop.shutdown_default_executor()) + + @staticmethod + def _cancel_all_tasks(loop): + to_cancel = asyncio.all_tasks(loop) + + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + # XXX: "loop" is removed in python 3.10+ + loop.run_until_complete( + asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler({ + 'message': 'unhandled exception during asyncio.run() shutdown', + 'exception': task.exception(), + 'task': task, + }) + + def load_plugins(name, suffix, namespace): from ..plugins import load_plugins ret = load_plugins(name, suffix) diff --git a/yt_dlp/utils/_utils.py b/yt_dlp/utils/_utils.py index 10c7c43110..b0164a8953 100644 --- a/yt_dlp/utils/_utils.py +++ b/yt_dlp/utils/_utils.py @@ -1,5 +1,3 @@ -import asyncio -import atexit import base64 import binascii import calendar @@ -54,7 +52,7 @@ from ..compat import ( compat_os_name, compat_shlex_quote, ) -from ..dependencies import websockets, xattr +from ..dependencies import xattr __name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module @@ -4923,77 +4921,6 @@ class Config: return self.parser.parse_args(self.all_args) -class WebSocketsWrapper: - """Wraps websockets module to use in non-async scopes""" - pool = None - - def __init__(self, url, headers=None, connect=True): - self.loop = asyncio.new_event_loop() - # XXX: "loop" is deprecated - self.conn = websockets.connect( - url, extra_headers=headers, ping_interval=None, - close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf')) - if connect: - self.__enter__() - atexit.register(self.__exit__, None, None, None) - - def __enter__(self): - if not self.pool: - self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop) - return self - - def send(self, *args): - self.run_with_loop(self.pool.send(*args), self.loop) - - def recv(self, *args): - return self.run_with_loop(self.pool.recv(*args), self.loop) - - def __exit__(self, type, value, traceback): - try: - return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop) - finally: - self.loop.close() - self._cancel_all_tasks(self.loop) - - # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications - # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class - @staticmethod - def run_with_loop(main, loop): - if not asyncio.iscoroutine(main): - raise ValueError(f'a coroutine was expected, got {main!r}') - - try: - return loop.run_until_complete(main) - finally: - loop.run_until_complete(loop.shutdown_asyncgens()) - if hasattr(loop, 'shutdown_default_executor'): - loop.run_until_complete(loop.shutdown_default_executor()) - - @staticmethod - def _cancel_all_tasks(loop): - to_cancel = asyncio.all_tasks(loop) - - if not to_cancel: - return - - for task in to_cancel: - task.cancel() - - # XXX: "loop" is removed in python 3.10+ - loop.run_until_complete( - asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)) - - for task in to_cancel: - if task.cancelled(): - continue - if task.exception() is not None: - loop.call_exception_handler({ - 'message': 'unhandled exception during asyncio.run() shutdown', - 'exception': task.exception(), - 'task': task, - }) - - def merge_headers(*dicts): """Merge dicts of http headers case insensitively, prioritizing the latter ones""" return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}