#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2020-07-30
# @Filename: client.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)
from __future__ import annotations
import asyncio
import json
import logging
import pathlib
import uuid
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Union
import aio_pika as apika
from sdsstools.logger import SDSSLogger
from clu.exceptions import CluWarning
from .base import BaseClient, Reply
from .command import Command
from .model import ModelSet
from .protocol import TopicListener
from .tools import CommandStatus
if TYPE_CHECKING:
from aio_pika.abc import HeadersType
__all__ = ["AMQPClient", "AMQPReply"]
PathLike = Union[str, pathlib.Path]
ReplyCallbackType = Callable[["AMQPReply"], Union[None, Awaitable[None]]]
[docs]
class AMQPReply(object):
"""Wrapper for an `~aio_pika.IncomingMessage` that expands and decodes it.
Parameters
----------
message
The message that contains the reply.
log
A message logger.
Attributes
----------
is_valid
Whether the message is valid and correctly parsed.
body
The body of the message, as a JSON dictionary.
info
The info dictionary.
headers
The headers of the message, decoded if they are bytes.
message_code
The message code.
internal
Whether this reply was marked internal.
sender
The name of the actor that sends the reply.
command_id
The command ID.
"""
def __init__(
self,
message: apika.abc.AbstractIncomingMessage,
log: Optional[logging.Logger] = None,
):
self.command_id: str | None = None
self.message_code: str = ""
self.body = {}
self.message = message
self._log = log
self.is_valid = True
self.info = message.info()
self.headers = dict(self.info.get("headers", {}))
self.message_code = str(self.headers.get("message_code", ""))
if self.message_code == "":
self.is_valid = False
if self._log:
self._log.warning(f"received message without message_code: {message}")
return
self.internal = bool(self.headers.get("internal", False))
self.sender = self.headers.get("sender", None)
if self.sender is None and self._log:
self._log.warning(f"received message without sender: {message}")
self.command_id = message.correlation_id
command_id_header = self.headers.get("command_id", None)
if command_id_header and command_id_header != self.command_id:
if self._log:
self._log.error(
f"mismatch between message "
f"correlation_id={self.command_id} "
f"and header command_id={command_id_header} "
f"in message {message}"
)
self.is_valid = False
return
self.body = json.loads(self.message.body)
[docs]
class AMQPClient(BaseClient):
"""Defines a new client based on the AMQP standard.
To start a new client first instantiate the class and then run `.start` as
a coroutine. Note that `.start` does not block so you will need to use
asyncio's ``run_forever`` or a similar system ::
>>> loop = asyncio.get_event_loop()
>>> client = await AMQPClient('my_client', host='localhost').start()
>>> loop.run_forever()
Parameters
----------
name
The name of the client.
url
RFC3986 formatted broker address. When used, the other connection
keyword arguments are ignored.
user
The user to connect to the AMQP broker. Defaults to ``guest``.
password
The password for the user. Defaults to ``guest``.
host
The host where the AMQP message broker runs. Defaults to ``localhost``.
virtualhost
Virtualhost parameter. ``'/'`` by default.
port
The port on which the AMQP broker is running. Defaults to 5672.
ssl
Whether to use TLS/SSL connection.
version
The version of the client.
loop
The event loop. If `None`, the current event loop will be used.
log_dir
The directory where to store the logs. Defaults to
``$HOME/logs/<name>`` where ``<name>`` is the name of the actor.
log
A `~logging.Logger` instance to be used for logging instead of
creating a new one.
parser
A click command parser that is a subclass of `~clu.parser.CluGroup`.
If `None`, the active parser will be used.
models
A list of actor models whose schemas will be monitored.
"""
__EXCHANGE_NAME__ = "sdss_exchange"
def __init__(
self,
name: str | None = None,
url: Optional[str] = None,
user: str = "guest",
password: str = "guest",
host: str = "localhost",
port: int = 5672,
virtualhost: str = "/",
ssl: bool = False,
version: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
log_dir: Optional[PathLike] = None,
log: Optional[SDSSLogger] = None,
models: List[str] = [],
**kwargs,
):
if name is None:
name = "amqp-client-" + str(uuid.uuid4()).split("-")[-1]
super().__init__(
name,
version=version,
loop=loop,
log_dir=log_dir,
log=log,
**kwargs,
)
self.replies_queue = None
# Creates the connection to the AMQP broker
self.connection = TopicListener(
url=url,
user=user,
password=password,
host=host,
port=port,
ssl=ssl,
virtualhost=virtualhost,
)
#: dict: External commands currently running.
self.running_commands: Dict[str, Command] = {}
self.models = ModelSet(self, actors=models, raise_exception=False)
self._callbacks: list[ReplyCallbackType] = []
def __repr__(self):
if self.connection.connection is None:
url = "disconnected"
else:
assert isinstance(self.connection.connection, apika.Connection)
url = str(self.connection.connection.url)
return f"<{str(self)} (name={self.name!r}, {url})>"
[docs]
async def start(self, exchange_name: str = __EXCHANGE_NAME__):
"""Starts the connection to the AMQP broker."""
self.set_loop_exception_handler()
# Starts the connection and creates the exchange
await self.connection.connect(exchange_name)
# Binds the replies queue.
self.replies_queue = await self.connection.add_queue(
f"{self.name}_replies",
callback=self.handle_reply,
bindings=["reply.#"],
)
if self.connection.connection:
assert isinstance(self.connection.connection, apika.Connection)
url = self.connection.connection.url
else:
url = "???"
self.log.debug(f"Replies queue {self.replies_queue.name!r} bound to {url!s}")
# Initialises the models.
await self.models.load_schemas()
return self
[docs]
async def stop(self):
"""Cancels queues and closes the connection."""
if self.connection.connection and not self.connection.connection.is_closed:
await self.connection.stop()
[docs]
def is_connected(self):
"""Is the client connected to the exchange?"""
return self.connection.connection and not self.connection.connection.is_closed
[docs]
async def run_forever(self):
"""Runs the event loop forever."""
assert self.connection.connection
while not self.connection.connection.is_closed:
await asyncio.sleep(1)
async def __aenter__(self):
"""Starts the client inside a context manager."""
try:
if not self.is_connected():
await self.start()
except Exception:
await self.stop()
async def __aexit__(self, *_):
"""Exits the context manager."""
await self.stop()
[docs]
async def handle_reply(
self,
message: apika.abc.AbstractIncomingMessage,
) -> AMQPReply:
"""Handles a reply received from the exchange.
Creates a new instance of `.AMQPReply` from the ``message``. If the
reply is valid it updates any running command.
Parameters
----------
message
The message received.
Returns
-------
reply
The `.AMQPReply` object created from the message.
"""
reply = AMQPReply(message, log=self.log)
await reply.message.ack()
if not reply.is_valid:
self.log.error("Invalid message received.")
return reply
# Update the models
if self.models and reply.sender in self.models:
self.models[reply.sender].update_model(reply.body)
# If the command is running we check if the message code indicates
# the command is done and, if so, sets the result in the Future.
# Also, add the reply to the command list of replies.
if reply.command_id and reply.command_id in self.running_commands:
command = self.running_commands[reply.command_id]
command.replies.append(
Reply(
message=reply.body,
message_code=reply.message_code,
command=command,
internal=reply.internal,
validated=True,
)
)
if command._reply_callback is not None:
command._reply_callback(reply)
status = CommandStatus.code_to_status(reply.message_code)
if command.status != status:
command.set_status(status)
if status.is_done:
if not command.done():
command.set_result(command)
del self.running_commands[reply.command_id]
# Handle reply callbacks.
for cb in self._callbacks:
if asyncio.iscoroutinefunction(cb):
asyncio.create_task(cb(reply))
else:
cb(reply)
return reply
[docs]
async def send_command(
self,
consumer: str,
command_string: str,
*args,
command_id: str | None = None,
callback: Optional[Callable[[AMQPReply], None]] = None,
internal: bool = False,
command: Optional[Command] = None,
time_limit: Optional[float] = None,
await_command: bool = True,
):
"""Commands another actor over its RCP queue.
Parameters
----------
consumer
The actor we are commanding.
command_string
The command string that will be parsed by the remote actor.
args
Arguments to concatenate to the command string.
command_id
The command ID associated with this command. If empty, an unique
identifier will be attached.
callback
A callback to invoke with each reply received from the actor.
internal
Whether to mark the command as internal, in which case replies will
also be considered internal.
command
The `.Command` that initiated the new command. Only relevant for
actors.
time_limit
A delay after which the command is marked as timed out and done.
await_command
If `True`, awaits the command until it finishes.
Examples
--------
These two are equivalent ::
>>> client.send_command('my_actor', 'do_something --now')
>>> client.send_command('my_actor', 'do_something', '--now')
"""
if command and command.command_id:
command_id = str(command.command_id)
else:
command_id = command_id or str(uuid.uuid4())
if len(args) > 0:
command_string += " " + " ".join(map(str, args))
if command and isinstance(command.commander_id, str):
commander_id = command.commander_id + f".{consumer}"
else:
commander_id = f"{self.name}.{consumer}"
internal = command.internal if command else internal
# Creates and registers a command.
command = Command(
command_string=command_string,
command_id=command_id,
commander_id=commander_id,
consumer_id=consumer,
internal=internal,
actor=None,
reply_callback=callback,
time_limit=time_limit,
)
self.running_commands[command_id] = command
headers = {
"command_id": command_id,
"commander_id": commander_id,
"internal": internal,
}
message_body = {"command_string": command_string}
await self._publish_message(
consumer,
headers=headers,
body=message_body,
correlation_id=command_id,
)
if await_command:
await command
return command
async def _publish_message(
self,
consumer: str,
headers: HeadersType = {},
body: dict[str, Any] = {},
correlation_id: str | None = None,
):
"""Publishes a message to an exchange."""
if not hasattr(self.connection, "exchange"):
warnings.warn(
f"Exchange is not ready to output message: {body}",
CluWarning,
)
return
assert self.replies_queue
# The routing key has the topic command and the name of the commanded actor.
routing_key = f"command.{consumer}"
try:
await self.connection.exchange.publish(
apika.Message(
json.dumps(body).encode(),
content_type="text/json",
headers=headers,
correlation_id=correlation_id,
reply_to=self.replies_queue.name,
),
routing_key=routing_key,
)
except (apika.exceptions.DeliveryError, apika.exceptions.PublishError):
# The consumer (actor) did not reply. This usually means that the actor
# is not connected. We fake a reply from that actor saying so. That will
# be received by handle_reply which will fail the current command.
error_msg = dict(error=f"Failed routing message to consumer {consumer!r}.")
headers.update({"message_code": "f", "sender": consumer})
await self.connection.exchange.publish(
apika.Message(
json.dumps(error_msg).encode(),
content_type="text/json",
headers=headers,
correlation_id=correlation_id,
),
routing_key=f"reply.{self.name}",
)
[docs]
async def send_task(
self,
consumer: str,
task_name: str,
payload: dict[str, Any] = {},
**kwargs,
):
"""
Parameters
----------
consumer
The actor we are commanding.
task_name
The task to execute in the remote actor
payload
A serialisable dictionary with the payload to pass to the task.
kwargs
Additional arguments used to update the payload dictionary.
"""
assert self.replies_queue
payload = deepcopy(payload)
payload.update(kwargs)
headers = {"commander_id": f"{self.name}.{consumer}", "task": True}
message_body = {"task": task_name}
message_body.update(payload)
correlation_id = str(uuid.uuid4())
await self._publish_message(
consumer,
headers=headers,
body=message_body,
correlation_id=correlation_id,
)
[docs]
def add_reply_callback(self, callback_func: ReplyCallbackType):
"""Adds a callback that is called when a new reply is received."""
if callback_func not in self._callbacks:
self._callbacks.append(callback_func)
[docs]
def remove_reply_callback(self, callback_func: ReplyCallbackType):
"""Removes a reply callback."""
if callback_func not in self._callbacks:
raise ValueError("Callback not registered.")
self._callbacks.remove(callback_func)