import asyncio
from collections import deque
from typing import Callable, Deque, Dict, List, Optional, Union
from urllib.parse import urlparse
import wsproto
import wsproto.events
from aioquic.asyncio import QuicConnectionProtocol
from aioquic.h3.connection import H3Connection
from aioquic.h3.events import (
DataReceived,
H3Event,
HeadersReceived,
PushPromiseReceived,
)
from aioquic.quic.events import QuicEvent
from starlette.websockets import WebSocket, WebSocketDisconnect
from wsproto.utilities import LocalProtocolError
from rsocket.exceptions import RSocketTransportError
from rsocket.frame import Frame, serialize_with_frame_size_header, KeepAliveFrame
from rsocket.helpers import wrap_transport_exception, cancel_if_task_exists
from rsocket.logger import logger
from rsocket.transports.abstract_messaging import AbstractMessagingTransport
class ClientWebSocket:
def __init__(
self, http: H3Connection, stream_id: int, transmit: Callable[[], None]
) -> None:
self.http = http
self.queue: asyncio.Queue[bytes] = asyncio.Queue()
self.stream_id = stream_id
self.subprotocol: Optional[str] = None
self.transmit = transmit
self.websocket = wsproto.Connection(wsproto.ConnectionType.CLIENT)
async def send_bytes(self, message: bytes) -> None:
data = self.websocket.send(wsproto.events.BytesMessage(data=message))
self.http.send_data(stream_id=self.stream_id, data=data, end_stream=False)
self.transmit()
async def receive_bytes(self) -> bytes:
return await self.queue.get()
async def close(self, code: int = 1000, reason: str = "") -> None:
data = self.websocket.send(
wsproto.events.CloseConnection(code=code, reason=reason)
)
self.http.send_data(stream_id=self.stream_id, data=data, end_stream=True)
self.transmit()
def http_event_received(self, event: H3Event) -> None:
if isinstance(event, HeadersReceived):
for header, value in event.headers:
if header == b"sec-websocket-protocol":
self.subprotocol = value.decode()
elif isinstance(event, DataReceived):
self.websocket.receive_data(event.data)
for ws_event in self.websocket.events():
self.websocket_event_received(ws_event)
def websocket_event_received(self, event: wsproto.events.Event) -> None:
if isinstance(event, wsproto.events.BytesMessage):
self.queue.put_nowait(event.data)
class URL:
def __init__(self, url: str) -> None:
parsed = urlparse(url)
self.authority = parsed.netloc
self.full_path = parsed.path or "/"
if parsed.query:
self.full_path += "?" + parsed.query
self.scheme = parsed.scheme
[docs]
class RSocketHttp3ClientProtocol(QuicConnectionProtocol):
"""
RSocket transport over client side http3 connection.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pushes: Dict[int, Deque[H3Event]] = {}
self._http: Optional[H3Connection] = None
self._request_events: Dict[int, Deque[H3Event]] = {}
self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {}
self._websockets: Dict[int, ClientWebSocket] = {}
self._http = H3Connection(self._quic)
async def websocket(
self, url: str, subprotocols: Optional[List[str]] = None
) -> ClientWebSocket:
parsed_url = URL(url)
stream_id = self._quic.get_next_available_stream_id()
websocket = ClientWebSocket(
http=self._http, stream_id=stream_id, transmit=self.transmit
)
self._websockets[stream_id] = websocket
headers = [
(b":method", b"CONNECT"),
(b":scheme", b"https"),
(b":authority", parsed_url.authority.encode()),
(b":path", parsed_url.full_path.encode()),
(b":protocol", b"websocket"),
(b"user-agent", b'rsocket'),
(b"sec-websocket-version", b"13"),
]
if subprotocols:
headers.append(
(b"sec-websocket-protocol", ", ".join(subprotocols).encode())
)
self._http.send_headers(stream_id=stream_id, headers=headers)
self.transmit()
return websocket
def http_event_received(self, event: H3Event) -> None:
if isinstance(event, (HeadersReceived, DataReceived)):
stream_id = event.stream_id
if stream_id in self._request_events:
# http
self._request_events[event.stream_id].append(event)
if event.stream_ended:
request_waiter = self._request_waiter.pop(stream_id)
request_waiter.set_result(self._request_events.pop(stream_id))
elif stream_id in self._websockets:
# websocket
websocket = self._websockets[stream_id]
websocket.http_event_received(event)
elif event.push_id in self.pushes:
# push
self.pushes[event.push_id].append(event)
elif isinstance(event, PushPromiseReceived):
self.pushes[event.push_id] = deque()
self.pushes[event.push_id].append(event)
def quic_event_received(self, event: QuicEvent) -> None:
events = self._http.handle_event(event)
for http_event in events:
self.http_event_received(http_event)
[docs]
class Http3TransportWebsocket(AbstractMessagingTransport):
"""
RSocket transport over server side http3 connection.
"""
def __init__(self, websocket: Union[WebSocket, ClientWebSocket]):
super().__init__()
self._websocket = websocket
self._listener = asyncio.create_task(self.incoming_data_listener())
self._disconnect_event = asyncio.Event()
async def send_frame(self, frame: Frame):
with wrap_transport_exception():
try:
data = serialize_with_frame_size_header(frame)
try:
await self._websocket.send_bytes(data)
except LocalProtocolError as exception:
if (not ((str(exception).endswith('ConnectionState.REMOTE_CLOSING.')
or str(exception).endswith('ConnectionState.CLOSED.'))
and isinstance(frame, KeepAliveFrame))):
raise RSocketTransportError(str(frame)) from exception
await asyncio.sleep(0)
except WebSocketDisconnect:
self._disconnect_event.set()
async def close(self):
await cancel_if_task_exists(self._listener)
# await self._websocket.close()
async def incoming_data_listener(self):
try:
while True:
try:
data = await self._websocket.receive_bytes()
except WebSocketDisconnect:
self._disconnect_event.set()
break
async for frame in self._frame_parser.receive_data(data):
self._incoming_frame_queue.put_nowait(frame)
except asyncio.CancelledError:
logger().debug('Asyncio task canceled: incoming_data_listener')
except WebSocketDisconnect:
pass
except Exception:
self._incoming_frame_queue.put_nowait(RSocketTransportError())
async def wait_for_disconnect(self):
await self._disconnect_event.wait()
def requires_length_header(self) -> bool:
return True