FAEA/venv/lib/python3.10/site-packages/curl_cffi/requests/websockets.py

1432 lines
53 KiB
Python

from __future__ import annotations
import asyncio
import struct
import threading
import warnings
from collections.abc import Callable
from contextlib import suppress
from enum import IntEnum
from functools import partial
from json import dumps, loads
from select import select
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, TypeVar, Union
from ..aio import CURL_SOCKET_BAD, get_selector
from ..const import CurlECode, CurlInfo, CurlOpt, CurlWsFlag
from ..curl import Curl, CurlError
from ..utils import CurlCffiWarning
from .exceptions import SessionClosed, Timeout
from .models import Response
from .utils import not_set, set_curl_options
if TYPE_CHECKING:
from typing_extensions import Self
from ..const import CurlHttpVersion
from ..curl import CurlWsFrame
from .cookies import CookieTypes
from .headers import HeaderTypes
from .impersonate import BrowserTypeLiteral, ExtraFingerprints, ExtraFpDict
from .session import AsyncSession, ProxySpec
T = TypeVar("T")
ON_DATA_T = Callable[["WebSocket", bytes, CurlWsFrame], None]
ON_MESSAGE_T = Callable[["WebSocket", Union[bytes, str]], None]
ON_ERROR_T = Callable[["WebSocket", CurlError], None]
ON_OPEN_T = Callable[["WebSocket"], None]
ON_CLOSE_T = Callable[["WebSocket", int, str], None]
RECV_QUEUE_ITEM = tuple[Union[bytes, Exception], int]
SEND_QUEUE_ITEM = tuple[bytes, Union[CurlWsFlag, int]]
# We need a partial for dumps() because a custom function may not accept the parameter
dumps = partial(dumps, separators=(",", ":"))
class WsCloseCode(IntEnum):
"""See: https://www.iana.org/assignments/websocket/websocket.xhtml"""
OK = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
UNKNOWN = 1005
ABNORMAL_CLOSURE = 1006
INVALID_DATA = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
BAD_GATEWAY = 1014
TLS_HANDSHAKE = 1015
UNAUTHORIZED = 3000
FORBIDDEN = 3003
TIMEOUT = 3008
class WebSocketError(CurlError):
"""WebSocket-specific error."""
def __init__(
self, message: str, code: Union[WsCloseCode, CurlECode, Literal[0]] = 0
):
super().__init__(message, code) # type: ignore
class WebSocketClosed(WebSocketError, SessionClosed):
"""WebSocket is already closed."""
class WebSocketTimeout(WebSocketError, Timeout):
"""WebSocket operation timed out."""
class BaseWebSocket:
__slots__ = (
"_curl",
"autoclose",
"_close_code",
"_close_reason",
"debug",
"closed",
)
def __init__(self, curl: Curl, *, autoclose: bool = True, debug: bool = False):
self._curl: Curl = curl
self.autoclose: bool = autoclose
self._close_code: Optional[int] = None
self._close_reason: Optional[str] = None
self.debug = debug
self.closed = False
@property
def curl(self):
if self._curl is not_set:
self._curl = Curl(debug=self.debug)
return self._curl
@property
def close_code(self) -> Optional[int]:
"""The WebSocket close code, if the connection has been closed."""
return self._close_code
@property
def close_reason(self) -> Optional[str]:
"""The WebSocket close reason, if the connection has been closed."""
return self._close_reason
@staticmethod
def _pack_close_frame(code: int, reason: bytes) -> bytes:
return struct.pack("!H", code) + reason
@staticmethod
def _unpack_close_frame(frame: bytes) -> tuple[int, str]:
if len(frame) < 2:
code = WsCloseCode.UNKNOWN
reason = ""
else:
try:
code = struct.unpack_from("!H", frame)[0]
reason = frame[2:].decode()
except UnicodeDecodeError as e:
raise WebSocketError(
"Invalid close message", WsCloseCode.INVALID_DATA
) from e
except Exception as e:
raise WebSocketError(
"Invalid close frame", WsCloseCode.PROTOCOL_ERROR
) from e
else:
if code == WsCloseCode.UNKNOWN or code < 1000 or code >= 5000:
raise WebSocketError(
f"Invalid close code: {code}", WsCloseCode.PROTOCOL_ERROR
)
return code, reason
def terminate(self):
"""Terminate the underlying connection."""
self.closed = True
self.curl.close()
EventTypeLiteral = Literal["open", "close", "data", "message", "error"]
class WebSocket(BaseWebSocket):
"""A WebSocket implementation using libcurl."""
__slots__ = (
"skip_utf8_validation",
"_emitters",
"keep_running",
)
def __init__(
self,
curl: Union[Curl, Any] = not_set,
*,
autoclose: bool = True,
skip_utf8_validation: bool = False,
debug: bool = False,
on_open: Optional[ON_OPEN_T] = None,
on_close: Optional[ON_CLOSE_T] = None,
on_data: Optional[ON_DATA_T] = None,
on_message: Optional[ON_MESSAGE_T] = None,
on_error: Optional[ON_ERROR_T] = None,
):
"""
Args:
autoclose: whether to close the WebSocket after receiving a close frame.
skip_utf8_validation: whether to skip UTF-8 validation for text frames in
run_forever().
debug: print extra curl debug info.
on_open: open callback, ``def on_open(ws)``
on_close: close callback, ``def on_close(ws, code, reason)``
on_data: raw data receive callback, ``def on_data(ws, data, frame)``
on_message: message receive callback, ``def on_message(ws, message)``
on_error: error callback, ``def on_error(ws, exception)``
"""
super().__init__(curl=curl, autoclose=autoclose, debug=debug)
self.skip_utf8_validation = skip_utf8_validation
self.keep_running = False
self._emitters: dict[EventTypeLiteral, Callable] = {}
if on_open:
self._emitters["open"] = on_open
if on_close:
self._emitters["close"] = on_close
if on_data:
self._emitters["data"] = on_data
if on_message:
self._emitters["message"] = on_message
if on_error:
self._emitters["error"] = on_error
def __iter__(self) -> WebSocket:
if self.closed:
raise WebSocketClosed("WebSocket is closed")
return self
def __next__(self) -> bytes:
msg, flags = self.recv()
if flags & CurlWsFlag.CLOSE:
raise StopIteration
return msg
def _emit(self, event_type: EventTypeLiteral, *args) -> None:
callback = self._emitters.get(event_type)
if callback:
try:
callback(self, *args)
except Exception as e:
error_callback = self._emitters.get("error")
if error_callback:
error_callback(self, e)
else:
warnings.warn(
f"WebSocket callback '{event_type}' failed",
CurlCffiWarning,
stacklevel=2,
)
def connect(
self,
url: str,
params: Optional[Union[dict, list, tuple]] = None,
headers: Optional[HeaderTypes] = None,
cookies: Optional[CookieTypes] = None,
auth: Optional[tuple[str, str]] = None,
timeout: Optional[Union[float, tuple[float, float], object]] = not_set,
allow_redirects: bool = True,
max_redirects: int = 30,
proxies: Optional[ProxySpec] = None,
proxy: Optional[str] = None,
proxy_auth: Optional[tuple[str, str]] = None,
verify: Optional[bool] = None,
referer: Optional[str] = None,
accept_encoding: Optional[str] = "gzip, deflate, br",
impersonate: Optional[BrowserTypeLiteral] = None,
ja3: Optional[str] = None,
akamai: Optional[str] = None,
extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]] = None,
default_headers: bool = True,
quote: Union[str, Literal[False]] = "",
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
cert: Optional[Union[str, tuple[str, str]]] = None,
max_recv_speed: int = 0,
curl_options: Optional[dict[CurlOpt, str]] = None,
):
"""Connect to the WebSocket.
libcurl automatically handles pings and pongs.
ref: https://curl.se/libcurl/c/libcurl-ws.html
Args:
url: url for the requests.
params: query string for the requests.
headers: headers to send.
cookies: cookies to use.
auth: HTTP basic auth, a tuple of (username, password), only basic auth is
supported.
timeout: how many seconds to wait before giving up.
allow_redirects: whether to allow redirection.
max_redirects: max redirect counts, default 30, use -1 for unlimited.
proxies: dict of proxies to use, prefer to use ``proxy`` if they are the
same. format: ``{"http": proxy_url, "https": proxy_url}``.
proxy: proxy to use, format: "http://user@pass:proxy_url".
Can't be used with `proxies` parameter.
proxy_auth: HTTP basic auth for proxy, a tuple of (username, password).
verify: whether to verify https certs.
referer: shortcut for setting referer header.
accept_encoding: shortcut for setting accept-encoding header.
impersonate: which browser version to impersonate.
ja3: ja3 string to impersonate.
akamai: akamai string to impersonate.
extra_fp: extra fingerprints options, in complement to ja3 and akamai str.
default_headers: whether to set default browser headers.
default_encoding: encoding for decoding response content if charset is not
found in headers. Defaults to "utf-8". Can be set to a callable for
automatic detection.
quote: Set characters to be quoted, i.e. percent-encoded. Default safe
string is ``!#$%&'()*+,/:;=?@[]~``. If set to a sting, the character
will be removed from the safe string, thus quoted. If set to False, the
url will be kept as is, without any automatic percent-encoding, you must
encode the URL yourself.
curl_options: extra curl options to use.
http_version: limiting http version, defaults to http2.
interface: which interface to use.
cert: a tuple of (cert, key) filenames for client cert.
max_recv_speed: maximum receive speed, bytes per second.
curl_options: extra curl options to use.
"""
curl = self.curl
set_curl_options(
curl=curl,
method="GET",
url=url,
params_list=[None, params],
headers_list=[None, headers],
cookies_list=[None, cookies],
auth=auth,
timeout=timeout,
allow_redirects=allow_redirects,
max_redirects=max_redirects,
proxies_list=[None, proxies],
proxy=proxy,
proxy_auth=proxy_auth,
verify_list=[None, verify],
referer=referer,
accept_encoding=accept_encoding,
impersonate=impersonate,
ja3=ja3,
akamai=akamai,
extra_fp=extra_fp,
default_headers=default_headers,
quote=quote,
http_version=http_version,
interface=interface,
max_recv_speed=max_recv_speed,
cert=cert,
curl_options=curl_options,
)
# Magic number defined in: https://curl.se/docs/websocket.html
curl.setopt(CurlOpt.CONNECT_ONLY, 2)
curl.perform()
return self
def recv_fragment(self) -> tuple[bytes, CurlWsFrame]:
"""Receive a single curl websocket fragment as bytes."""
if self.closed:
raise WebSocketClosed("WebSocket is already closed")
chunk, frame = self.curl.ws_recv()
if frame.flags & CurlWsFlag.CLOSE:
try:
self._close_code, self._close_reason = self._unpack_close_frame(chunk)
except WebSocketError as e:
# Follow the spec to close the connection
# Errors do not respect autoclose
self._close_code = e.code
self.close(e.code)
raise
if self.autoclose:
self.close()
return chunk, frame
def recv(self) -> tuple[bytes, int]:
"""
Receive a frame as bytes. libcurl splits frames into fragments, so we have to
collect all the chunks for a frame.
"""
chunks = []
flags = 0
sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
if sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
)
while True:
try:
# Try to receive the first fragment first
chunk, frame = self.recv_fragment()
flags = frame.flags
chunks.append(chunk)
if frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0:
break
except CurlError as e:
if e.code == CurlECode.AGAIN:
# According to https://curl.se/libcurl/c/curl_ws_recv.html
# > in real application: wait for socket here, e.g. using select()
_, _, _ = select([sock_fd], [], [], 0.5)
else:
raise
return b"".join(chunks), flags
def recv_str(self) -> str:
"""Receive a text frame."""
data, flags = self.recv()
if not (flags & CurlWsFlag.TEXT):
raise WebSocketError("Not valid text frame", WsCloseCode.INVALID_DATA)
return data.decode()
def recv_json(self, *, loads: Callable[[str], T] = loads) -> T:
"""Receive a JSON frame.
Args:
loads: JSON decoder, default is json.loads.
"""
data = self.recv_str()
return loads(data)
def send(
self,
payload: Union[str, bytes, memoryview],
flags: CurlWsFlag = CurlWsFlag.BINARY,
):
"""Send a data frame.
Args:
payload: data to send.
flags: flags for the frame.
"""
if flags & CurlWsFlag.CLOSE:
self.keep_running = False
if self.closed:
raise WebSocketClosed("WebSocket is already closed")
# curl expects bytes
if isinstance(payload, str):
payload = payload.encode()
sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
if sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
)
# Loop checks for CurlECode.Again
# https://curl.se/libcurl/c/curl_ws_send.html
offset = 0
while offset < len(payload):
current_buffer = payload[offset:]
try:
n_sent = self.curl.ws_send(current_buffer, flags)
except CurlError as e:
if e.code == CurlECode.AGAIN:
_, writeable, _ = select([], [sock_fd], [], 0.5)
if not writeable:
raise WebSocketError("Socket write timeout") from e
continue
raise
offset += n_sent
return offset
def send_binary(self, payload: bytes):
"""Send a binary frame.
Args:
payload: binary data to send.
"""
return self.send(payload, CurlWsFlag.BINARY)
def send_bytes(self, payload: bytes):
"""Send a binary frame, alias of :meth:`send_binary`.
Args:
payload: binary data to send.
"""
return self.send(payload, CurlWsFlag.BINARY)
def send_str(self, payload: str):
"""Send a text frame.
Args:
payload: text data to send.
"""
return self.send(payload, CurlWsFlag.TEXT)
def send_json(self, payload: Any, *, dumps: Callable[[Any], str] = dumps):
"""Send a JSON frame.
Args:
payload: data to send.
dumps: JSON encoder, default is json.dumps.
"""
return self.send_str(dumps(payload))
def ping(self, payload: Union[str, bytes]):
"""Send a ping frame.
Args:
payload: data to send.
"""
return self.send(payload, CurlWsFlag.PING)
def run_forever(self, url: str = "", **kwargs):
"""Run the WebSocket forever. See :meth:`connect` for details on parameters.
libcurl automatically handles pings and pongs.
ref: https://curl.se/libcurl/c/libcurl-ws.html
"""
if url:
self.connect(url, **kwargs)
sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
if sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
)
self._emit("open")
# Keep reading the messages and invoke callbacks
# TODO: Reconnect logic
chunks = []
self.keep_running = True
while self.keep_running:
try:
chunk, frame = self.recv_fragment()
flags = frame.flags
self._emit("data", chunk, frame)
chunks.append(chunk)
if not (frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0):
continue
# Avoid unnecessary computation
if "message" in self._emitters:
# Concatenate collected chunks with the final message
msg = b"".join(chunks)
if (flags & CurlWsFlag.TEXT) and not self.skip_utf8_validation:
try:
msg = msg.decode() # type: ignore
except UnicodeDecodeError as e:
self._close_code = WsCloseCode.INVALID_DATA
self.close(WsCloseCode.INVALID_DATA)
raise WebSocketError(
"Invalid UTF-8", WsCloseCode.INVALID_DATA
) from e
if (flags & CurlWsFlag.BINARY) or (flags & CurlWsFlag.TEXT):
self._emit("message", msg)
chunks = [] # Reset chunks for next message
if flags & CurlWsFlag.CLOSE:
self.keep_running = False
self._emit("close", self._close_code or 0, self._close_reason or "")
except CurlError as e:
if e.code == CurlECode.AGAIN:
_, _, _ = select([sock_fd], [], [], 0.5)
else:
self._emit("error", e)
if not self.closed:
code = WsCloseCode.UNKNOWN
if isinstance(e, WebSocketError):
code = e.code
self.close(code)
raise
def close(self, code: int = WsCloseCode.OK, message: bytes = b""):
"""Close the connection.
Args:
code: close code.
message: close reason.
"""
if self.curl is not_set:
return
# TODO: As per spec, we should wait for the server to close the connection
# But this is not a requirement
msg = self._pack_close_frame(code, message)
self.send(msg, CurlWsFlag.CLOSE)
# The only way to close the connection appears to be curl_easy_cleanup
self.terminate()
class AsyncWebSocket(BaseWebSocket):
"""
An asyncio WebSocket implementation using libcurl.
NOTE: This object represents a single WebSocket connection. Once closed,
it cannot be reopened. A new instance must be created to reconnect.
"""
__slots__ = (
"session",
"_loop",
"_sock_fd",
"_close_lock",
"_terminate_lock",
"_read_task",
"_write_task",
"_close_handle",
"_receive_queue",
"_send_queue",
"_max_send_batch_size",
"_coalesce_frames",
"retry_on_recv_error",
"_yield_interval",
"_use_fair_scheduling",
"_yield_mask",
"_recv_error_retries",
"_terminated",
"_terminated_event",
)
# Match libcurl's documented max frame size limit.
_MAX_CURL_FRAME_SIZE: ClassVar[int] = 65535
_MAX_RECV_RETRIES: ClassVar[int] = 3
_RECV_RETRY_DELAY: ClassVar[float] = 0.3
def __init__(
self,
session: AsyncSession,
curl: Curl,
*,
autoclose: bool = True,
debug: bool = False,
recv_queue_size: int = 512,
send_queue_size: int = 256,
max_send_batch_size: int = 256,
coalesce_frames: bool = False,
retry_on_recv_error: bool = False,
yield_interval: float = 0.001,
fair_scheduling: bool = False,
yield_mask: int = 63,
) -> None:
"""Initializes an Async WebSocket session.
This class should not be instantiated directly. It is intended to be created
via the `AsyncSession.ws_connect()` method, which correctly handles setup and
initialization of the underlying I/O tasks.
Important:
This WebSocket implementation uses a decoupled I/O model. Network
operations occur in background tasks. As a result, network-related
errors that occur during a `send()` operation will not be raised by
`send()`. Instead, they are placed into the receive queue and will be
raised by the next call to `recv()`.
Args:
session (AsyncSession): An instantiated AsyncSession object.
curl (Curl): The underlying Curl to use.
autoclose (bool, optional): Close the WS on receiving a close frame.
debug (bool, optional): Enable debug messages. Defaults to False.
recv_queue_size (int, optional): The maximum number of incoming WebSocket
messages to buffer internally. This queue stores messages received
by the Curl socket that are waiting to be consumed by calling `recv()`
send_queue_size (int, optional): The maximum number of outgoing WebSocket
messages to buffer before applying network backpressure. When you call
`send(...)` the message is placed in this queue and transmitted when
the Curl socket is next available for sending.
max_send_batch_size (int, optional): The max number of messages per batch.
coalesce_frames (bool, optional): Combine multiple frames into a batch.
retry_on_recv_error (bool, optional): Retry recv on some transient errors.
yield_interval (float, optional): How often to yield control in seconds.
fair_scheduling (bool, optional): Change the ~5:1 ratio in favor
of `recv`:`send` to a fairer 1:1 ratio. This decreases recv throughput.
yield_mask (int, optional): A bitmask that sets the yield frequency for
cooperative multitasking, checked every `yield_mask + 1` operations.
Must be a power of two minus one (e.g., `63`, `127`, `255`) for
efficient bitwise checks. Lower values increase fairness; higher values
increase throughput.
"""
super().__init__(curl=curl, autoclose=autoclose, debug=debug)
self.session: AsyncSession[Response] = session
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._sock_fd: int = -1
self._close_lock: asyncio.Lock = asyncio.Lock()
self._terminate_lock: threading.Lock = threading.Lock()
self._terminated_event: asyncio.Event = asyncio.Event()
self._read_task: Optional[asyncio.Task[None]] = None
self._write_task: Optional[asyncio.Task[None]] = None
self._close_handle: Optional[asyncio.Handle] = None
self._receive_queue: asyncio.Queue[RECV_QUEUE_ITEM] = asyncio.Queue(
maxsize=recv_queue_size
)
self._send_queue: asyncio.Queue[SEND_QUEUE_ITEM] = asyncio.Queue(
maxsize=send_queue_size
)
self._max_send_batch_size: int = max_send_batch_size
self._coalesce_frames: bool = coalesce_frames
self.retry_on_recv_error: bool = retry_on_recv_error
self._yield_interval: float = yield_interval
self._use_fair_scheduling: bool = fair_scheduling
self._yield_mask: int = yield_mask
self._recv_error_retries: int = 0
self._terminated: bool = False
@property
def loop(self) -> asyncio.AbstractEventLoop:
"""Get a reference to the running event loop"""
if self._loop is None:
self._loop = get_selector(asyncio.get_running_loop())
return self._loop
@property
def send_queue_size(self) -> int:
"""Returns the current number of items in the send queue."""
return self._send_queue.qsize()
def is_alive(self) -> bool:
"""
Checks if the background I/O tasks are still running.
Returns `False` if either the read or write task has terminated due
to an error or a clean shutdown.
Note: This is a snapshot in time. A return value of `True` does not
guarantee the next network operation will succeed, but `False`
definitively indicates the connection is no longer active.
"""
if self.closed or self._terminated:
return False
if self._read_task and self._read_task.done():
return False
return not (self._write_task and self._write_task.done())
def __del__(self) -> None:
"""Warn if the user forgets to close the connection."""
if getattr(self, "closed", True):
return
with suppress(Exception):
warnings.warn(
f"Unclosed WebSocket {self!r} was garbage collected. "
"Always call await ws.close() to ensure clean shutdown.",
ResourceWarning,
stacklevel=2,
)
def __aiter__(self) -> Self:
if self.closed:
raise WebSocketClosed("WebSocket has been closed")
return self
async def __anext__(self) -> bytes:
msg, flags = await self.recv()
if (msg is None) or (flags & CurlWsFlag.CLOSE):
raise StopAsyncIteration
return msg
def _start_io_tasks(self) -> None:
"""Start the read/write I/O loop tasks.
This should be called only once after object creation by the factory.
Once started, the tasks cannot be restarted again, this is a one-shot.
Raises:
WebSocketError: The WebSocket FD was invalid.
"""
# Return early if already started
if self._read_task is not None:
return
# Get the currently active socket FD
self._sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
if self._sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket.", code=CurlECode.NO_CONNECTION_AVAILABLE
)
# Get an identifier for the websocket from its object id
ws_id: str = f"WebSocket-{id(self):#x}"
# Start the I/O loop tasks
self._read_task = self.loop.create_task(self._read_loop(), name=f"{ws_id}-read")
self._write_task = self.loop.create_task(
self._write_loop(), name=f"{ws_id}-write"
)
async def recv(
self, *, timeout: Optional[float] = None
) -> tuple[Optional[bytes], int]:
"""Receive a frame as bytes.
This method waits for and returns the next complete data frame from the
receive queue.
Args:
timeout: how many seconds to wait before giving up.
Raises:
WebSocketClosed: If `recv()` is called on a closed connection after
the receive queue is empty.
WebSocketTimeout: If the operation times out.
WebSocketError: A protocol or network error that occurred in a
background I/O task, including errors from previous `send()`
operations.
Returns:
tuple[bytes, int]: A tuple with the received payload and flags.
"""
if self.closed and self._receive_queue.empty():
raise WebSocketClosed("WebSocket is closed")
try:
result, flags = await asyncio.wait_for(self._receive_queue.get(), timeout)
if isinstance(result, Exception):
raise result
return result, flags
except asyncio.TimeoutError as e:
raise WebSocketTimeout(
"WebSocket recv() timed out", CurlECode.OPERATION_TIMEDOUT
) from e
async def recv_str(self, *, timeout: Optional[float] = None) -> str:
"""Receive a text frame.
Args:
timeout: how many seconds to wait before giving up.
"""
data, flags = await self.recv(timeout=timeout)
if data is None or not (flags & CurlWsFlag.TEXT):
raise WebSocketError("Not a valid text frame", WsCloseCode.INVALID_DATA)
try:
return data.decode("utf-8")
except UnicodeDecodeError as e:
raise WebSocketError(
"Invalid UTF-8 in text frame", WsCloseCode.INVALID_DATA
) from e
async def recv_json(
self,
*,
loads: Callable[[Union[str, bytes]], T] = loads,
timeout: Optional[float] = None,
) -> T:
"""Receive a JSON frame.
Args:
loads: JSON decoder, default is json.loads.
timeout: how many seconds to wait before giving up.
"""
data, flags = await self.recv(timeout=timeout)
if data is None:
raise WebSocketError(
"Received empty frame, cannot decode JSON", WsCloseCode.INVALID_DATA
)
if flags & CurlWsFlag.TEXT:
try:
return loads(data.decode("utf-8"))
except UnicodeDecodeError as e:
raise WebSocketError(
"Invalid UTF-8 in JSON text frame", WsCloseCode.INVALID_DATA
) from e
return loads(data)
async def send(
self,
payload: Union[str, bytes, bytearray, memoryview],
flags: CurlWsFlag = CurlWsFlag.BINARY,
):
"""Send a data frame.
This method is a lightweight, non-blocking call that places the payload
into a send queue. The actual network transmission is handled by a
background task.
To guarantee all your messages have been sent `await ws.flush(...)`.
The max frame size supported by libcurl is `65535` bytes. Larger frames
will be broken down and sent in chunks of that size.
Args:
payload: data to send.
flags: flags for the frame.
Raises:
WebSocketClosed: The WebSocket is closed.
NOTE:
Due to the asynchronous nature of this client, network errors
(e.g., connection dropped) that occur during the actual transmission
will NOT be raised by this method. They will be raised by a
subsequent call to `recv()`. Always ensure you are actively
receiving data to handle potential connection errors.
Also: If the network is slow and the internal send queue becomes full,
this method will block until there is space in the queue.
"""
if self.closed:
raise WebSocketClosed("WebSocket is closed")
# cURL expects bytes
if isinstance(payload, str):
payload = payload.encode("utf-8")
elif isinstance(payload, (bytearray, memoryview)):
payload = bytes(payload)
try:
self._send_queue.put_nowait((payload, flags))
except asyncio.QueueFull:
await self._send_queue.put((payload, flags))
async def send_binary(self, payload: bytes) -> None:
"""Send a binary frame.
Args:
payload: binary data to send.
For more info, see the docstring for `send(...)`
"""
return await self.send(payload, CurlWsFlag.BINARY)
async def send_bytes(self, payload: bytes) -> None:
"""Send a binary frame, alias of :meth:`send_binary`.
Args:
payload: binary data to send.
For more info, see the docstring for `send(...)`
"""
return await self.send(payload, CurlWsFlag.BINARY)
async def send_str(self, payload: str) -> None:
"""Send a text frame.
Args:
payload: text data to send.
For more info, see the docstring for `send(...)`
"""
return await self.send(payload, CurlWsFlag.TEXT)
async def send_json(
self, payload: Any, *, dumps: Callable[[Any], str] = dumps
) -> None:
"""Send a JSON frame.
Args:
payload: data to send.
dumps: JSON encoder, default is `json.dumps(...)`.
For more info, see the docstring for `send(...)`
"""
return await self.send_str(dumps(payload))
async def ping(self, payload: Union[str, bytes]):
"""Send a ping frame.
Args:
payload: data to send.
For more info, see the docstring for `send(...)`
"""
return await self.send(payload, CurlWsFlag.PING)
async def close(
self, code: int = WsCloseCode.OK, message: bytes = b"", timeout: float = 5.0
) -> None:
"""
Performs a graceful WebSocket closing handshake and terminates the connection.
This method sends a WebSocket close frame to the peer, waits for queued
outgoing messages to be sent, and then shuts down the connection. This is
the recommended way to close the session.
Args:
code (int, optional): Close code. Defaults to `WsCloseCode.OK`.
message (bytes, optional): Close reason. Defaults to `b""`.
timeout (float, optional): How long in seconds to wait closed.
"""
async with self._close_lock:
if self.closed:
return
self.closed = True
try:
if self._write_task and not self._write_task.done():
close_frame = self._pack_close_frame(code, message)
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(
self._send_queue.put((close_frame, CurlWsFlag.CLOSE)),
timeout=timeout,
)
with suppress(WebSocketTimeout, WebSocketError):
await self.flush(timeout)
finally:
self.terminate()
# Wait for the termination completion signal
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(self._terminated_event.wait(), timeout)
def terminate(self) -> None:
"""
Immediately terminates the connection without a graceful handshake.
This method is a forceful shutdown that cancels all background I/O tasks
and cleans up resources. It should be used for final cleanup or after an
unrecoverable error. Unlike `close()`, it does not attempt to send a close
frame or wait for pending messages. It schedules the cleanup to run on the
event loop and returns immediately. It does not wait for cleanup completion.
This method is thread-safe, task-safe, and idempotent.
"""
with self._terminate_lock:
if self._terminated:
return
self._terminated = True
# Terminate the connection in a thread-safe way
if self._loop and self.loop.is_running():
self._close_handle = self.loop.call_soon_threadsafe(
lambda: self.loop.create_task(self._terminate_helper())
)
# The event loop is not running
else:
super().terminate()
if self.session and not self.session._closed:
# WebSocket curls CANNOT be reused
self.session.push_curl(None)
self._terminated_event.set()
async def _read_loop(self) -> None:
"""The main asynchronous task for reading incoming WebSocket frames.
This method is fully event-driven. It waits for the underlying socket
to become readable, and upon being woken by the event loop, it drains
all buffered data from libcurl until it receives an EAGAIN error. This
error signals that the buffer is empty, and the loop returns to an
idle state, waiting for the next readability event.
To ensure cooperative multitasking during high-volume message streams,
the loop yields control to the asyncio event loop periodically.
If the receive queue becomes full, await `self._receive_queue.put(...)`
will block the reader loop and stall the socket read task. Thus, appropriate
queue sizes should be set by the user, even though the defaults are generous
and should be suitable for most use cases.
"""
# Cache locals to avoid repeated attribute lookups
curl_ws_recv = self._curl.ws_recv
queue_put_nowait = self._receive_queue.put_nowait
queue_put = self._receive_queue.put
loop = self.loop
fair_scheduling = self._use_fair_scheduling
yield_mask = self._yield_mask
yield_interval = self._yield_interval
chunks: list[bytes] = []
msg_counter = 0
try:
# The outer loop waits for readability events.
while not self.closed:
# Wait for the socket to be readable.
read_future = loop.create_future()
try:
loop.add_reader(self._sock_fd, read_future.set_result, None)
except Exception as exc:
with suppress(asyncio.QueueFull):
queue_put_nowait(
(
WebSocketError(
f"add_reader failed: {exc}",
CurlECode.NO_CONNECTION_AVAILABLE,
),
0,
)
)
self.terminate()
return
try:
await read_future
finally:
# Remove the reader immediately after waking.
if self._sock_fd != -1:
loop.remove_reader(self._sock_fd)
# There is data, so we now read until it's empty.
start_time = loop.time()
while True:
try:
chunk, frame = curl_ws_recv()
flags: int = frame.flags
if self._recv_error_retries > 0:
self._recv_error_retries = 0
# If a CLOSE frame is received, the reader is done.
if flags & CurlWsFlag.CLOSE:
with suppress(asyncio.QueueFull):
queue_put_nowait((chunk, flags))
await self._handle_close_frame(chunk)
return
# Collect the chunk
chunks.append(chunk)
# If the message is complete, process and dispatch it
if frame.bytesleft <= 0 and (flags & CurlWsFlag.CONT) == 0:
message = b"".join(chunks)
chunks.clear()
try:
queue_put_nowait((message, flags))
except asyncio.QueueFull:
await queue_put((message, flags))
msg_counter += 1
op_check: bool = (msg_counter & yield_mask) == 0
time_check: bool = loop.time() - start_time > yield_interval
if (
(op_check or time_check)
if fair_scheduling
else (op_check and time_check)
):
await asyncio.sleep(0)
start_time = loop.time()
except CurlError as e:
# Normal EAGAIN from Curl
if e.code == CurlECode.AGAIN:
break
# Transient error, can be retried
if (
e.code == CurlECode.RECV_ERROR
and self.retry_on_recv_error
and self._recv_error_retries < self._MAX_RECV_RETRIES
):
self._recv_error_retries += 1
await asyncio.sleep(
self._RECV_RETRY_DELAY * self._recv_error_retries
)
continue
# Unrecoverable, place error on queue and cleanup
with suppress(asyncio.QueueFull):
queue_put_nowait((e, 0))
self.terminate()
return
except asyncio.CancelledError:
pass
except Exception as e:
if not self.closed:
with suppress(asyncio.QueueFull):
queue_put_nowait((e, 0))
self.terminate()
finally:
with suppress(asyncio.QueueFull):
queue_put_nowait((WebSocketClosed("Connection closed."), 0))
async def _write_loop(self) -> None:
"""
The high-level send manager. It efficiently gathers pending messages
from the send queue and orchestrates their transmission.
This method runs a continuous loop that consumes messages from the
`_send_queue`. To improve performance and reduce system call overhead,
it implements an adaptive batching strategy. It greedily gathers
multiple pending messages from the queue and then coalesces the
payloads of messages that share the same flags (e.g., all text frames)
into a single, larger payload, ONLY if `coalesce_frames=True` and the
frame is not a CONTROL frame, as the spec requires them to be whole.
It will batch as many as possible, then iterate over the batch and send
the frames, one at a time. This batching and coalescing significantly
improves throughput for high volumes of small messages where the message
boundaries do not matter. The final, consolidated payloads are then passed
to the `_send_payload` method for transmission.
"""
control_frame_flags: int = CurlWsFlag.CLOSE | CurlWsFlag.PING
send_payload = self._send_payload
queue_get = self._send_queue.get
queue_get_nowait = self._send_queue.get_nowait
try:
while True:
payload, flags = await queue_get()
# Build the rest of the batch without awaiting.
batch = [(payload, flags)]
if not flags & CurlWsFlag.CLOSE:
while len(batch) < self._max_send_batch_size:
try:
payload, frame = queue_get_nowait()
batch.append((payload, frame))
if frame & CurlWsFlag.CLOSE:
break
except asyncio.QueueEmpty:
break
try:
# Process the batch depending on the coalescing strategy
if self._coalesce_frames:
data_to_coalesce: dict[int, list[bytes]] = {}
for payload, frame in batch:
if frame & control_frame_flags:
# Flush any pending data before the control frame.
for frame_group, payloads in data_to_coalesce.items():
if not await send_payload(
b"".join(payloads), frame_group
):
return
data_to_coalesce.clear()
if not await send_payload(payload, frame):
return
else:
data_to_coalesce.setdefault(frame, []).append(payload)
# Send any remaining data at the end of the batch.
for frame_group, payloads in data_to_coalesce.items():
if not await send_payload(b"".join(payloads), frame_group):
return
else:
# Send each message in the batch, preserving frame boundaries.
for payload, frame in batch:
if not await send_payload(payload, frame):
return
finally:
# Mark all processed items as done.
for _ in range(len(batch)):
self._send_queue.task_done()
# Exit cleanly after sending a CLOSE frame.
if batch[-1][1] & CurlWsFlag.CLOSE:
break
except asyncio.CancelledError:
pass
except Exception as e:
if not self.closed:
with suppress(asyncio.QueueFull):
self._receive_queue.put_nowait((e, 0))
finally:
# If the loop exits unexpectedly, ensure we terminate the connection.
if not self.closed:
self.terminate()
async def _send_payload(self, payload: bytes, flags: CurlWsFlag) -> bool:
"""
The low-level I/O Handler. It transmits a single payload, handling
fragmentation, backpressure (EAGAIN), and cooperative multitasking.
Returns False on a non-recoverable error.
Args:
payload: The complete byte payload to be sent.
flags: The `CurlWsFlag` indicating the frame type (e.g., `TEXT`, `BINARY`).
"""
# Cache locals to reduce lookup cost
curl_ws_send = self._curl.ws_send
queue_put_nowait = self._receive_queue.put_nowait
loop = self.loop
view = memoryview(payload)
offset = 0
write_ops = 0
start_time: float = loop.time()
while offset < len(view):
# Cooperatively yield to the event loop to prevent starvation.
if (write_ops & self._yield_mask) == 0 or (
loop.time() - start_time
) > self._yield_interval:
await asyncio.sleep(0)
start_time = loop.time()
try:
chunk = view[offset : offset + self._MAX_CURL_FRAME_SIZE]
n_sent = curl_ws_send(chunk, flags)
if n_sent == 0:
with suppress(asyncio.QueueFull):
queue_put_nowait(
(
WebSocketError(
"ws_send returned 0 bytes",
CurlECode.SEND_ERROR,
),
0,
)
)
return False
offset += n_sent
write_ops += 1
except CurlError as e:
if e.code != CurlECode.AGAIN:
# Non-recoverable error: queue it for the user and signal failure.
with suppress(asyncio.QueueFull):
queue_put_nowait((e, 0))
return False
# EAGAIN: wait until the socket is writable.
write_future = loop.create_future()
try:
loop.add_writer(self._sock_fd, write_future.set_result, None)
except Exception as exc:
with suppress(asyncio.QueueFull):
queue_put_nowait(
(
WebSocketError(
f"add_writer failed: {exc}",
CurlECode.NO_CONNECTION_AVAILABLE,
),
0,
)
)
return False
try:
await write_future
finally:
if self._sock_fd != -1:
loop.remove_writer(self._sock_fd)
return True
async def flush(self, timeout: Optional[float] = None) -> None:
"""Waits until all items in the send queue have been processed.
This ensures that all messages passed to `send()` have been handed off to the
underlying socket for transmission. It does not guarantee that the data has
been received by the remote peer.
Args:
timeout (Optional[float], optional): The maximum number of seconds to wait
for the queue to drain.
Raises:
WebSocketTimeout: If the send queue is not fully processed within the
specified ``timeout`` period.
WebSocketError: If the writer task has already terminated while unsent
messages remain in the queue.
"""
if (
self._write_task
and self._write_task.done()
and not self._send_queue.empty()
):
raise WebSocketError(
"Cannot flush, writer task has terminated unexpectedly."
)
try:
await asyncio.wait_for(self._send_queue.join(), timeout=timeout)
except asyncio.TimeoutError as e:
raise WebSocketTimeout("Timed out waiting for send queue to flush.") from e
async def _terminate_helper(self) -> None:
"""Utility method for connection termination"""
tasks_to_cancel: set[asyncio.Task[None]] = set()
max_timeout: int = 3
try:
# Cancel all the I/O tasks
for io_task in (self._read_task, self._write_task):
try:
if io_task and not io_task.done():
io_task.cancel()
tasks_to_cancel.add(io_task)
except (asyncio.CancelledError, RuntimeError):
...
# Wait for cancellation but don't get stuck
if tasks_to_cancel:
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(
asyncio.gather(*tasks_to_cancel, return_exceptions=True),
timeout=max_timeout,
)
# Drain the send_queue
while not self._send_queue.empty():
try:
self._send_queue.get_nowait()
self._send_queue.task_done()
except (asyncio.QueueEmpty, ValueError):
break
# Remove the reader/writer if still registered
if self._sock_fd != -1:
with suppress(Exception):
self.loop.remove_reader(self._sock_fd)
with suppress(Exception):
self.loop.remove_writer(self._sock_fd)
self._sock_fd = -1
# Close the Curl connection
super().terminate()
if self.session and not self.session._closed:
# WebSocket curls CANNOT be reused
self.session.push_curl(None)
finally:
self._terminated_event.set()
async def _handle_close_frame(self, message: bytes) -> None:
"""Unpack and handle the closing frame, then initiate shutdown."""
try:
self._close_code, self._close_reason = self._unpack_close_frame(message)
except WebSocketError as e:
self._close_code = e.code
if self.autoclose and not self.closed:
await self.close(self._close_code or WsCloseCode.OK)
else:
# If not sending a reply, we must still terminate the connection.
self.terminate()