[test] Workaround websocket server hanging (#9467)

Authored by: coletdjnz
This commit is contained in:
coletdjnz 2024-03-16 16:57:21 +13:00 committed by GitHub
parent f2868b26e9
commit f849d77ab5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -32,8 +32,6 @@ from yt_dlp.networking.exceptions import (
) )
from yt_dlp.utils.networking import HTTPHeaderDict from yt_dlp.utils.networking import HTTPHeaderDict
from test.conftest import validate_and_send
TEST_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_DIR = os.path.dirname(os.path.abspath(__file__))
@ -66,7 +64,9 @@ def process_request(self, request):
def create_websocket_server(**ws_kwargs): def create_websocket_server(**ws_kwargs):
import websockets.sync.server import websockets.sync.server
wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs) wsd = websockets.sync.server.serve(
websocket_handler, '127.0.0.1', 0,
process_request=process_request, open_timeout=2, **ws_kwargs)
ws_port = wsd.socket.getsockname()[1] ws_port = wsd.socket.getsockname()[1]
ws_server_thread = threading.Thread(target=wsd.serve_forever) ws_server_thread = threading.Thread(target=wsd.serve_forever)
ws_server_thread.daemon = True ws_server_thread.daemon = True
@ -100,6 +100,19 @@ def create_mtls_wss_websocket_server():
return create_websocket_server(ssl_context=sslctx) return create_websocket_server(ssl_context=sslctx)
def ws_validate_and_send(rh, req):
rh.validate(req)
max_tries = 3
for i in range(max_tries):
try:
return rh.send(req)
except TransportError as e:
if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
# websockets server sometimes hangs on new connections
continue
raise
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
class TestWebsSocketRequestHandlerConformance: class TestWebsSocketRequestHandlerConformance:
@classmethod @classmethod
@ -119,7 +132,7 @@ class TestWebsSocketRequestHandlerConformance:
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True) @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_basic_websockets(self, handler): def test_basic_websockets(self, handler):
with handler() as rh: with handler() as rh:
ws = validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
assert 'upgrade' in ws.headers assert 'upgrade' in ws.headers
assert ws.status == 101 assert ws.status == 101
ws.send('foo') ws.send('foo')
@ -131,7 +144,7 @@ class TestWebsSocketRequestHandlerConformance:
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True) @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_send_types(self, handler, msg, opcode): def test_send_types(self, handler, msg, opcode):
with handler() as rh: with handler() as rh:
ws = validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send(msg) ws.send(msg)
assert int(ws.recv()) == opcode assert int(ws.recv()) == opcode
ws.close() ws.close()
@ -140,10 +153,10 @@ class TestWebsSocketRequestHandlerConformance:
def test_verify_cert(self, handler): def test_verify_cert(self, handler):
with handler() as rh: with handler() as rh:
with pytest.raises(CertificateVerifyError): with pytest.raises(CertificateVerifyError):
validate_and_send(rh, Request(self.wss_base_url)) ws_validate_and_send(rh, Request(self.wss_base_url))
with handler(verify=False) as rh: with handler(verify=False) as rh:
ws = validate_and_send(rh, Request(self.wss_base_url)) ws = ws_validate_and_send(rh, Request(self.wss_base_url))
assert ws.status == 101 assert ws.status == 101
ws.close() ws.close()
@ -151,7 +164,7 @@ class TestWebsSocketRequestHandlerConformance:
def test_ssl_error(self, handler): def test_ssl_error(self, handler):
with handler(verify=False) as rh: with handler(verify=False) as rh:
with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info: with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
validate_and_send(rh, Request(self.bad_wss_host)) ws_validate_and_send(rh, Request(self.bad_wss_host))
assert not issubclass(exc_info.type, CertificateVerifyError) assert not issubclass(exc_info.type, CertificateVerifyError)
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True) @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@ -163,7 +176,7 @@ class TestWebsSocketRequestHandlerConformance:
]) ])
def test_percent_encode(self, handler, path, expected): def test_percent_encode(self, handler, path, expected):
with handler() as rh: with handler() as rh:
ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}')) ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
ws.send('path') ws.send('path')
assert ws.recv() == expected assert ws.recv() == expected
assert ws.status == 101 assert ws.status == 101
@ -174,7 +187,7 @@ class TestWebsSocketRequestHandlerConformance:
with handler() as rh: with handler() as rh:
# This isn't a comprehensive test, # This isn't a comprehensive test,
# but it should be enough to check whether the handler is removing dot segments # but it should be enough to check whether the handler is removing dot segments
ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test')) ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
assert ws.status == 101 assert ws.status == 101
ws.send('path') ws.send('path')
assert ws.recv() == '/test' assert ws.recv() == '/test'
@ -187,7 +200,7 @@ class TestWebsSocketRequestHandlerConformance:
def test_raise_http_error(self, handler, status): def test_raise_http_error(self, handler, status):
with handler() as rh: with handler() as rh:
with pytest.raises(HTTPError) as exc_info: with pytest.raises(HTTPError) as exc_info:
validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
assert exc_info.value.status == status assert exc_info.value.status == status
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True) @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
@ -198,7 +211,7 @@ class TestWebsSocketRequestHandlerConformance:
def test_timeout(self, handler, params, extensions): def test_timeout(self, handler, params, extensions):
with handler(**params) as rh: with handler(**params) as rh:
with pytest.raises(TransportError): with pytest.raises(TransportError):
validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True) @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
def test_cookies(self, handler): def test_cookies(self, handler):
@ -210,18 +223,18 @@ class TestWebsSocketRequestHandlerConformance:
comment_url=None, rest={})) comment_url=None, rest={}))
with handler(cookiejar=cookiejar) as rh: with handler(cookiejar=cookiejar) as rh:
ws = validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers') ws.send('headers')
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
ws.close() ws.close()
with handler() as rh: with handler() as rh:
ws = validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers') ws.send('headers')
assert 'cookie' not in json.loads(ws.recv()) assert 'cookie' not in json.loads(ws.recv())
ws.close() ws.close()
ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
ws.send('headers') ws.send('headers')
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
ws.close() ws.close()
@ -231,7 +244,7 @@ class TestWebsSocketRequestHandlerConformance:
source_address = f'127.0.0.{random.randint(5, 255)}' source_address = f'127.0.0.{random.randint(5, 255)}'
verify_address_availability(source_address) verify_address_availability(source_address)
with handler(source_address=source_address) as rh: with handler(source_address=source_address) as rh:
ws = validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('source_address') ws.send('source_address')
assert source_address == ws.recv() assert source_address == ws.recv()
ws.close() ws.close()
@ -240,7 +253,7 @@ class TestWebsSocketRequestHandlerConformance:
def test_response_url(self, handler): def test_response_url(self, handler):
with handler() as rh: with handler() as rh:
url = f'{self.ws_base_url}/something' url = f'{self.ws_base_url}/something'
ws = validate_and_send(rh, Request(url)) ws = ws_validate_and_send(rh, Request(url))
assert ws.url == url assert ws.url == url
ws.close() ws.close()
@ -248,14 +261,14 @@ class TestWebsSocketRequestHandlerConformance:
def test_request_headers(self, handler): def test_request_headers(self, handler):
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
# Global Headers # Global Headers
ws = validate_and_send(rh, Request(self.ws_base_url)) ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers') ws.send('headers')
headers = HTTPHeaderDict(json.loads(ws.recv())) headers = HTTPHeaderDict(json.loads(ws.recv()))
assert headers['test1'] == 'test' assert headers['test1'] == 'test'
ws.close() ws.close()
# Per request headers, merged with global # Per request headers, merged with global
ws = validate_and_send(rh, Request( ws = ws_validate_and_send(rh, Request(
self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'})) self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
ws.send('headers') ws.send('headers')
headers = HTTPHeaderDict(json.loads(ws.recv())) headers = HTTPHeaderDict(json.loads(ws.recv()))
@ -288,7 +301,7 @@ class TestWebsSocketRequestHandlerConformance:
verify=False, verify=False,
client_cert=client_cert client_cert=client_cert
) as rh: ) as rh:
validate_and_send(rh, Request(self.mtls_wss_base_url)).close() ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
def create_fake_ws_connection(raised): def create_fake_ws_connection(raised):