1432 lines
53 KiB
Python
1432 lines
53 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import struct
|
|
import threading
|
|
import warnings
|
|
from collections.abc import Callable
|
|
from contextlib import suppress
|
|
from enum import IntEnum
|
|
from functools import partial
|
|
from json import dumps, loads
|
|
from select import select
|
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, TypeVar, Union
|
|
|
|
from ..aio import CURL_SOCKET_BAD, get_selector
|
|
from ..const import CurlECode, CurlInfo, CurlOpt, CurlWsFlag
|
|
from ..curl import Curl, CurlError
|
|
from ..utils import CurlCffiWarning
|
|
from .exceptions import SessionClosed, Timeout
|
|
from .models import Response
|
|
from .utils import not_set, set_curl_options
|
|
|
|
if TYPE_CHECKING:
|
|
from typing_extensions import Self
|
|
|
|
from ..const import CurlHttpVersion
|
|
from ..curl import CurlWsFrame
|
|
from .cookies import CookieTypes
|
|
from .headers import HeaderTypes
|
|
from .impersonate import BrowserTypeLiteral, ExtraFingerprints, ExtraFpDict
|
|
from .session import AsyncSession, ProxySpec
|
|
|
|
T = TypeVar("T")
|
|
|
|
ON_DATA_T = Callable[["WebSocket", bytes, CurlWsFrame], None]
|
|
ON_MESSAGE_T = Callable[["WebSocket", Union[bytes, str]], None]
|
|
ON_ERROR_T = Callable[["WebSocket", CurlError], None]
|
|
ON_OPEN_T = Callable[["WebSocket"], None]
|
|
ON_CLOSE_T = Callable[["WebSocket", int, str], None]
|
|
RECV_QUEUE_ITEM = tuple[Union[bytes, Exception], int]
|
|
SEND_QUEUE_ITEM = tuple[bytes, Union[CurlWsFlag, int]]
|
|
|
|
|
|
# We need a partial for dumps() because a custom function may not accept the parameter
|
|
dumps = partial(dumps, separators=(",", ":"))
|
|
|
|
|
|
class WsCloseCode(IntEnum):
|
|
"""See: https://www.iana.org/assignments/websocket/websocket.xhtml"""
|
|
|
|
OK = 1000
|
|
GOING_AWAY = 1001
|
|
PROTOCOL_ERROR = 1002
|
|
UNSUPPORTED_DATA = 1003
|
|
UNKNOWN = 1005
|
|
ABNORMAL_CLOSURE = 1006
|
|
INVALID_DATA = 1007
|
|
POLICY_VIOLATION = 1008
|
|
MESSAGE_TOO_BIG = 1009
|
|
MANDATORY_EXTENSION = 1010
|
|
INTERNAL_ERROR = 1011
|
|
SERVICE_RESTART = 1012
|
|
TRY_AGAIN_LATER = 1013
|
|
BAD_GATEWAY = 1014
|
|
TLS_HANDSHAKE = 1015
|
|
UNAUTHORIZED = 3000
|
|
FORBIDDEN = 3003
|
|
TIMEOUT = 3008
|
|
|
|
|
|
class WebSocketError(CurlError):
|
|
"""WebSocket-specific error."""
|
|
|
|
def __init__(
|
|
self, message: str, code: Union[WsCloseCode, CurlECode, Literal[0]] = 0
|
|
):
|
|
super().__init__(message, code) # type: ignore
|
|
|
|
|
|
class WebSocketClosed(WebSocketError, SessionClosed):
|
|
"""WebSocket is already closed."""
|
|
|
|
|
|
class WebSocketTimeout(WebSocketError, Timeout):
|
|
"""WebSocket operation timed out."""
|
|
|
|
|
|
class BaseWebSocket:
|
|
__slots__ = (
|
|
"_curl",
|
|
"autoclose",
|
|
"_close_code",
|
|
"_close_reason",
|
|
"debug",
|
|
"closed",
|
|
)
|
|
|
|
def __init__(self, curl: Curl, *, autoclose: bool = True, debug: bool = False):
|
|
self._curl: Curl = curl
|
|
self.autoclose: bool = autoclose
|
|
self._close_code: Optional[int] = None
|
|
self._close_reason: Optional[str] = None
|
|
self.debug = debug
|
|
self.closed = False
|
|
|
|
@property
|
|
def curl(self):
|
|
if self._curl is not_set:
|
|
self._curl = Curl(debug=self.debug)
|
|
return self._curl
|
|
|
|
@property
|
|
def close_code(self) -> Optional[int]:
|
|
"""The WebSocket close code, if the connection has been closed."""
|
|
return self._close_code
|
|
|
|
@property
|
|
def close_reason(self) -> Optional[str]:
|
|
"""The WebSocket close reason, if the connection has been closed."""
|
|
return self._close_reason
|
|
|
|
@staticmethod
|
|
def _pack_close_frame(code: int, reason: bytes) -> bytes:
|
|
return struct.pack("!H", code) + reason
|
|
|
|
@staticmethod
|
|
def _unpack_close_frame(frame: bytes) -> tuple[int, str]:
|
|
if len(frame) < 2:
|
|
code = WsCloseCode.UNKNOWN
|
|
reason = ""
|
|
else:
|
|
try:
|
|
code = struct.unpack_from("!H", frame)[0]
|
|
reason = frame[2:].decode()
|
|
except UnicodeDecodeError as e:
|
|
raise WebSocketError(
|
|
"Invalid close message", WsCloseCode.INVALID_DATA
|
|
) from e
|
|
except Exception as e:
|
|
raise WebSocketError(
|
|
"Invalid close frame", WsCloseCode.PROTOCOL_ERROR
|
|
) from e
|
|
else:
|
|
if code == WsCloseCode.UNKNOWN or code < 1000 or code >= 5000:
|
|
raise WebSocketError(
|
|
f"Invalid close code: {code}", WsCloseCode.PROTOCOL_ERROR
|
|
)
|
|
return code, reason
|
|
|
|
def terminate(self):
|
|
"""Terminate the underlying connection."""
|
|
self.closed = True
|
|
self.curl.close()
|
|
|
|
|
|
EventTypeLiteral = Literal["open", "close", "data", "message", "error"]
|
|
|
|
|
|
class WebSocket(BaseWebSocket):
|
|
"""A WebSocket implementation using libcurl."""
|
|
|
|
__slots__ = (
|
|
"skip_utf8_validation",
|
|
"_emitters",
|
|
"keep_running",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
curl: Union[Curl, Any] = not_set,
|
|
*,
|
|
autoclose: bool = True,
|
|
skip_utf8_validation: bool = False,
|
|
debug: bool = False,
|
|
on_open: Optional[ON_OPEN_T] = None,
|
|
on_close: Optional[ON_CLOSE_T] = None,
|
|
on_data: Optional[ON_DATA_T] = None,
|
|
on_message: Optional[ON_MESSAGE_T] = None,
|
|
on_error: Optional[ON_ERROR_T] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
autoclose: whether to close the WebSocket after receiving a close frame.
|
|
skip_utf8_validation: whether to skip UTF-8 validation for text frames in
|
|
run_forever().
|
|
debug: print extra curl debug info.
|
|
|
|
on_open: open callback, ``def on_open(ws)``
|
|
on_close: close callback, ``def on_close(ws, code, reason)``
|
|
on_data: raw data receive callback, ``def on_data(ws, data, frame)``
|
|
on_message: message receive callback, ``def on_message(ws, message)``
|
|
on_error: error callback, ``def on_error(ws, exception)``
|
|
"""
|
|
super().__init__(curl=curl, autoclose=autoclose, debug=debug)
|
|
self.skip_utf8_validation = skip_utf8_validation
|
|
self.keep_running = False
|
|
|
|
self._emitters: dict[EventTypeLiteral, Callable] = {}
|
|
if on_open:
|
|
self._emitters["open"] = on_open
|
|
if on_close:
|
|
self._emitters["close"] = on_close
|
|
if on_data:
|
|
self._emitters["data"] = on_data
|
|
if on_message:
|
|
self._emitters["message"] = on_message
|
|
if on_error:
|
|
self._emitters["error"] = on_error
|
|
|
|
def __iter__(self) -> WebSocket:
|
|
if self.closed:
|
|
raise WebSocketClosed("WebSocket is closed")
|
|
return self
|
|
|
|
def __next__(self) -> bytes:
|
|
msg, flags = self.recv()
|
|
if flags & CurlWsFlag.CLOSE:
|
|
raise StopIteration
|
|
return msg
|
|
|
|
def _emit(self, event_type: EventTypeLiteral, *args) -> None:
|
|
callback = self._emitters.get(event_type)
|
|
if callback:
|
|
try:
|
|
callback(self, *args)
|
|
except Exception as e:
|
|
error_callback = self._emitters.get("error")
|
|
if error_callback:
|
|
error_callback(self, e)
|
|
else:
|
|
warnings.warn(
|
|
f"WebSocket callback '{event_type}' failed",
|
|
CurlCffiWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
def connect(
|
|
self,
|
|
url: str,
|
|
params: Optional[Union[dict, list, tuple]] = None,
|
|
headers: Optional[HeaderTypes] = None,
|
|
cookies: Optional[CookieTypes] = None,
|
|
auth: Optional[tuple[str, str]] = None,
|
|
timeout: Optional[Union[float, tuple[float, float], object]] = not_set,
|
|
allow_redirects: bool = True,
|
|
max_redirects: int = 30,
|
|
proxies: Optional[ProxySpec] = None,
|
|
proxy: Optional[str] = None,
|
|
proxy_auth: Optional[tuple[str, str]] = None,
|
|
verify: Optional[bool] = None,
|
|
referer: Optional[str] = None,
|
|
accept_encoding: Optional[str] = "gzip, deflate, br",
|
|
impersonate: Optional[BrowserTypeLiteral] = None,
|
|
ja3: Optional[str] = None,
|
|
akamai: Optional[str] = None,
|
|
extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]] = None,
|
|
default_headers: bool = True,
|
|
quote: Union[str, Literal[False]] = "",
|
|
http_version: Optional[CurlHttpVersion] = None,
|
|
interface: Optional[str] = None,
|
|
cert: Optional[Union[str, tuple[str, str]]] = None,
|
|
max_recv_speed: int = 0,
|
|
curl_options: Optional[dict[CurlOpt, str]] = None,
|
|
):
|
|
"""Connect to the WebSocket.
|
|
|
|
libcurl automatically handles pings and pongs.
|
|
ref: https://curl.se/libcurl/c/libcurl-ws.html
|
|
|
|
Args:
|
|
url: url for the requests.
|
|
params: query string for the requests.
|
|
headers: headers to send.
|
|
cookies: cookies to use.
|
|
auth: HTTP basic auth, a tuple of (username, password), only basic auth is
|
|
supported.
|
|
timeout: how many seconds to wait before giving up.
|
|
allow_redirects: whether to allow redirection.
|
|
max_redirects: max redirect counts, default 30, use -1 for unlimited.
|
|
proxies: dict of proxies to use, prefer to use ``proxy`` if they are the
|
|
same. format: ``{"http": proxy_url, "https": proxy_url}``.
|
|
proxy: proxy to use, format: "http://user@pass:proxy_url".
|
|
Can't be used with `proxies` parameter.
|
|
proxy_auth: HTTP basic auth for proxy, a tuple of (username, password).
|
|
verify: whether to verify https certs.
|
|
referer: shortcut for setting referer header.
|
|
accept_encoding: shortcut for setting accept-encoding header.
|
|
impersonate: which browser version to impersonate.
|
|
ja3: ja3 string to impersonate.
|
|
akamai: akamai string to impersonate.
|
|
extra_fp: extra fingerprints options, in complement to ja3 and akamai str.
|
|
default_headers: whether to set default browser headers.
|
|
default_encoding: encoding for decoding response content if charset is not
|
|
found in headers. Defaults to "utf-8". Can be set to a callable for
|
|
automatic detection.
|
|
quote: Set characters to be quoted, i.e. percent-encoded. Default safe
|
|
string is ``!#$%&'()*+,/:;=?@[]~``. If set to a sting, the character
|
|
will be removed from the safe string, thus quoted. If set to False, the
|
|
url will be kept as is, without any automatic percent-encoding, you must
|
|
encode the URL yourself.
|
|
curl_options: extra curl options to use.
|
|
http_version: limiting http version, defaults to http2.
|
|
interface: which interface to use.
|
|
cert: a tuple of (cert, key) filenames for client cert.
|
|
max_recv_speed: maximum receive speed, bytes per second.
|
|
curl_options: extra curl options to use.
|
|
"""
|
|
|
|
curl = self.curl
|
|
|
|
set_curl_options(
|
|
curl=curl,
|
|
method="GET",
|
|
url=url,
|
|
params_list=[None, params],
|
|
headers_list=[None, headers],
|
|
cookies_list=[None, cookies],
|
|
auth=auth,
|
|
timeout=timeout,
|
|
allow_redirects=allow_redirects,
|
|
max_redirects=max_redirects,
|
|
proxies_list=[None, proxies],
|
|
proxy=proxy,
|
|
proxy_auth=proxy_auth,
|
|
verify_list=[None, verify],
|
|
referer=referer,
|
|
accept_encoding=accept_encoding,
|
|
impersonate=impersonate,
|
|
ja3=ja3,
|
|
akamai=akamai,
|
|
extra_fp=extra_fp,
|
|
default_headers=default_headers,
|
|
quote=quote,
|
|
http_version=http_version,
|
|
interface=interface,
|
|
max_recv_speed=max_recv_speed,
|
|
cert=cert,
|
|
curl_options=curl_options,
|
|
)
|
|
|
|
# Magic number defined in: https://curl.se/docs/websocket.html
|
|
curl.setopt(CurlOpt.CONNECT_ONLY, 2)
|
|
curl.perform()
|
|
return self
|
|
|
|
def recv_fragment(self) -> tuple[bytes, CurlWsFrame]:
|
|
"""Receive a single curl websocket fragment as bytes."""
|
|
|
|
if self.closed:
|
|
raise WebSocketClosed("WebSocket is already closed")
|
|
|
|
chunk, frame = self.curl.ws_recv()
|
|
|
|
if frame.flags & CurlWsFlag.CLOSE:
|
|
try:
|
|
self._close_code, self._close_reason = self._unpack_close_frame(chunk)
|
|
except WebSocketError as e:
|
|
# Follow the spec to close the connection
|
|
# Errors do not respect autoclose
|
|
self._close_code = e.code
|
|
self.close(e.code)
|
|
raise
|
|
if self.autoclose:
|
|
self.close()
|
|
|
|
return chunk, frame
|
|
|
|
def recv(self) -> tuple[bytes, int]:
|
|
"""
|
|
Receive a frame as bytes. libcurl splits frames into fragments, so we have to
|
|
collect all the chunks for a frame.
|
|
"""
|
|
chunks = []
|
|
flags = 0
|
|
|
|
sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
|
|
if sock_fd == CURL_SOCKET_BAD:
|
|
raise WebSocketError(
|
|
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
|
|
)
|
|
|
|
while True:
|
|
try:
|
|
# Try to receive the first fragment first
|
|
chunk, frame = self.recv_fragment()
|
|
flags = frame.flags
|
|
chunks.append(chunk)
|
|
if frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0:
|
|
break
|
|
except CurlError as e:
|
|
if e.code == CurlECode.AGAIN:
|
|
# According to https://curl.se/libcurl/c/curl_ws_recv.html
|
|
# > in real application: wait for socket here, e.g. using select()
|
|
_, _, _ = select([sock_fd], [], [], 0.5)
|
|
else:
|
|
raise
|
|
|
|
return b"".join(chunks), flags
|
|
|
|
def recv_str(self) -> str:
|
|
"""Receive a text frame."""
|
|
data, flags = self.recv()
|
|
if not (flags & CurlWsFlag.TEXT):
|
|
raise WebSocketError("Not valid text frame", WsCloseCode.INVALID_DATA)
|
|
return data.decode()
|
|
|
|
def recv_json(self, *, loads: Callable[[str], T] = loads) -> T:
|
|
"""Receive a JSON frame.
|
|
|
|
Args:
|
|
loads: JSON decoder, default is json.loads.
|
|
"""
|
|
data = self.recv_str()
|
|
return loads(data)
|
|
|
|
def send(
|
|
self,
|
|
payload: Union[str, bytes, memoryview],
|
|
flags: CurlWsFlag = CurlWsFlag.BINARY,
|
|
):
|
|
"""Send a data frame.
|
|
|
|
Args:
|
|
payload: data to send.
|
|
flags: flags for the frame.
|
|
"""
|
|
if flags & CurlWsFlag.CLOSE:
|
|
self.keep_running = False
|
|
|
|
if self.closed:
|
|
raise WebSocketClosed("WebSocket is already closed")
|
|
|
|
# curl expects bytes
|
|
if isinstance(payload, str):
|
|
payload = payload.encode()
|
|
|
|
sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
|
|
if sock_fd == CURL_SOCKET_BAD:
|
|
raise WebSocketError(
|
|
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
|
|
)
|
|
|
|
# Loop checks for CurlECode.Again
|
|
# https://curl.se/libcurl/c/curl_ws_send.html
|
|
offset = 0
|
|
while offset < len(payload):
|
|
current_buffer = payload[offset:]
|
|
|
|
try:
|
|
n_sent = self.curl.ws_send(current_buffer, flags)
|
|
except CurlError as e:
|
|
if e.code == CurlECode.AGAIN:
|
|
_, writeable, _ = select([], [sock_fd], [], 0.5)
|
|
if not writeable:
|
|
raise WebSocketError("Socket write timeout") from e
|
|
continue
|
|
raise
|
|
|
|
offset += n_sent
|
|
|
|
return offset
|
|
|
|
def send_binary(self, payload: bytes):
|
|
"""Send a binary frame.
|
|
|
|
Args:
|
|
payload: binary data to send.
|
|
"""
|
|
return self.send(payload, CurlWsFlag.BINARY)
|
|
|
|
def send_bytes(self, payload: bytes):
|
|
"""Send a binary frame, alias of :meth:`send_binary`.
|
|
|
|
Args:
|
|
payload: binary data to send.
|
|
"""
|
|
return self.send(payload, CurlWsFlag.BINARY)
|
|
|
|
def send_str(self, payload: str):
|
|
"""Send a text frame.
|
|
|
|
Args:
|
|
payload: text data to send.
|
|
"""
|
|
return self.send(payload, CurlWsFlag.TEXT)
|
|
|
|
def send_json(self, payload: Any, *, dumps: Callable[[Any], str] = dumps):
|
|
"""Send a JSON frame.
|
|
|
|
Args:
|
|
payload: data to send.
|
|
dumps: JSON encoder, default is json.dumps.
|
|
"""
|
|
return self.send_str(dumps(payload))
|
|
|
|
def ping(self, payload: Union[str, bytes]):
|
|
"""Send a ping frame.
|
|
|
|
Args:
|
|
payload: data to send.
|
|
"""
|
|
return self.send(payload, CurlWsFlag.PING)
|
|
|
|
def run_forever(self, url: str = "", **kwargs):
|
|
"""Run the WebSocket forever. See :meth:`connect` for details on parameters.
|
|
|
|
libcurl automatically handles pings and pongs.
|
|
ref: https://curl.se/libcurl/c/libcurl-ws.html
|
|
"""
|
|
|
|
if url:
|
|
self.connect(url, **kwargs)
|
|
|
|
sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
|
|
if sock_fd == CURL_SOCKET_BAD:
|
|
raise WebSocketError(
|
|
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
|
|
)
|
|
|
|
self._emit("open")
|
|
|
|
# Keep reading the messages and invoke callbacks
|
|
# TODO: Reconnect logic
|
|
chunks = []
|
|
self.keep_running = True
|
|
while self.keep_running:
|
|
try:
|
|
chunk, frame = self.recv_fragment()
|
|
flags = frame.flags
|
|
self._emit("data", chunk, frame)
|
|
|
|
chunks.append(chunk)
|
|
if not (frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0):
|
|
continue
|
|
|
|
# Avoid unnecessary computation
|
|
if "message" in self._emitters:
|
|
# Concatenate collected chunks with the final message
|
|
msg = b"".join(chunks)
|
|
|
|
if (flags & CurlWsFlag.TEXT) and not self.skip_utf8_validation:
|
|
try:
|
|
msg = msg.decode() # type: ignore
|
|
except UnicodeDecodeError as e:
|
|
self._close_code = WsCloseCode.INVALID_DATA
|
|
self.close(WsCloseCode.INVALID_DATA)
|
|
raise WebSocketError(
|
|
"Invalid UTF-8", WsCloseCode.INVALID_DATA
|
|
) from e
|
|
|
|
if (flags & CurlWsFlag.BINARY) or (flags & CurlWsFlag.TEXT):
|
|
self._emit("message", msg)
|
|
|
|
chunks = [] # Reset chunks for next message
|
|
|
|
if flags & CurlWsFlag.CLOSE:
|
|
self.keep_running = False
|
|
self._emit("close", self._close_code or 0, self._close_reason or "")
|
|
|
|
except CurlError as e:
|
|
if e.code == CurlECode.AGAIN:
|
|
_, _, _ = select([sock_fd], [], [], 0.5)
|
|
else:
|
|
self._emit("error", e)
|
|
if not self.closed:
|
|
code = WsCloseCode.UNKNOWN
|
|
if isinstance(e, WebSocketError):
|
|
code = e.code
|
|
self.close(code)
|
|
raise
|
|
|
|
def close(self, code: int = WsCloseCode.OK, message: bytes = b""):
|
|
"""Close the connection.
|
|
|
|
Args:
|
|
code: close code.
|
|
message: close reason.
|
|
"""
|
|
if self.curl is not_set:
|
|
return
|
|
|
|
# TODO: As per spec, we should wait for the server to close the connection
|
|
# But this is not a requirement
|
|
msg = self._pack_close_frame(code, message)
|
|
self.send(msg, CurlWsFlag.CLOSE)
|
|
# The only way to close the connection appears to be curl_easy_cleanup
|
|
self.terminate()
|
|
|
|
|
|
class AsyncWebSocket(BaseWebSocket):
|
|
"""
|
|
An asyncio WebSocket implementation using libcurl.
|
|
|
|
NOTE: This object represents a single WebSocket connection. Once closed,
|
|
it cannot be reopened. A new instance must be created to reconnect.
|
|
"""
|
|
|
|
__slots__ = (
|
|
"session",
|
|
"_loop",
|
|
"_sock_fd",
|
|
"_close_lock",
|
|
"_terminate_lock",
|
|
"_read_task",
|
|
"_write_task",
|
|
"_close_handle",
|
|
"_receive_queue",
|
|
"_send_queue",
|
|
"_max_send_batch_size",
|
|
"_coalesce_frames",
|
|
"retry_on_recv_error",
|
|
"_yield_interval",
|
|
"_use_fair_scheduling",
|
|
"_yield_mask",
|
|
"_recv_error_retries",
|
|
"_terminated",
|
|
"_terminated_event",
|
|
)
|
|
|
|
# Match libcurl's documented max frame size limit.
|
|
_MAX_CURL_FRAME_SIZE: ClassVar[int] = 65535
|
|
_MAX_RECV_RETRIES: ClassVar[int] = 3
|
|
_RECV_RETRY_DELAY: ClassVar[float] = 0.3
|
|
|
|
def __init__(
|
|
self,
|
|
session: AsyncSession,
|
|
curl: Curl,
|
|
*,
|
|
autoclose: bool = True,
|
|
debug: bool = False,
|
|
recv_queue_size: int = 512,
|
|
send_queue_size: int = 256,
|
|
max_send_batch_size: int = 256,
|
|
coalesce_frames: bool = False,
|
|
retry_on_recv_error: bool = False,
|
|
yield_interval: float = 0.001,
|
|
fair_scheduling: bool = False,
|
|
yield_mask: int = 63,
|
|
) -> None:
|
|
"""Initializes an Async WebSocket session.
|
|
|
|
This class should not be instantiated directly. It is intended to be created
|
|
via the `AsyncSession.ws_connect()` method, which correctly handles setup and
|
|
initialization of the underlying I/O tasks.
|
|
|
|
Important:
|
|
This WebSocket implementation uses a decoupled I/O model. Network
|
|
operations occur in background tasks. As a result, network-related
|
|
errors that occur during a `send()` operation will not be raised by
|
|
`send()`. Instead, they are placed into the receive queue and will be
|
|
raised by the next call to `recv()`.
|
|
|
|
Args:
|
|
session (AsyncSession): An instantiated AsyncSession object.
|
|
curl (Curl): The underlying Curl to use.
|
|
autoclose (bool, optional): Close the WS on receiving a close frame.
|
|
debug (bool, optional): Enable debug messages. Defaults to False.
|
|
recv_queue_size (int, optional): The maximum number of incoming WebSocket
|
|
messages to buffer internally. This queue stores messages received
|
|
by the Curl socket that are waiting to be consumed by calling `recv()`
|
|
send_queue_size (int, optional): The maximum number of outgoing WebSocket
|
|
messages to buffer before applying network backpressure. When you call
|
|
`send(...)` the message is placed in this queue and transmitted when
|
|
the Curl socket is next available for sending.
|
|
max_send_batch_size (int, optional): The max number of messages per batch.
|
|
coalesce_frames (bool, optional): Combine multiple frames into a batch.
|
|
retry_on_recv_error (bool, optional): Retry recv on some transient errors.
|
|
yield_interval (float, optional): How often to yield control in seconds.
|
|
fair_scheduling (bool, optional): Change the ~5:1 ratio in favor
|
|
of `recv`:`send` to a fairer 1:1 ratio. This decreases recv throughput.
|
|
yield_mask (int, optional): A bitmask that sets the yield frequency for
|
|
cooperative multitasking, checked every `yield_mask + 1` operations.
|
|
Must be a power of two minus one (e.g., `63`, `127`, `255`) for
|
|
efficient bitwise checks. Lower values increase fairness; higher values
|
|
increase throughput.
|
|
"""
|
|
super().__init__(curl=curl, autoclose=autoclose, debug=debug)
|
|
self.session: AsyncSession[Response] = session
|
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
self._sock_fd: int = -1
|
|
self._close_lock: asyncio.Lock = asyncio.Lock()
|
|
self._terminate_lock: threading.Lock = threading.Lock()
|
|
self._terminated_event: asyncio.Event = asyncio.Event()
|
|
self._read_task: Optional[asyncio.Task[None]] = None
|
|
self._write_task: Optional[asyncio.Task[None]] = None
|
|
self._close_handle: Optional[asyncio.Handle] = None
|
|
self._receive_queue: asyncio.Queue[RECV_QUEUE_ITEM] = asyncio.Queue(
|
|
maxsize=recv_queue_size
|
|
)
|
|
self._send_queue: asyncio.Queue[SEND_QUEUE_ITEM] = asyncio.Queue(
|
|
maxsize=send_queue_size
|
|
)
|
|
self._max_send_batch_size: int = max_send_batch_size
|
|
self._coalesce_frames: bool = coalesce_frames
|
|
self.retry_on_recv_error: bool = retry_on_recv_error
|
|
self._yield_interval: float = yield_interval
|
|
self._use_fair_scheduling: bool = fair_scheduling
|
|
self._yield_mask: int = yield_mask
|
|
self._recv_error_retries: int = 0
|
|
self._terminated: bool = False
|
|
|
|
@property
|
|
def loop(self) -> asyncio.AbstractEventLoop:
|
|
"""Get a reference to the running event loop"""
|
|
if self._loop is None:
|
|
self._loop = get_selector(asyncio.get_running_loop())
|
|
return self._loop
|
|
|
|
@property
|
|
def send_queue_size(self) -> int:
|
|
"""Returns the current number of items in the send queue."""
|
|
return self._send_queue.qsize()
|
|
|
|
def is_alive(self) -> bool:
|
|
"""
|
|
Checks if the background I/O tasks are still running.
|
|
|
|
Returns `False` if either the read or write task has terminated due
|
|
to an error or a clean shutdown.
|
|
|
|
Note: This is a snapshot in time. A return value of `True` does not
|
|
guarantee the next network operation will succeed, but `False`
|
|
definitively indicates the connection is no longer active.
|
|
"""
|
|
if self.closed or self._terminated:
|
|
return False
|
|
|
|
if self._read_task and self._read_task.done():
|
|
return False
|
|
|
|
return not (self._write_task and self._write_task.done())
|
|
|
|
def __del__(self) -> None:
|
|
"""Warn if the user forgets to close the connection."""
|
|
|
|
if getattr(self, "closed", True):
|
|
return
|
|
|
|
with suppress(Exception):
|
|
warnings.warn(
|
|
f"Unclosed WebSocket {self!r} was garbage collected. "
|
|
"Always call await ws.close() to ensure clean shutdown.",
|
|
ResourceWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
def __aiter__(self) -> Self:
|
|
if self.closed:
|
|
raise WebSocketClosed("WebSocket has been closed")
|
|
return self
|
|
|
|
async def __anext__(self) -> bytes:
|
|
msg, flags = await self.recv()
|
|
if (msg is None) or (flags & CurlWsFlag.CLOSE):
|
|
raise StopAsyncIteration
|
|
return msg
|
|
|
|
def _start_io_tasks(self) -> None:
|
|
"""Start the read/write I/O loop tasks.
|
|
This should be called only once after object creation by the factory.
|
|
Once started, the tasks cannot be restarted again, this is a one-shot.
|
|
|
|
Raises:
|
|
WebSocketError: The WebSocket FD was invalid.
|
|
"""
|
|
|
|
# Return early if already started
|
|
if self._read_task is not None:
|
|
return
|
|
|
|
# Get the currently active socket FD
|
|
self._sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
|
|
if self._sock_fd == CURL_SOCKET_BAD:
|
|
raise WebSocketError(
|
|
"Invalid active socket.", code=CurlECode.NO_CONNECTION_AVAILABLE
|
|
)
|
|
|
|
# Get an identifier for the websocket from its object id
|
|
ws_id: str = f"WebSocket-{id(self):#x}"
|
|
|
|
# Start the I/O loop tasks
|
|
self._read_task = self.loop.create_task(self._read_loop(), name=f"{ws_id}-read")
|
|
self._write_task = self.loop.create_task(
|
|
self._write_loop(), name=f"{ws_id}-write"
|
|
)
|
|
|
|
async def recv(
|
|
self, *, timeout: Optional[float] = None
|
|
) -> tuple[Optional[bytes], int]:
|
|
"""Receive a frame as bytes.
|
|
|
|
This method waits for and returns the next complete data frame from the
|
|
receive queue.
|
|
|
|
Args:
|
|
timeout: how many seconds to wait before giving up.
|
|
|
|
Raises:
|
|
WebSocketClosed: If `recv()` is called on a closed connection after
|
|
the receive queue is empty.
|
|
WebSocketTimeout: If the operation times out.
|
|
WebSocketError: A protocol or network error that occurred in a
|
|
background I/O task, including errors from previous `send()`
|
|
operations.
|
|
|
|
Returns:
|
|
tuple[bytes, int]: A tuple with the received payload and flags.
|
|
"""
|
|
if self.closed and self._receive_queue.empty():
|
|
raise WebSocketClosed("WebSocket is closed")
|
|
try:
|
|
result, flags = await asyncio.wait_for(self._receive_queue.get(), timeout)
|
|
|
|
if isinstance(result, Exception):
|
|
raise result
|
|
return result, flags
|
|
|
|
except asyncio.TimeoutError as e:
|
|
raise WebSocketTimeout(
|
|
"WebSocket recv() timed out", CurlECode.OPERATION_TIMEDOUT
|
|
) from e
|
|
|
|
async def recv_str(self, *, timeout: Optional[float] = None) -> str:
|
|
"""Receive a text frame.
|
|
|
|
Args:
|
|
timeout: how many seconds to wait before giving up.
|
|
"""
|
|
data, flags = await self.recv(timeout=timeout)
|
|
if data is None or not (flags & CurlWsFlag.TEXT):
|
|
raise WebSocketError("Not a valid text frame", WsCloseCode.INVALID_DATA)
|
|
try:
|
|
return data.decode("utf-8")
|
|
except UnicodeDecodeError as e:
|
|
raise WebSocketError(
|
|
"Invalid UTF-8 in text frame", WsCloseCode.INVALID_DATA
|
|
) from e
|
|
|
|
async def recv_json(
|
|
self,
|
|
*,
|
|
loads: Callable[[Union[str, bytes]], T] = loads,
|
|
timeout: Optional[float] = None,
|
|
) -> T:
|
|
"""Receive a JSON frame.
|
|
|
|
Args:
|
|
loads: JSON decoder, default is json.loads.
|
|
timeout: how many seconds to wait before giving up.
|
|
"""
|
|
data, flags = await self.recv(timeout=timeout)
|
|
if data is None:
|
|
raise WebSocketError(
|
|
"Received empty frame, cannot decode JSON", WsCloseCode.INVALID_DATA
|
|
)
|
|
if flags & CurlWsFlag.TEXT:
|
|
try:
|
|
return loads(data.decode("utf-8"))
|
|
except UnicodeDecodeError as e:
|
|
raise WebSocketError(
|
|
"Invalid UTF-8 in JSON text frame", WsCloseCode.INVALID_DATA
|
|
) from e
|
|
return loads(data)
|
|
|
|
async def send(
|
|
self,
|
|
payload: Union[str, bytes, bytearray, memoryview],
|
|
flags: CurlWsFlag = CurlWsFlag.BINARY,
|
|
):
|
|
"""Send a data frame.
|
|
|
|
This method is a lightweight, non-blocking call that places the payload
|
|
into a send queue. The actual network transmission is handled by a
|
|
background task.
|
|
|
|
To guarantee all your messages have been sent `await ws.flush(...)`.
|
|
|
|
The max frame size supported by libcurl is `65535` bytes. Larger frames
|
|
will be broken down and sent in chunks of that size.
|
|
|
|
Args:
|
|
payload: data to send.
|
|
flags: flags for the frame.
|
|
|
|
Raises:
|
|
WebSocketClosed: The WebSocket is closed.
|
|
|
|
NOTE:
|
|
Due to the asynchronous nature of this client, network errors
|
|
(e.g., connection dropped) that occur during the actual transmission
|
|
will NOT be raised by this method. They will be raised by a
|
|
subsequent call to `recv()`. Always ensure you are actively
|
|
receiving data to handle potential connection errors.
|
|
|
|
Also: If the network is slow and the internal send queue becomes full,
|
|
this method will block until there is space in the queue.
|
|
"""
|
|
|
|
if self.closed:
|
|
raise WebSocketClosed("WebSocket is closed")
|
|
|
|
# cURL expects bytes
|
|
if isinstance(payload, str):
|
|
payload = payload.encode("utf-8")
|
|
elif isinstance(payload, (bytearray, memoryview)):
|
|
payload = bytes(payload)
|
|
|
|
try:
|
|
self._send_queue.put_nowait((payload, flags))
|
|
except asyncio.QueueFull:
|
|
await self._send_queue.put((payload, flags))
|
|
|
|
async def send_binary(self, payload: bytes) -> None:
|
|
"""Send a binary frame.
|
|
|
|
Args:
|
|
payload: binary data to send.
|
|
|
|
For more info, see the docstring for `send(...)`
|
|
"""
|
|
return await self.send(payload, CurlWsFlag.BINARY)
|
|
|
|
async def send_bytes(self, payload: bytes) -> None:
|
|
"""Send a binary frame, alias of :meth:`send_binary`.
|
|
|
|
Args:
|
|
payload: binary data to send.
|
|
|
|
For more info, see the docstring for `send(...)`
|
|
"""
|
|
return await self.send(payload, CurlWsFlag.BINARY)
|
|
|
|
async def send_str(self, payload: str) -> None:
|
|
"""Send a text frame.
|
|
|
|
Args:
|
|
payload: text data to send.
|
|
|
|
For more info, see the docstring for `send(...)`
|
|
"""
|
|
return await self.send(payload, CurlWsFlag.TEXT)
|
|
|
|
async def send_json(
|
|
self, payload: Any, *, dumps: Callable[[Any], str] = dumps
|
|
) -> None:
|
|
"""Send a JSON frame.
|
|
|
|
Args:
|
|
payload: data to send.
|
|
dumps: JSON encoder, default is `json.dumps(...)`.
|
|
|
|
For more info, see the docstring for `send(...)`
|
|
"""
|
|
return await self.send_str(dumps(payload))
|
|
|
|
async def ping(self, payload: Union[str, bytes]):
|
|
"""Send a ping frame.
|
|
|
|
Args:
|
|
payload: data to send.
|
|
|
|
For more info, see the docstring for `send(...)`
|
|
"""
|
|
return await self.send(payload, CurlWsFlag.PING)
|
|
|
|
async def close(
|
|
self, code: int = WsCloseCode.OK, message: bytes = b"", timeout: float = 5.0
|
|
) -> None:
|
|
"""
|
|
Performs a graceful WebSocket closing handshake and terminates the connection.
|
|
|
|
This method sends a WebSocket close frame to the peer, waits for queued
|
|
outgoing messages to be sent, and then shuts down the connection. This is
|
|
the recommended way to close the session.
|
|
|
|
Args:
|
|
code (int, optional): Close code. Defaults to `WsCloseCode.OK`.
|
|
message (bytes, optional): Close reason. Defaults to `b""`.
|
|
timeout (float, optional): How long in seconds to wait closed.
|
|
"""
|
|
async with self._close_lock:
|
|
if self.closed:
|
|
return
|
|
|
|
self.closed = True
|
|
|
|
try:
|
|
if self._write_task and not self._write_task.done():
|
|
close_frame = self._pack_close_frame(code, message)
|
|
with suppress(asyncio.TimeoutError):
|
|
await asyncio.wait_for(
|
|
self._send_queue.put((close_frame, CurlWsFlag.CLOSE)),
|
|
timeout=timeout,
|
|
)
|
|
with suppress(WebSocketTimeout, WebSocketError):
|
|
await self.flush(timeout)
|
|
|
|
finally:
|
|
self.terminate()
|
|
|
|
# Wait for the termination completion signal
|
|
with suppress(asyncio.TimeoutError):
|
|
await asyncio.wait_for(self._terminated_event.wait(), timeout)
|
|
|
|
def terminate(self) -> None:
|
|
"""
|
|
Immediately terminates the connection without a graceful handshake.
|
|
|
|
This method is a forceful shutdown that cancels all background I/O tasks
|
|
and cleans up resources. It should be used for final cleanup or after an
|
|
unrecoverable error. Unlike `close()`, it does not attempt to send a close
|
|
frame or wait for pending messages. It schedules the cleanup to run on the
|
|
event loop and returns immediately. It does not wait for cleanup completion.
|
|
|
|
This method is thread-safe, task-safe, and idempotent.
|
|
"""
|
|
|
|
with self._terminate_lock:
|
|
if self._terminated:
|
|
return
|
|
self._terminated = True
|
|
|
|
# Terminate the connection in a thread-safe way
|
|
if self._loop and self.loop.is_running():
|
|
self._close_handle = self.loop.call_soon_threadsafe(
|
|
lambda: self.loop.create_task(self._terminate_helper())
|
|
)
|
|
|
|
# The event loop is not running
|
|
else:
|
|
super().terminate()
|
|
if self.session and not self.session._closed:
|
|
# WebSocket curls CANNOT be reused
|
|
self.session.push_curl(None)
|
|
self._terminated_event.set()
|
|
|
|
async def _read_loop(self) -> None:
|
|
"""The main asynchronous task for reading incoming WebSocket frames.
|
|
|
|
This method is fully event-driven. It waits for the underlying socket
|
|
to become readable, and upon being woken by the event loop, it drains
|
|
all buffered data from libcurl until it receives an EAGAIN error. This
|
|
error signals that the buffer is empty, and the loop returns to an
|
|
idle state, waiting for the next readability event.
|
|
|
|
To ensure cooperative multitasking during high-volume message streams,
|
|
the loop yields control to the asyncio event loop periodically.
|
|
|
|
If the receive queue becomes full, await `self._receive_queue.put(...)`
|
|
will block the reader loop and stall the socket read task. Thus, appropriate
|
|
queue sizes should be set by the user, even though the defaults are generous
|
|
and should be suitable for most use cases.
|
|
"""
|
|
|
|
# Cache locals to avoid repeated attribute lookups
|
|
curl_ws_recv = self._curl.ws_recv
|
|
queue_put_nowait = self._receive_queue.put_nowait
|
|
queue_put = self._receive_queue.put
|
|
loop = self.loop
|
|
fair_scheduling = self._use_fair_scheduling
|
|
yield_mask = self._yield_mask
|
|
yield_interval = self._yield_interval
|
|
|
|
chunks: list[bytes] = []
|
|
msg_counter = 0
|
|
try:
|
|
# The outer loop waits for readability events.
|
|
while not self.closed:
|
|
# Wait for the socket to be readable.
|
|
read_future = loop.create_future()
|
|
|
|
try:
|
|
loop.add_reader(self._sock_fd, read_future.set_result, None)
|
|
except Exception as exc:
|
|
with suppress(asyncio.QueueFull):
|
|
queue_put_nowait(
|
|
(
|
|
WebSocketError(
|
|
f"add_reader failed: {exc}",
|
|
CurlECode.NO_CONNECTION_AVAILABLE,
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
self.terminate()
|
|
return
|
|
|
|
try:
|
|
await read_future
|
|
finally:
|
|
# Remove the reader immediately after waking.
|
|
if self._sock_fd != -1:
|
|
loop.remove_reader(self._sock_fd)
|
|
|
|
# There is data, so we now read until it's empty.
|
|
start_time = loop.time()
|
|
while True:
|
|
try:
|
|
chunk, frame = curl_ws_recv()
|
|
flags: int = frame.flags
|
|
if self._recv_error_retries > 0:
|
|
self._recv_error_retries = 0
|
|
|
|
# If a CLOSE frame is received, the reader is done.
|
|
if flags & CurlWsFlag.CLOSE:
|
|
with suppress(asyncio.QueueFull):
|
|
queue_put_nowait((chunk, flags))
|
|
await self._handle_close_frame(chunk)
|
|
return
|
|
|
|
# Collect the chunk
|
|
chunks.append(chunk)
|
|
|
|
# If the message is complete, process and dispatch it
|
|
if frame.bytesleft <= 0 and (flags & CurlWsFlag.CONT) == 0:
|
|
message = b"".join(chunks)
|
|
chunks.clear()
|
|
try:
|
|
queue_put_nowait((message, flags))
|
|
except asyncio.QueueFull:
|
|
await queue_put((message, flags))
|
|
|
|
msg_counter += 1
|
|
op_check: bool = (msg_counter & yield_mask) == 0
|
|
time_check: bool = loop.time() - start_time > yield_interval
|
|
|
|
if (
|
|
(op_check or time_check)
|
|
if fair_scheduling
|
|
else (op_check and time_check)
|
|
):
|
|
await asyncio.sleep(0)
|
|
start_time = loop.time()
|
|
|
|
except CurlError as e:
|
|
# Normal EAGAIN from Curl
|
|
if e.code == CurlECode.AGAIN:
|
|
break
|
|
|
|
# Transient error, can be retried
|
|
if (
|
|
e.code == CurlECode.RECV_ERROR
|
|
and self.retry_on_recv_error
|
|
and self._recv_error_retries < self._MAX_RECV_RETRIES
|
|
):
|
|
self._recv_error_retries += 1
|
|
await asyncio.sleep(
|
|
self._RECV_RETRY_DELAY * self._recv_error_retries
|
|
)
|
|
continue
|
|
|
|
# Unrecoverable, place error on queue and cleanup
|
|
with suppress(asyncio.QueueFull):
|
|
queue_put_nowait((e, 0))
|
|
self.terminate()
|
|
return
|
|
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
if not self.closed:
|
|
with suppress(asyncio.QueueFull):
|
|
queue_put_nowait((e, 0))
|
|
self.terminate()
|
|
finally:
|
|
with suppress(asyncio.QueueFull):
|
|
queue_put_nowait((WebSocketClosed("Connection closed."), 0))
|
|
|
|
async def _write_loop(self) -> None:
|
|
"""
|
|
The high-level send manager. It efficiently gathers pending messages
|
|
from the send queue and orchestrates their transmission.
|
|
|
|
This method runs a continuous loop that consumes messages from the
|
|
`_send_queue`. To improve performance and reduce system call overhead,
|
|
it implements an adaptive batching strategy. It greedily gathers
|
|
multiple pending messages from the queue and then coalesces the
|
|
payloads of messages that share the same flags (e.g., all text frames)
|
|
into a single, larger payload, ONLY if `coalesce_frames=True` and the
|
|
frame is not a CONTROL frame, as the spec requires them to be whole.
|
|
|
|
It will batch as many as possible, then iterate over the batch and send
|
|
the frames, one at a time. This batching and coalescing significantly
|
|
improves throughput for high volumes of small messages where the message
|
|
boundaries do not matter. The final, consolidated payloads are then passed
|
|
to the `_send_payload` method for transmission.
|
|
"""
|
|
control_frame_flags: int = CurlWsFlag.CLOSE | CurlWsFlag.PING
|
|
send_payload = self._send_payload
|
|
queue_get = self._send_queue.get
|
|
queue_get_nowait = self._send_queue.get_nowait
|
|
|
|
try:
|
|
while True:
|
|
payload, flags = await queue_get()
|
|
|
|
# Build the rest of the batch without awaiting.
|
|
batch = [(payload, flags)]
|
|
if not flags & CurlWsFlag.CLOSE:
|
|
while len(batch) < self._max_send_batch_size:
|
|
try:
|
|
payload, frame = queue_get_nowait()
|
|
batch.append((payload, frame))
|
|
if frame & CurlWsFlag.CLOSE:
|
|
break
|
|
|
|
except asyncio.QueueEmpty:
|
|
break
|
|
|
|
try:
|
|
# Process the batch depending on the coalescing strategy
|
|
if self._coalesce_frames:
|
|
data_to_coalesce: dict[int, list[bytes]] = {}
|
|
for payload, frame in batch:
|
|
if frame & control_frame_flags:
|
|
# Flush any pending data before the control frame.
|
|
for frame_group, payloads in data_to_coalesce.items():
|
|
if not await send_payload(
|
|
b"".join(payloads), frame_group
|
|
):
|
|
return
|
|
data_to_coalesce.clear()
|
|
|
|
if not await send_payload(payload, frame):
|
|
return
|
|
else:
|
|
data_to_coalesce.setdefault(frame, []).append(payload)
|
|
|
|
# Send any remaining data at the end of the batch.
|
|
for frame_group, payloads in data_to_coalesce.items():
|
|
if not await send_payload(b"".join(payloads), frame_group):
|
|
return
|
|
else:
|
|
# Send each message in the batch, preserving frame boundaries.
|
|
for payload, frame in batch:
|
|
if not await send_payload(payload, frame):
|
|
return
|
|
|
|
finally:
|
|
# Mark all processed items as done.
|
|
for _ in range(len(batch)):
|
|
self._send_queue.task_done()
|
|
|
|
# Exit cleanly after sending a CLOSE frame.
|
|
if batch[-1][1] & CurlWsFlag.CLOSE:
|
|
break
|
|
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
except Exception as e:
|
|
if not self.closed:
|
|
with suppress(asyncio.QueueFull):
|
|
self._receive_queue.put_nowait((e, 0))
|
|
|
|
finally:
|
|
# If the loop exits unexpectedly, ensure we terminate the connection.
|
|
if not self.closed:
|
|
self.terminate()
|
|
|
|
async def _send_payload(self, payload: bytes, flags: CurlWsFlag) -> bool:
|
|
"""
|
|
The low-level I/O Handler. It transmits a single payload, handling
|
|
fragmentation, backpressure (EAGAIN), and cooperative multitasking.
|
|
Returns False on a non-recoverable error.
|
|
|
|
Args:
|
|
payload: The complete byte payload to be sent.
|
|
flags: The `CurlWsFlag` indicating the frame type (e.g., `TEXT`, `BINARY`).
|
|
"""
|
|
|
|
# Cache locals to reduce lookup cost
|
|
curl_ws_send = self._curl.ws_send
|
|
queue_put_nowait = self._receive_queue.put_nowait
|
|
loop = self.loop
|
|
|
|
view = memoryview(payload)
|
|
offset = 0
|
|
write_ops = 0
|
|
start_time: float = loop.time()
|
|
|
|
while offset < len(view):
|
|
# Cooperatively yield to the event loop to prevent starvation.
|
|
if (write_ops & self._yield_mask) == 0 or (
|
|
loop.time() - start_time
|
|
) > self._yield_interval:
|
|
await asyncio.sleep(0)
|
|
start_time = loop.time()
|
|
|
|
try:
|
|
chunk = view[offset : offset + self._MAX_CURL_FRAME_SIZE]
|
|
n_sent = curl_ws_send(chunk, flags)
|
|
if n_sent == 0:
|
|
with suppress(asyncio.QueueFull):
|
|
queue_put_nowait(
|
|
(
|
|
WebSocketError(
|
|
"ws_send returned 0 bytes",
|
|
CurlECode.SEND_ERROR,
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
return False
|
|
offset += n_sent
|
|
write_ops += 1
|
|
|
|
except CurlError as e:
|
|
if e.code != CurlECode.AGAIN:
|
|
# Non-recoverable error: queue it for the user and signal failure.
|
|
with suppress(asyncio.QueueFull):
|
|
queue_put_nowait((e, 0))
|
|
return False
|
|
|
|
# EAGAIN: wait until the socket is writable.
|
|
write_future = loop.create_future()
|
|
|
|
try:
|
|
loop.add_writer(self._sock_fd, write_future.set_result, None)
|
|
except Exception as exc:
|
|
with suppress(asyncio.QueueFull):
|
|
queue_put_nowait(
|
|
(
|
|
WebSocketError(
|
|
f"add_writer failed: {exc}",
|
|
CurlECode.NO_CONNECTION_AVAILABLE,
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
return False
|
|
|
|
try:
|
|
await write_future
|
|
finally:
|
|
if self._sock_fd != -1:
|
|
loop.remove_writer(self._sock_fd)
|
|
return True
|
|
|
|
async def flush(self, timeout: Optional[float] = None) -> None:
|
|
"""Waits until all items in the send queue have been processed.
|
|
|
|
This ensures that all messages passed to `send()` have been handed off to the
|
|
underlying socket for transmission. It does not guarantee that the data has
|
|
been received by the remote peer.
|
|
|
|
Args:
|
|
timeout (Optional[float], optional): The maximum number of seconds to wait
|
|
for the queue to drain.
|
|
|
|
Raises:
|
|
WebSocketTimeout: If the send queue is not fully processed within the
|
|
specified ``timeout`` period.
|
|
WebSocketError: If the writer task has already terminated while unsent
|
|
messages remain in the queue.
|
|
"""
|
|
if (
|
|
self._write_task
|
|
and self._write_task.done()
|
|
and not self._send_queue.empty()
|
|
):
|
|
raise WebSocketError(
|
|
"Cannot flush, writer task has terminated unexpectedly."
|
|
)
|
|
|
|
try:
|
|
await asyncio.wait_for(self._send_queue.join(), timeout=timeout)
|
|
except asyncio.TimeoutError as e:
|
|
raise WebSocketTimeout("Timed out waiting for send queue to flush.") from e
|
|
|
|
async def _terminate_helper(self) -> None:
|
|
"""Utility method for connection termination"""
|
|
tasks_to_cancel: set[asyncio.Task[None]] = set()
|
|
max_timeout: int = 3
|
|
|
|
try:
|
|
# Cancel all the I/O tasks
|
|
for io_task in (self._read_task, self._write_task):
|
|
try:
|
|
if io_task and not io_task.done():
|
|
io_task.cancel()
|
|
tasks_to_cancel.add(io_task)
|
|
except (asyncio.CancelledError, RuntimeError):
|
|
...
|
|
|
|
# Wait for cancellation but don't get stuck
|
|
if tasks_to_cancel:
|
|
with suppress(asyncio.TimeoutError):
|
|
await asyncio.wait_for(
|
|
asyncio.gather(*tasks_to_cancel, return_exceptions=True),
|
|
timeout=max_timeout,
|
|
)
|
|
|
|
# Drain the send_queue
|
|
while not self._send_queue.empty():
|
|
try:
|
|
self._send_queue.get_nowait()
|
|
self._send_queue.task_done()
|
|
except (asyncio.QueueEmpty, ValueError):
|
|
break
|
|
|
|
# Remove the reader/writer if still registered
|
|
if self._sock_fd != -1:
|
|
with suppress(Exception):
|
|
self.loop.remove_reader(self._sock_fd)
|
|
with suppress(Exception):
|
|
self.loop.remove_writer(self._sock_fd)
|
|
|
|
self._sock_fd = -1
|
|
|
|
# Close the Curl connection
|
|
super().terminate()
|
|
if self.session and not self.session._closed:
|
|
# WebSocket curls CANNOT be reused
|
|
self.session.push_curl(None)
|
|
|
|
finally:
|
|
self._terminated_event.set()
|
|
|
|
async def _handle_close_frame(self, message: bytes) -> None:
|
|
"""Unpack and handle the closing frame, then initiate shutdown."""
|
|
try:
|
|
self._close_code, self._close_reason = self._unpack_close_frame(message)
|
|
except WebSocketError as e:
|
|
self._close_code = e.code
|
|
|
|
if self.autoclose and not self.closed:
|
|
await self.close(self._close_code or WsCloseCode.OK)
|
|
else:
|
|
# If not sending a reply, we must still terminate the connection.
|
|
self.terminate()
|