[rh] Remove additional logging handlers on close (#9032)

Fixes https://github.com/yt-dlp/yt-dlp/issues/8922

Authored by: coletdjnz
This commit is contained in:
coletdjnz 2024-02-18 11:32:34 +13:00 committed by GitHub
parent 73fcfa39f5
commit 0085e2bab8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 64 additions and 6 deletions

View file

@ -13,6 +13,7 @@ import http.client
import http.cookiejar import http.cookiejar
import http.server import http.server
import io import io
import logging
import pathlib import pathlib
import random import random
import ssl import ssl
@ -752,6 +753,25 @@ class TestClientCertificate:
}) })
class TestRequestHandlerMisc:
"""Misc generic tests for request handlers, not related to request or validation testing"""
@pytest.mark.parametrize('handler,logger_name', [
('Requests', 'urllib3'),
('Websockets', 'websockets.client'),
('Websockets', 'websockets.server')
], indirect=['handler'])
def test_remove_logging_handler(self, handler, logger_name):
# Ensure any logging handlers, which may contain a YoutubeDL instance,
# are removed when we close the request handler
# See: https://github.com/yt-dlp/yt-dlp/issues/8922
logging_handlers = logging.getLogger(logger_name).handlers
before_count = len(logging_handlers)
rh = handler()
assert len(logging_handlers) == before_count + 1
rh.close()
assert len(logging_handlers) == before_count
class TestUrllibRequestHandler(TestRequestHandlerBase): class TestUrllibRequestHandler(TestRequestHandlerBase):
@pytest.mark.parametrize('handler', ['Urllib'], indirect=True) @pytest.mark.parametrize('handler', ['Urllib'], indirect=True)
def test_file_urls(self, handler): def test_file_urls(self, handler):
@ -827,6 +847,7 @@ class TestUrllibRequestHandler(TestRequestHandlerBase):
assert not isinstance(exc_info.value, TransportError) assert not isinstance(exc_info.value, TransportError)
@pytest.mark.parametrize('handler', ['Requests'], indirect=True)
class TestRequestsRequestHandler(TestRequestHandlerBase): class TestRequestsRequestHandler(TestRequestHandlerBase):
@pytest.mark.parametrize('raised,expected', [ @pytest.mark.parametrize('raised,expected', [
(lambda: requests.exceptions.ConnectTimeout(), TransportError), (lambda: requests.exceptions.ConnectTimeout(), TransportError),
@ -843,7 +864,6 @@ class TestRequestsRequestHandler(TestRequestHandlerBase):
(lambda: requests.exceptions.RequestException(), RequestError) (lambda: requests.exceptions.RequestException(), RequestError)
# (lambda: requests.exceptions.TooManyRedirects(), HTTPError) - Needs a response object # (lambda: requests.exceptions.TooManyRedirects(), HTTPError) - Needs a response object
]) ])
@pytest.mark.parametrize('handler', ['Requests'], indirect=True)
def test_request_error_mapping(self, handler, monkeypatch, raised, expected): def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
with handler() as rh: with handler() as rh:
def mock_get_instance(*args, **kwargs): def mock_get_instance(*args, **kwargs):
@ -877,7 +897,6 @@ class TestRequestsRequestHandler(TestRequestHandlerBase):
'3 bytes read, 5 more expected' '3 bytes read, 5 more expected'
), ),
]) ])
@pytest.mark.parametrize('handler', ['Requests'], indirect=True)
def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match): def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match):
from requests.models import Response as RequestsResponse from requests.models import Response as RequestsResponse
from urllib3.response import HTTPResponse as Urllib3Response from urllib3.response import HTTPResponse as Urllib3Response
@ -896,6 +915,21 @@ class TestRequestsRequestHandler(TestRequestHandlerBase):
assert exc_info.type is expected assert exc_info.type is expected
def test_close(self, handler, monkeypatch):
rh = handler()
session = rh._get_instance(cookiejar=rh.cookiejar)
called = False
original_close = session.close
def mock_close(*args, **kwargs):
nonlocal called
called = True
return original_close(*args, **kwargs)
monkeypatch.setattr(session, 'close', mock_close)
rh.close()
assert called
def run_validation(handler, error, req, **handler_kwargs): def run_validation(handler, error, req, **handler_kwargs):
with handler(**handler_kwargs) as rh: with handler(**handler_kwargs) as rh:
@ -1205,6 +1239,19 @@ class TestRequestDirector:
assert director.send(Request('http://')).read() == b'' assert director.send(Request('http://')).read() == b''
assert director.send(Request('http://', headers={'prefer': '1'})).read() == b'supported' assert director.send(Request('http://', headers={'prefer': '1'})).read() == b'supported'
def test_close(self, monkeypatch):
director = RequestDirector(logger=FakeLogger())
director.add_handler(FakeRH(logger=FakeLogger()))
called = False
def mock_close(*args, **kwargs):
nonlocal called
called = True
monkeypatch.setattr(director.handlers[FakeRH.RH_KEY], 'close', mock_close)
director.close()
assert called
# XXX: do we want to move this to test_YoutubeDL.py? # XXX: do we want to move this to test_YoutubeDL.py?
class TestYoutubeDLNetworking: class TestYoutubeDLNetworking:

View file

@ -258,10 +258,10 @@ class RequestsRH(RequestHandler, InstanceStoreMixin):
# Forward urllib3 debug messages to our logger # Forward urllib3 debug messages to our logger
logger = logging.getLogger('urllib3') logger = logging.getLogger('urllib3')
handler = Urllib3LoggingHandler(logger=self._logger) self.__logging_handler = Urllib3LoggingHandler(logger=self._logger)
handler.setFormatter(logging.Formatter('requests: %(message)s')) self.__logging_handler.setFormatter(logging.Formatter('requests: %(message)s'))
handler.addFilter(Urllib3LoggingFilter()) self.__logging_handler.addFilter(Urllib3LoggingFilter())
logger.addHandler(handler) logger.addHandler(self.__logging_handler)
# TODO: Use a logger filter to suppress pool reuse warning instead # TODO: Use a logger filter to suppress pool reuse warning instead
logger.setLevel(logging.ERROR) logger.setLevel(logging.ERROR)
@ -276,6 +276,9 @@ class RequestsRH(RequestHandler, InstanceStoreMixin):
def close(self): def close(self):
self._clear_instances() self._clear_instances()
# Remove the logging handler that contains a reference to our logger
# See: https://github.com/yt-dlp/yt-dlp/issues/8922
logging.getLogger('urllib3').removeHandler(self.__logging_handler)
def _check_extensions(self, extensions): def _check_extensions(self, extensions):
super()._check_extensions(extensions) super()._check_extensions(extensions)

View file

@ -90,10 +90,12 @@ class WebsocketsRH(WebSocketRequestHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.__logging_handlers = {}
for name in ('websockets.client', 'websockets.server'): for name in ('websockets.client', 'websockets.server'):
logger = logging.getLogger(name) logger = logging.getLogger(name)
handler = logging.StreamHandler(stream=sys.stdout) handler = logging.StreamHandler(stream=sys.stdout)
handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s')) handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
self.__logging_handlers[name] = handler
logger.addHandler(handler) logger.addHandler(handler)
if self.verbose: if self.verbose:
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -103,6 +105,12 @@ class WebsocketsRH(WebSocketRequestHandler):
extensions.pop('timeout', None) extensions.pop('timeout', None)
extensions.pop('cookiejar', None) extensions.pop('cookiejar', None)
def close(self):
# Remove the logging handler that contains a reference to our logger
# See: https://github.com/yt-dlp/yt-dlp/issues/8922
for name, handler in self.__logging_handlers.items():
logging.getLogger(name).removeHandler(handler)
def _send(self, request): def _send(self, request):
timeout = float(request.extensions.get('timeout') or self.timeout) timeout = float(request.extensions.get('timeout') or self.timeout)
headers = self._merge_headers(request.headers) headers = self._merge_headers(request.headers)