#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2018-01-19
# @Filename: command.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)
from __future__ import annotations
import asyncio
import random
from contextlib import suppress
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
import aio_pika as apika
import aiormq
from .exceptions import CluError
__all__ = [
"TCPProtocol",
"PeriodicTCPServer",
"TCPStreamServer",
"TCPStreamPeriodicServer",
"TCPStreamClient",
"open_connection",
"TopicListener",
"ReconnectingTCPClientProtocol",
]
T = TypeVar("T")
ConnectionCallbackType = Callable[[Any], Any]
DataReceivedCallbackType = Callable[[Any, bytes], Any]
[docs]
class TCPProtocol(asyncio.Protocol):
"""A TCP server/client based on asyncio protocols.
This is a high-level implementation of the client and server asyncio
protocols. See `asyncio protocol
<https://docs.python.org/3/library/asyncio-protocol.html>`__ for details.
Parameters
----------
loop
The event loop. The current event loop is used by default.
connection_callback
Callback to call when a new client connects.
data_received_callback
Callback to call when a new data is received.
max_connections
How many clients the server accepts. If `None`, unlimited connections
are allowed.
"""
def __init__(
self,
loop: asyncio.AbstractEventLoop | None = None,
connection_callback: Optional[ConnectionCallbackType] = None,
data_received_callback: Optional[Callable[[str], Any]] = None,
max_connections: Optional[int] = None,
):
self.connection_callback = connection_callback
self.data_received_callback = data_received_callback
self.transports = []
self.max_connections = max_connections
self.loop = loop or asyncio.get_event_loop()
[docs]
@classmethod
async def create_server(cls, host: str, port: int, **kwargs):
"""Returns a `~asyncio.Server` connection."""
loop = kwargs.get("loop", asyncio.get_event_loop())
new_tcp = cls(**kwargs)
server = await loop.create_server(lambda: new_tcp, host, port)
await server.start_serving()
return server
[docs]
@classmethod
async def create_client(cls, host: str, port: int, **kwargs):
"""Returns a `~asyncio.Transport` and `~asyncio.Protocol`."""
if "connection_callback" in kwargs:
raise KeyError("connection_callback not allowed when creating a client.")
loop = kwargs.get("loop", asyncio.get_event_loop())
new_tcp = cls.__new__(cls, **kwargs)
transport, protocol = await loop.create_connection(lambda: new_tcp, host, port)
return transport, protocol
[docs]
def connection_made(self, transport: asyncio.Transport):
"""Receives a connection and calls the connection callback."""
if self.max_connections is None or (
len(self.transports) < self.max_connections
):
self.transports.append(transport)
else:
transport.write(b"Maximum number of connections reached.")
transport.close()
if self.connection_callback:
self.connection_callback(transport)
[docs]
def data_received(self, data: bytes):
"""Decodes the received data."""
if self.data_received_callback:
self.data_received_callback(data.decode())
[docs]
def connection_lost(self, exc):
"""Called when connection is lost."""
pass
[docs]
class PeriodicTCPServer(TCPProtocol):
"""A TCP server that runs a callback periodically.
Parameters
----------
period_callback
Callback to run every iteration.
sleep_time
The delay between two calls to ``periodic_callback``.
kwargs
Parameters to pass to `TCPProtocol`
"""
def __init__(
self,
periodic_callback: Optional[ConnectionCallbackType] = None,
sleep_time: float = 1,
**kwargs,
):
self._periodic_callback = periodic_callback
self.sleep_time = sleep_time
self.periodic_task = None
super().__init__(**kwargs)
[docs]
@classmethod
async def create_client(cls, *args, **kwargs):
raise NotImplementedError(
"create_client is not implemented for PeriodicTCPServer."
)
[docs]
@classmethod
async def create_server(cls, host: str, port: int, *args, **kwargs):
"""Returns a `~asyncio.Server` connection."""
loop = kwargs.get("loop", asyncio.get_event_loop())
new_tcp = cls(*args, **kwargs)
server = await loop.create_server(lambda: new_tcp, host, port)
await server.start_serving()
new_tcp.periodic_task = asyncio.create_task(new_tcp._emit_periodic())
return server
@property
def periodic_callback(self) -> Optional[ConnectionCallbackType]:
"""Returns the periodic callback."""
return self._periodic_callback
@periodic_callback.setter
def periodic_callback(self, func: Callable[[asyncio.Transport], Any]):
"""Sets the periodic callback."""
self._periodic_callback = func
async def _emit_periodic(self):
while True:
if self.periodic_callback is not None:
for transport in self.transports:
if asyncio.iscoroutinefunction(self.periodic_callback):
await self.periodic_callback(transport)
else:
self.periodic_callback(transport)
await asyncio.sleep(self.sleep_time)
[docs]
class TCPStreamServer(object):
"""A TCP server based on asyncio streams.
This is a high-level implementation of the asyncio server using
streams. See `asyncio streams
<https://docs.python.org/3/library/asyncio-stream.html>`__ for details.
Parameters
----------
host
The server host.
port
The server port.
connection_callback
Callback to call when a new client connects or disconnects.
data_received_callback
Callback to call when a new data is received.
loop
The event loop. The current event loop is used by default.
max_connections
How many clients the server accepts. If `None`, unlimited connections
are allowed.
"""
def __init__(
self,
host: str,
port: int,
connection_callback: Optional[ConnectionCallbackType] = None,
data_received_callback: Optional[DataReceivedCallbackType] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
max_connections: Optional[int] = None,
):
self.host = host
self.port = port
self.transports = {}
self.loop = loop or asyncio.get_event_loop()
self.max_connections = max_connections
self.connection_callback = connection_callback
self.data_received_callback = data_received_callback
# The `asyncio.Server`. Created when `.start_server` is run.
self._server = None
[docs]
async def start(self) -> asyncio.AbstractServer:
"""Starts the server and returns a `~asyncio.Server` connection."""
self._server = await asyncio.start_server(
self.connection_made,
self.host,
self.port,
)
return self._server
[docs]
def stop(self):
"""Stops the server."""
assert self._server
self._server.close()
[docs]
def serve_forever(self):
"""Exposes ``TCPStreamServer.server.serve_forever``."""
assert self._server
return self._server.serve_forever()
def is_serving(self) -> bool:
assert self._server
return self._server.is_serving()
async def _do_callback(self, cb, *args, **kwargs):
"""Calls a function or coroutine callback."""
if asyncio.iscoroutinefunction(cb):
return await asyncio.create_task(cb(*args, **kwargs))
else:
return cb(*args, **kwargs)
[docs]
async def connection_made(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
):
"""Called when a new client connects to the server.
Stores the writer protocol in ``transports``, calls the connection
callback, if any, and starts a loop to read any incoming data.
"""
if self.max_connections and len(self.transports) == self.max_connections:
writer.write("Max number of connections reached.\n".encode())
await writer.drain()
return
self.transports[writer.transport] = writer
if self.connection_callback:
await self._do_callback(self.connection_callback, writer.transport)
while True:
try:
data = await reader.readuntil()
except (asyncio.IncompleteReadError, ConnectionResetError):
break
if data == b"" or reader.at_eof():
break
if self.data_received_callback:
await self._do_callback(
self.data_received_callback,
writer.transport,
data,
)
self.transports.pop(writer.transport)
writer.close()
if self.connection_callback:
await self._do_callback(self.connection_callback, writer.transport)
with suppress(ConnectionResetError):
await writer.wait_closed()
[docs]
class TCPStreamClient:
"""An object containing a writer and reader stream to a TCP server."""
def __init__(self, host: str, port: int):
self.host = host
self.port = port
self.reader = None
self.writer = None
[docs]
async def open_connection(self):
"""Creates the connection."""
self.reader, self.writer = await asyncio.open_connection(self.host, self.port)
[docs]
def close(self):
"""Closes the stream."""
if self.writer:
self.writer.close()
else:
raise RuntimeError("connection cannot be closed because it is not open.")
[docs]
async def open_connection(host: str, port: int) -> TCPStreamClient:
"""Returns a TCP stream connection with a writer and reader.
This function is equivalent to doing ::
>>> client = TCPStreamClient('127.0.0.1', 5555)
>>> await client.open_connection()
Instead just do ::
>>> client = await TCPStreamClient('127.0.0.1', 5555)
>>> client.writer.write('Hi!\\n'.encode())
Parameters
----------
host : str
The host of the TCP server.
port : int
The port of the TCP server.
Returns
-------
client : `.TCPStreamClient`
A container for the stream reader and writer.
"""
client = TCPStreamClient(host, port)
await client.open_connection()
return client
[docs]
class TCPStreamPeriodicServer(TCPStreamServer):
"""A TCP server that calls a function periodically.
Parameters
----------
host
The server host.
port
The server port.
period_callback
Callback to run every iteration. It is called for each transport
that is connected to the server and receives the transport object.
sleep_time
The delay between two calls to ``periodic_callback``.
kwargs
Parameters to pass to `TCPStreamServer`
"""
def __init__(
self,
host: str,
port: int,
periodic_callback: Optional[Callable[[asyncio.Transport], Any]] = None,
sleep_time: float = 1,
**kwargs,
):
self._periodic_callback = periodic_callback
self.sleep_time = sleep_time
self.periodic_task = None
super().__init__(host, port, **kwargs)
[docs]
async def start(self) -> asyncio.AbstractServer:
"""Starts the server and returns a `~asyncio.Server` connection."""
self._server = await super().start()
self.periodic_task = asyncio.create_task(self._emit_periodic())
return self._server
[docs]
def stop(self):
if self.periodic_task:
self.periodic_task.cancel()
super().stop()
@property
def periodic_callback(self):
"""Returns the periodic callback."""
return self._periodic_callback
@periodic_callback.setter
def periodic_callback(self, func: Callable[[asyncio.Transport], Any]):
"""Sets the periodic callback."""
self._periodic_callback = func
async def _emit_periodic(self):
while True:
if self._server and self.periodic_callback:
for transport in self.transports:
await self._do_callback(self.periodic_callback, transport)
await asyncio.sleep(self.sleep_time)
[docs]
class TopicListener(object):
"""A class to declare and listen to AMQP queues with topic conditions.
Parameters
----------
url
RFC3986 formatted broker address. When used, the other keyword
arguments are ignored.
user
The user to connect to the RabbitMQ broker.
password
The password for the user.
host
The host where the RabbitMQ message broker runs.
virtualhost
Virtualhost parameter. ``'/'`` by default.
port
The port on which the RabbitMQ message broker is running.
ssl
Whether to use TLS/SSL connection.
"""
def __init__(
self,
url: str | None = None,
user: str = "guest",
password: str = "guest",
host: str = "localhost",
virtualhost: str = "/",
port: int = 5672,
ssl: bool = False,
):
self.url = url
self.user = user
self.password = password
self.host = host
self.port = port
self.virtualhost = virtualhost
self.ssl = ssl
self.connection: apika.abc.AbstractConnection | None = None
self.channel: apika.abc.AbstractChannel
self.exchange: apika.abc.AbstractExchange
self.queues: List[apika.abc.AbstractQueue] = []
self._consumer_tag: Dict[apika.abc.AbstractQueue, apika.queue.ConsumerTag] = {}
[docs]
async def connect(
self,
exchange_name: str,
exchange_type: apika.ExchangeType = apika.ExchangeType.TOPIC,
on_return_raises=True,
) -> TopicListener:
"""Initialise the connection.
Parameters
----------
exchange_name
The name of the exchange to create.
exchange_type
The type of exchange to create.
"""
try:
if self.url:
self.connection = await apika.connect(self.url)
else:
self.connection = await apika.connect(
login=self.user,
host=self.host,
port=self.port,
password=self.password,
virtualhost=self.virtualhost,
ssl=self.ssl,
)
except ConnectionError as err:
raise ConnectionError(f"Failed conneting to the AMQP server: {err}.")
self.channel = await self.connection.channel(on_return_raises=on_return_raises)
await self.channel.set_qos(prefetch_count=1)
self.exchange = await self.channel.declare_exchange(
exchange_name,
type=exchange_type,
auto_delete=True,
)
return self
[docs]
async def add_queue(
self,
queue_name: str,
callback: Optional[Callable[[apika.abc.AbstractIncomingMessage], Any]] = None,
bindings: Union[str, List[str]] = "*",
) -> apika.abc.AbstractQueue:
"""Adds a queue with bindings.
Parameters
----------
queue_name
The name of the queue to create.
callback
A callable that will be called when a new message is received in
the queue. Can be a coroutine.
bindings
The list of bindings for the queue. Can be a list of string or a
single string in which the bindings are comma-separated.
"""
if isinstance(bindings, str):
bindings = bindings.split(",")
elif isinstance(bindings, (list, tuple)):
bindings = list(bindings)
else:
raise TypeError(f"invalid type for bindings {bindings!r}.")
try:
queue = await self.channel.declare_queue(queue_name, exclusive=True)
except aiormq.exceptions.ChannelLockedResource:
raise CluError(
f"cannot create queue {queue_name}. "
"This may indicate that another instance of the "
"same actor is running."
)
for binding in bindings:
await queue.bind(self.exchange, routing_key=binding)
if callback:
self._consumer_tag[queue] = await queue.consume(callback)
self.queues.append(queue)
return queue
[docs]
async def stop(self):
"""Cancels queues and closes the connection."""
for queue in self.queues:
consumer_tag = self._consumer_tag.get(queue, None)
if hasattr(queue, "consumer_tag") and consumer_tag is not None:
await queue.cancel(consumer_tag)
if self.connection:
await self.connection.close()
[docs]
class ReconnectingTCPClientProtocol(asyncio.Protocol):
"""A reconnecting client modelled after Twisted ``ReconnectingClientFactory``.
Taken from https://bit.ly/3yn6MWa.
"""
max_delay = 3600
initial_delay = 1.0
factor = 2.7182818284590451
jitter = 0.119626565582
max_retries = None
def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = kwargs
self._retries = 0
self._delay = self.initial_delay
self._continue_trying = True
self._call_handle = None
self._connector = None
self.connected = False
[docs]
def connection_lost(self, exc):
self.connected = False
if self._continue_trying:
self.retry()
def connection_failed(self, exc):
self.connected = False
if self._continue_trying:
self.retry()
def retry(self):
if not self._continue_trying:
return
self._retries += 1
if self.max_retries is not None and (self._retries > self.max_retries):
return
self._delay = min(self._delay * self.factor, self.max_delay)
if self.jitter:
self._delay = random.normalvariate(self._delay, self._delay * self.jitter)
self._call_handle = asyncio.get_event_loop().call_later(
self._delay, self.connect
)
def connect(self):
if self._connector is None:
self._connector = asyncio.create_task(self._connect())
async def _connect(self):
try:
await asyncio.get_event_loop().create_connection(
lambda: self,
*self._args,
**self._kwargs,
)
self.connected = True
except Exception as exc:
asyncio.get_event_loop().call_soon(self.connection_failed, exc)
finally:
self._connector = None
def stop_trying(self):
if self._call_handle:
self._call_handle.cancel()
self._call_handle = None
self._continue_trying = False
if self._connector is not None:
self._connector.cancel()
self._connector = None
self.connected = False