mirror of
https://github.com/yt-dlp/yt-dlp
synced 2024-12-26 21:59:08 +01:00
[test] Workaround websocket server hanging (#9467)
Authored by: coletdjnz
This commit is contained in:
parent
f2868b26e9
commit
f849d77ab5
1 changed files with 33 additions and 20 deletions
|
@ -32,8 +32,6 @@ from yt_dlp.networking.exceptions import (
|
|||
)
|
||||
from yt_dlp.utils.networking import HTTPHeaderDict
|
||||
|
||||
from test.conftest import validate_and_send
|
||||
|
||||
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
|
@ -66,7 +64,9 @@ def process_request(self, request):
|
|||
|
||||
def create_websocket_server(**ws_kwargs):
|
||||
import websockets.sync.server
|
||||
wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
|
||||
wsd = websockets.sync.server.serve(
|
||||
websocket_handler, '127.0.0.1', 0,
|
||||
process_request=process_request, open_timeout=2, **ws_kwargs)
|
||||
ws_port = wsd.socket.getsockname()[1]
|
||||
ws_server_thread = threading.Thread(target=wsd.serve_forever)
|
||||
ws_server_thread.daemon = True
|
||||
|
@ -100,6 +100,19 @@ def create_mtls_wss_websocket_server():
|
|||
return create_websocket_server(ssl_context=sslctx)
|
||||
|
||||
|
||||
def ws_validate_and_send(rh, req):
|
||||
rh.validate(req)
|
||||
max_tries = 3
|
||||
for i in range(max_tries):
|
||||
try:
|
||||
return rh.send(req)
|
||||
except TransportError as e:
|
||||
if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
|
||||
# websockets server sometimes hangs on new connections
|
||||
continue
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
|
||||
class TestWebsSocketRequestHandlerConformance:
|
||||
@classmethod
|
||||
|
@ -119,7 +132,7 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||
def test_basic_websockets(self, handler):
|
||||
with handler() as rh:
|
||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||
assert 'upgrade' in ws.headers
|
||||
assert ws.status == 101
|
||||
ws.send('foo')
|
||||
|
@ -131,7 +144,7 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||
def test_send_types(self, handler, msg, opcode):
|
||||
with handler() as rh:
|
||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws.send(msg)
|
||||
assert int(ws.recv()) == opcode
|
||||
ws.close()
|
||||
|
@ -140,10 +153,10 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
def test_verify_cert(self, handler):
|
||||
with handler() as rh:
|
||||
with pytest.raises(CertificateVerifyError):
|
||||
validate_and_send(rh, Request(self.wss_base_url))
|
||||
ws_validate_and_send(rh, Request(self.wss_base_url))
|
||||
|
||||
with handler(verify=False) as rh:
|
||||
ws = validate_and_send(rh, Request(self.wss_base_url))
|
||||
ws = ws_validate_and_send(rh, Request(self.wss_base_url))
|
||||
assert ws.status == 101
|
||||
ws.close()
|
||||
|
||||
|
@ -151,7 +164,7 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
def test_ssl_error(self, handler):
|
||||
with handler(verify=False) as rh:
|
||||
with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
|
||||
validate_and_send(rh, Request(self.bad_wss_host))
|
||||
ws_validate_and_send(rh, Request(self.bad_wss_host))
|
||||
assert not issubclass(exc_info.type, CertificateVerifyError)
|
||||
|
||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||
|
@ -163,7 +176,7 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
])
|
||||
def test_percent_encode(self, handler, path, expected):
|
||||
with handler() as rh:
|
||||
ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
|
||||
ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
|
||||
ws.send('path')
|
||||
assert ws.recv() == expected
|
||||
assert ws.status == 101
|
||||
|
@ -174,7 +187,7 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
with handler() as rh:
|
||||
# This isn't a comprehensive test,
|
||||
# but it should be enough to check whether the handler is removing dot segments
|
||||
ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
|
||||
ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
|
||||
assert ws.status == 101
|
||||
ws.send('path')
|
||||
assert ws.recv() == '/test'
|
||||
|
@ -187,7 +200,7 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
def test_raise_http_error(self, handler, status):
|
||||
with handler() as rh:
|
||||
with pytest.raises(HTTPError) as exc_info:
|
||||
validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
|
||||
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
|
||||
assert exc_info.value.status == status
|
||||
|
||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||
|
@ -198,7 +211,7 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
def test_timeout(self, handler, params, extensions):
|
||||
with handler(**params) as rh:
|
||||
with pytest.raises(TransportError):
|
||||
validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
|
||||
ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
|
||||
|
||||
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||
def test_cookies(self, handler):
|
||||
|
@ -210,18 +223,18 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
comment_url=None, rest={}))
|
||||
|
||||
with handler(cookiejar=cookiejar) as rh:
|
||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws.send('headers')
|
||||
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
||||
ws.close()
|
||||
|
||||
with handler() as rh:
|
||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws.send('headers')
|
||||
assert 'cookie' not in json.loads(ws.recv())
|
||||
ws.close()
|
||||
|
||||
ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
|
||||
ws.send('headers')
|
||||
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
||||
ws.close()
|
||||
|
@ -231,7 +244,7 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
source_address = f'127.0.0.{random.randint(5, 255)}'
|
||||
verify_address_availability(source_address)
|
||||
with handler(source_address=source_address) as rh:
|
||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws.send('source_address')
|
||||
assert source_address == ws.recv()
|
||||
ws.close()
|
||||
|
@ -240,7 +253,7 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
def test_response_url(self, handler):
|
||||
with handler() as rh:
|
||||
url = f'{self.ws_base_url}/something'
|
||||
ws = validate_and_send(rh, Request(url))
|
||||
ws = ws_validate_and_send(rh, Request(url))
|
||||
assert ws.url == url
|
||||
ws.close()
|
||||
|
||||
|
@ -248,14 +261,14 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
def test_request_headers(self, handler):
|
||||
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
|
||||
# Global Headers
|
||||
ws = validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
|
||||
ws.send('headers')
|
||||
headers = HTTPHeaderDict(json.loads(ws.recv()))
|
||||
assert headers['test1'] == 'test'
|
||||
ws.close()
|
||||
|
||||
# Per request headers, merged with global
|
||||
ws = validate_and_send(rh, Request(
|
||||
ws = ws_validate_and_send(rh, Request(
|
||||
self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
|
||||
ws.send('headers')
|
||||
headers = HTTPHeaderDict(json.loads(ws.recv()))
|
||||
|
@ -288,7 +301,7 @@ class TestWebsSocketRequestHandlerConformance:
|
|||
verify=False,
|
||||
client_cert=client_cert
|
||||
) as rh:
|
||||
validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
|
||||
ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
|
||||
|
||||
|
||||
def create_fake_ws_connection(raised):
|
||||
|
|
Loading…
Reference in a new issue