Source code for rsocket.routing.request_router

from asyncio import Future
from dataclasses import dataclass
from inspect import signature
from typing import Callable, Any, Dict

from rsocket.exceptions import RSocketUnknownRoute, RSocketEmptyRoute
from rsocket.extensions.composite_metadata import CompositeMetadata
from rsocket.frame import FrameType
from rsocket.frame_helpers import safe_len
from rsocket.helpers import create_future
from rsocket.payload import Payload
from rsocket.rsocket import RSocket

__all__ = ['RequestRouter']

decorated_method = Callable[[RSocket, Payload, CompositeMetadata], Any]


class RouteInfo:
    def __init__(self, method):
        self.method: Callable = method
        self.signature = signature(method)


def decorator_factory(container: dict, route: str):
    def decorator(function: decorated_method):
        if safe_len(route) == 0:
            raise RSocketEmptyRoute(function.__name__)

        if route in container:
            raise KeyError('Duplicate route "%s" already registered', route)

        container[route] = RouteInfo(function)
        return function

    return decorator


@dataclass
class Handlers:
    response: callable = None
    stream: callable = None
    channel: callable = None
    fire_and_forget: callable = None
    metadata_push: callable = None


[docs] class RequestRouter: """ Used to define routes for RSocket endpoints. Pass this to :class:`RoutingRequestHandler <rsocket.routing.routing_request_handler.RoutingRequestHandler>` to instantiate a handler using these routes. """ __slots__ = ( '_channel_routes', '_stream_routes', '_response_routes', '_fnf_routes', '_metadata_push', '_route_map_by_frame_type', '_payload_deserializer', '_payload_serializer', '_unknown' ) def __init__(self, payload_deserializer=lambda cls, _: _, payload_serializer=lambda cls, _: _): self._payload_serializer = payload_serializer self._payload_deserializer = payload_deserializer self._channel_routes: Dict[str, RouteInfo] = {} self._stream_routes: Dict[str, RouteInfo] = {} self._response_routes: Dict[str, RouteInfo] = {} self._fnf_routes: Dict[str, RouteInfo] = {} self._metadata_push: Dict[str, RouteInfo] = {} self._unknown = Handlers() self._route_map_by_frame_type: Dict[int, Dict[str, RouteInfo]] = { FrameType.REQUEST_CHANNEL: self._channel_routes, FrameType.REQUEST_FNF: self._fnf_routes, FrameType.REQUEST_STREAM: self._stream_routes, FrameType.REQUEST_RESPONSE: self._response_routes, FrameType.METADATA_PUSH: self._metadata_push, } def response(self, route: str): return decorator_factory(self._response_routes, route) def response_unknown(self): def wrapper(function): self._unknown.response = RouteInfo(function) return function return wrapper def stream(self, route: str): return decorator_factory(self._stream_routes, route) def stream_unknown(self): def wrapper(function): self._unknown.stream = RouteInfo(function) return function return wrapper def channel(self, route: str): return decorator_factory(self._channel_routes, route) def channel_unknown(self): def wrapper(function): self._unknown.channel = RouteInfo(function) return function return wrapper def fire_and_forget(self, route: str): return decorator_factory(self._fnf_routes, route) def fire_and_forget_unknown(self): def wrapper(function): self._unknown.fire_and_forget = RouteInfo(function) return function return wrapper def metadata_push(self, route: str): return decorator_factory(self._metadata_push, route) def metadata_push_unknown(self): def wrapper(function): self._unknown.metadata_push = RouteInfo(function) return function return wrapper async def route(self, frame_type: FrameType, route: str, payload: Payload, composite_metadata: CompositeMetadata): if route in self._route_map_by_frame_type[frame_type]: route_info = self._route_map_by_frame_type[frame_type][route] else: route_info = self._get_unknown_route(frame_type) if route_info is None: raise RSocketUnknownRoute(route) route_kwargs = self._collect_route_arguments(route_info, payload, composite_metadata) result = await route_info.method(**route_kwargs) if frame_type == FrameType.REQUEST_RESPONSE and not isinstance(result, Future): if not isinstance(result, Payload): result = self._payload_serializer(route_info.signature.return_annotation, result) return create_future(result) return result def _collect_route_arguments(self, route_info: RouteInfo, payload: Payload, composite_metadata: CompositeMetadata): route_signature = route_info.signature route_kwargs = {} for parameter in route_signature.parameters: parameter_type = route_signature.parameters[parameter] if 'composite_metadata' == parameter or parameter_type is CompositeMetadata: route_kwargs['composite_metadata'] = composite_metadata else: payload_data = payload if parameter_type.annotation not in (Payload, parameter_type.empty): payload_data = self._payload_deserializer(parameter_type.annotation, payload) route_kwargs[parameter] = payload_data return route_kwargs def _get_unknown_route(self, frame_type: FrameType) -> Callable: if frame_type == FrameType.REQUEST_RESPONSE: return self._unknown.response elif frame_type == FrameType.REQUEST_STREAM: return self._unknown.stream elif frame_type == FrameType.REQUEST_CHANNEL: return self._unknown.channel elif frame_type == FrameType.REQUEST_FNF: return self._unknown.fire_and_forget elif frame_type == FrameType.METADATA_PUSH: return self._unknown.metadata_push