from __future__ import annotations

import contextlib
import functools
import io
import logging
import ssl
import sys

from ._helper import (
    create_connection,
    create_socks_proxy_socket,
    make_socks_proxy_opts,
    select_proxy,
)
from .common import Features, Response, register_rh
from .exceptions import (
    CertificateVerifyError,
    HTTPError,
    ProxyError,
    RequestError,
    SSLError,
    TransportError,
)
from .websocket import WebSocketRequestHandler, WebSocketResponse
from ..dependencies import websockets
from ..socks import ProxyError as SocksProxyError
from ..utils import int_or_none

if not websockets:
    raise ImportError('websockets is not installed')

import websockets.version

websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
if websockets_version < (13, 0):
    raise ImportError('Only websockets>=13.0 is supported')

import websockets.sync.client
from websockets.uri import parse_uri

# In websockets Connection, recv_exc and recv_events_exc are defined
# after the recv events handler thread is started [1].
# On our CI using PyPy, in some cases a race condition may occur
# where the recv events handler thread tries to use these attributes before they are defined [2].
# 1: https://github.com/python-websockets/websockets/blame/de768cf65e7e2b1a3b67854fb9e08816a5ff7050/src/websockets/sync/connection.py#L93
# 2: "AttributeError: 'ClientConnection' object has no attribute 'recv_events_exc'. Did you mean: 'recv_events'?"
import websockets.sync.connection  # isort: split
with contextlib.suppress(Exception):
    websockets.sync.connection.Connection.recv_exc = None


class WebsocketsResponseAdapter(WebSocketResponse):

    def __init__(self, ws: websockets.sync.client.ClientConnection, url):
        super().__init__(
            fp=io.BytesIO(ws.response.body or b''),
            url=url,
            headers=ws.response.headers,
            status=ws.response.status_code,
            reason=ws.response.reason_phrase,
        )
        self._ws = ws

    def close(self):
        self._ws.close()
        super().close()

    def send(self, message):
        # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
        try:
            return self._ws.send(message)
        except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
            raise TransportError(cause=e) from e
        except SocksProxyError as e:
            raise ProxyError(cause=e) from e
        except TypeError as e:
            raise RequestError(cause=e) from e

    def recv(self):
        # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
        try:
            return self._ws.recv()
        except SocksProxyError as e:
            raise ProxyError(cause=e) from e
        except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
            raise TransportError(cause=e) from e


@register_rh
class WebsocketsRH(WebSocketRequestHandler):
    """
    Websockets request handler
    https://websockets.readthedocs.io
    https://github.com/python-websockets/websockets
    """
    _SUPPORTED_URL_SCHEMES = ('wss', 'ws')
    _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
    _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
    RH_NAME = 'websockets'

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__logging_handlers = {}
        for name in ('websockets.client', 'websockets.server'):
            logger = logging.getLogger(name)
            handler = logging.StreamHandler(stream=sys.stdout)
            handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
            self.__logging_handlers[name] = handler
            logger.addHandler(handler)
            if self.verbose:
                logger.setLevel(logging.DEBUG)

    def _check_extensions(self, extensions):
        super()._check_extensions(extensions)
        extensions.pop('timeout', None)
        extensions.pop('cookiejar', None)
        extensions.pop('legacy_ssl', None)

    def close(self):
        # Remove the logging handler that contains a reference to our logger
        # See: https://github.com/yt-dlp/yt-dlp/issues/8922
        for name, handler in self.__logging_handlers.items():
            logging.getLogger(name).removeHandler(handler)

    def _send(self, request):
        timeout = self._calculate_timeout(request)
        headers = self._merge_headers(request.headers)
        if 'cookie' not in headers:
            cookiejar = self._get_cookiejar(request)
            cookie_header = cookiejar.get_cookie_header(request.url)
            if cookie_header:
                headers['cookie'] = cookie_header

        wsuri = parse_uri(request.url)
        create_conn_kwargs = {
            'source_address': (self.source_address, 0) if self.source_address else None,
            'timeout': timeout,
        }
        proxy = select_proxy(request.url, self._get_proxies(request))
        try:
            if proxy:
                socks_proxy_options = make_socks_proxy_opts(proxy)
                sock = create_connection(
                    address=(socks_proxy_options['addr'], socks_proxy_options['port']),
                    _create_socket_func=functools.partial(
                        create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
                    **create_conn_kwargs,
                )
            else:
                sock = create_connection(
                    address=(wsuri.host, wsuri.port),
                    **create_conn_kwargs,
                )
            ssl_ctx = self._make_sslcontext(legacy_ssl_support=request.extensions.get('legacy_ssl'))
            conn = websockets.sync.client.connect(
                sock=sock,
                uri=request.url,
                additional_headers=headers,
                open_timeout=timeout,
                user_agent_header=None,
                ssl=ssl_ctx if wsuri.secure else None,
                close_timeout=0,  # not ideal, but prevents yt-dlp hanging
            )
            return WebsocketsResponseAdapter(conn, url=request.url)

        # Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
        except SocksProxyError as e:
            raise ProxyError(cause=e) from e
        except websockets.exceptions.InvalidURI as e:
            raise RequestError(cause=e) from e
        except ssl.SSLCertVerificationError as e:
            raise CertificateVerifyError(cause=e) from e
        except ssl.SSLError as e:
            raise SSLError(cause=e) from e
        except websockets.exceptions.InvalidStatus as e:
            raise HTTPError(
                Response(
                    fp=io.BytesIO(e.response.body),
                    url=request.url,
                    headers=e.response.headers,
                    status=e.response.status_code,
                    reason=e.response.reason_phrase),
            ) from e
        except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
            raise TransportError(cause=e) from e