mirror of
https://github.com/yt-dlp/yt-dlp
synced 2024-12-26 21:59:08 +01:00
[test] Workaround websocket server hanging (#9467)
Authored by: coletdjnz
This commit is contained in:
parent
f2868b26e9
commit
f849d77ab5
1 changed files with 33 additions and 20 deletions
|
@ -32,8 +32,6 @@ from yt_dlp.networking.exceptions import (
|
||||||
)
|
)
|
||||||
from yt_dlp.utils.networking import HTTPHeaderDict
|
from yt_dlp.utils.networking import HTTPHeaderDict
|
||||||
|
|
||||||
from test.conftest import validate_and_send
|
|
||||||
|
|
||||||
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,7 +64,9 @@ def process_request(self, request):
|
||||||
|
|
||||||
def create_websocket_server(**ws_kwargs):
|
def create_websocket_server(**ws_kwargs):
|
||||||
import websockets.sync.server
|
import websockets.sync.server
|
||||||
wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
|
wsd = websockets.sync.server.serve(
|
||||||
|
websocket_handler, '127.0.0.1', 0,
|
||||||
|
process_request=process_request, open_timeout=2, **ws_kwargs)
|
||||||
ws_port = wsd.socket.getsockname()[1]
|
ws_port = wsd.socket.getsockname()[1]
|
||||||
ws_server_thread = threading.Thread(target=wsd.serve_forever)
|
ws_server_thread = threading.Thread(target=wsd.serve_forever)
|
||||||
ws_server_thread.daemon = True
|
ws_server_thread.daemon = True
|
||||||
|
@ -100,6 +100,19 @@ def create_mtls_wss_websocket_server():
|
||||||
return create_websocket_server(ssl_context=sslctx)
|
return create_websocket_server(ssl_context=sslctx)
|
||||||
|
|
||||||
|
|
||||||
|
def ws_validate_and_send(rh, req):
|
||||||
|
rh.validate(req)
|
||||||
|
max_tries = 3
|
||||||
|
for i in range(max_tries):
|
||||||
|
try:
|
||||||
|
return rh.send(req)
|
||||||
|
except TransportError as e:
|
||||||
|
if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
|
||||||
|
# websockets server sometimes hangs on new connections
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
|
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
|
||||||
class TestWebsSocketRequestHandlerConformance:
|
class TestWebsSocketRequestHandlerConformance:
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -119,7 +132,7 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
def test_basic_websockets(self, handler):
|
def test_basic_websockets(self, handler):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||||
assert 'upgrade' in ws.headers
|
assert 'upgrade' in ws.headers
|
||||||
assert ws.status == 101
|
assert ws.status == 101
|
||||||
ws.send('foo')
|
ws.send('foo')
|
||||||
|
@ -131,7 +144,7 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
def test_send_types(self, handler, msg, opcode):
|
def test_send_types(self, handler, msg, opcode):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||||
ws.send(msg)
|
ws.send(msg)
|
||||||
assert int(ws.recv()) == opcode
|
assert int(ws.recv()) == opcode
|
||||||
ws.close()
|
ws.close()
|
||||||
|
@ -140,10 +153,10 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
def test_verify_cert(self, handler):
|
def test_verify_cert(self, handler):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
with pytest.raises(CertificateVerifyError):
|
with pytest.raises(CertificateVerifyError):
|
||||||
validate_and_send(rh, Request(self.wss_base_url))
|
ws_validate_and_send(rh, Request(self.wss_base_url))
|
||||||
|
|
||||||
with handler(verify=False) as rh:
|
with handler(verify=False) as rh:
|
||||||
ws = validate_and_send(rh, Request(self.wss_base_url))
|
ws = ws_validate_and_send(rh, Request(self.wss_base_url))
|
||||||
assert ws.status == 101
|
assert ws.status == 101
|
||||||
ws.close()
|
ws.close()
|
||||||
|
|
||||||
|
@ -151,7 +164,7 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
def test_ssl_error(self, handler):
|
def test_ssl_error(self, handler):
|
||||||
with handler(verify=False) as rh:
|
with handler(verify=False) as rh:
|
||||||
with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
|
with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
|
||||||
validate_and_send(rh, Request(self.bad_wss_host))
|
ws_validate_and_send(rh, Request(self.bad_wss_host))
|
||||||
assert not issubclass(exc_info.type, CertificateVerifyError)
|
assert not issubclass(exc_info.type, CertificateVerifyError)
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
@ -163,7 +176,7 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
])
|
])
|
||||||
def test_percent_encode(self, handler, path, expected):
|
def test_percent_encode(self, handler, path, expected):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
|
ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
|
||||||
ws.send('path')
|
ws.send('path')
|
||||||
assert ws.recv() == expected
|
assert ws.recv() == expected
|
||||||
assert ws.status == 101
|
assert ws.status == 101
|
||||||
|
@ -174,7 +187,7 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
# This isn't a comprehensive test,
|
# This isn't a comprehensive test,
|
||||||
# but it should be enough to check whether the handler is removing dot segments
|
# but it should be enough to check whether the handler is removing dot segments
|
||||||
ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
|
ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
|
||||||
assert ws.status == 101
|
assert ws.status == 101
|
||||||
ws.send('path')
|
ws.send('path')
|
||||||
assert ws.recv() == '/test'
|
assert ws.recv() == '/test'
|
||||||
|
@ -187,7 +200,7 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
def test_raise_http_error(self, handler, status):
|
def test_raise_http_error(self, handler, status):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
with pytest.raises(HTTPError) as exc_info:
|
with pytest.raises(HTTPError) as exc_info:
|
||||||
validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
|
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
|
||||||
assert exc_info.value.status == status
|
assert exc_info.value.status == status
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
@ -198,7 +211,7 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
def test_timeout(self, handler, params, extensions):
|
def test_timeout(self, handler, params, extensions):
|
||||||
with handler(**params) as rh:
|
with handler(**params) as rh:
|
||||||
with pytest.raises(TransportError):
|
with pytest.raises(TransportError):
|
||||||
validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
|
ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
def test_cookies(self, handler):
|
def test_cookies(self, handler):
|
||||||
|
@ -210,18 +223,18 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
comment_url=None, rest={}))
|
comment_url=None, rest={}))
|
||||||
|
|
||||||
with handler(cookiejar=cookiejar) as rh:
|
with handler(cookiejar=cookiejar) as rh:
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||||
ws.send('headers')
|
ws.send('headers')
|
||||||
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
||||||
ws.close()
|
ws.close()
|
||||||
|
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||||
ws.send('headers')
|
ws.send('headers')
|
||||||
assert 'cookie' not in json.loads(ws.recv())
|
assert 'cookie' not in json.loads(ws.recv())
|
||||||
ws.close()
|
ws.close()
|
||||||
|
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
|
||||||
ws.send('headers')
|
ws.send('headers')
|
||||||
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
||||||
ws.close()
|
ws.close()
|
||||||
|
@ -231,7 +244,7 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
source_address = f'127.0.0.{random.randint(5, 255)}'
|
source_address = f'127.0.0.{random.randint(5, 255)}'
|
||||||
verify_address_availability(source_address)
|
verify_address_availability(source_address)
|
||||||
with handler(source_address=source_address) as rh:
|
with handler(source_address=source_address) as rh:
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||||
ws.send('source_address')
|
ws.send('source_address')
|
||||||
assert source_address == ws.recv()
|
assert source_address == ws.recv()
|
||||||
ws.close()
|
ws.close()
|
||||||
|
@ -240,7 +253,7 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
def test_response_url(self, handler):
|
def test_response_url(self, handler):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
url = f'{self.ws_base_url}/something'
|
url = f'{self.ws_base_url}/something'
|
||||||
ws = validate_and_send(rh, Request(url))
|
ws = ws_validate_and_send(rh, Request(url))
|
||||||
assert ws.url == url
|
assert ws.url == url
|
||||||
ws.close()
|
ws.close()
|
||||||
|
|
||||||
|
@ -248,14 +261,14 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
def test_request_headers(self, handler):
|
def test_request_headers(self, handler):
|
||||||
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
|
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
|
||||||
# Global Headers
|
# Global Headers
|
||||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||||
ws.send('headers')
|
ws.send('headers')
|
||||||
headers = HTTPHeaderDict(json.loads(ws.recv()))
|
headers = HTTPHeaderDict(json.loads(ws.recv()))
|
||||||
assert headers['test1'] == 'test'
|
assert headers['test1'] == 'test'
|
||||||
ws.close()
|
ws.close()
|
||||||
|
|
||||||
# Per request headers, merged with global
|
# Per request headers, merged with global
|
||||||
ws = validate_and_send(rh, Request(
|
ws = ws_validate_and_send(rh, Request(
|
||||||
self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
|
self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
|
||||||
ws.send('headers')
|
ws.send('headers')
|
||||||
headers = HTTPHeaderDict(json.loads(ws.recv()))
|
headers = HTTPHeaderDict(json.loads(ws.recv()))
|
||||||
|
@ -288,7 +301,7 @@ class TestWebsSocketRequestHandlerConformance:
|
||||||
verify=False,
|
verify=False,
|
||||||
client_cert=client_cert
|
client_cert=client_cert
|
||||||
) as rh:
|
) as rh:
|
||||||
validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
|
ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
|
||||||
|
|
||||||
|
|
||||||
def create_fake_ws_connection(raised):
|
def create_fake_ws_connection(raised):
|
||||||
|
|
Loading…
Reference in a new issue