#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2018-09-07
# @Filename: tools.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)
from __future__ import annotations
import asyncio
import contextlib
import enum
import functools
import inspect
import json
import logging
import re
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
)
if TYPE_CHECKING:
from clu.base import MessageCode
__all__ = [
"CommandStatus",
"StatusMixIn",
"format_value",
"CallbackMixIn",
"CaseInsensitiveDict",
"cli_coro",
"as_complete_failer",
"log_reply",
"ActorHandler",
]
REPLY = 5 # REPLY logging level
WARNING_REGEX = r"^.*?\s*?(\w*?Warning): (.*)"
[docs]
class Maskbit(enum.Flag):
"""A maskbit enumeration. Intended for subclassing."""
@property
def active_bits(self) -> List[Maskbit]:
"""Returns a list of non-combination flags that match the value."""
return [
bit
for bit in self.__class__ # type: ignore
if ((bit.value & self.value) and bin(bit.value).count("1") == 1)
]
COMMAND_STATUS_TO_CODE: Dict[str, str] = {
"DONE": ":",
"CANCELLED": "f",
"FAILED": "f",
"TIMEDOUT": "f",
"READY": "i",
"RUNNING": ">",
"CANCELLING": "w",
"FAILING": "w",
"DEBUG": "d",
}
[docs]
class CommandStatus(Maskbit):
DONE = enum.auto()
CANCELLED = enum.auto()
FAILED = enum.auto()
TIMEDOUT = enum.auto()
READY = enum.auto()
RUNNING = enum.auto()
CANCELLING = enum.auto()
FAILING = enum.auto()
DEBUG = enum.auto()
ACTIVE_STATES = RUNNING | CANCELLING | FAILING
FAILED_STATES = CANCELLED | FAILED | TIMEDOUT
FAILING_STATES = CANCELLING | FAILING
DONE_STATES = DONE | FAILED_STATES
ALL_STATES = READY | ACTIVE_STATES | DONE_STATES
def __init__(self, *args):
self.code: str | None
if self.name and self.name.upper() in COMMAND_STATUS_TO_CODE:
self.code = COMMAND_STATUS_TO_CODE[self.name.upper()]
else:
self.code = None
@property
def is_combination(self) -> bool:
"""Returns True if a flag is a combination."""
assert isinstance(self.value, int)
if bin(self.value).count("1") > 1:
return True
return False
@property
def did_fail(self) -> bool:
"""Command failed or was cancelled."""
return self in self.FAILED_STATES
@property
def did_succeed(self) -> bool:
"""Command finished with DONE status."""
return self == self.DONE
@property
def is_active(self) -> bool:
"""Command is running, cancelling or failing."""
return self in self.ACTIVE_STATES
@property
def is_done(self) -> bool:
"""Command is done (whether successfully or not)."""
return self in self.DONE_STATES
@property
def is_failing(self) -> bool:
"""Command is being cancelled or is failing."""
return self in self.FAILING_STATES
[docs]
@staticmethod
def code_to_status(code, default: Optional[CommandStatus] = None) -> CommandStatus:
"""Returns the status associated with a code.
If the code doesn't have an associated status, returns ``default``.
``default`` defaults to `.CommandStatus.RUNNING`.
"""
statuses = {
":": CommandStatus.DONE,
"f": CommandStatus.FAILED,
"!": CommandStatus.FAILED,
">": CommandStatus.RUNNING,
}
return statuses.get(code, default or CommandStatus.RUNNING)
MaskbitType = TypeVar("MaskbitType", bound=Maskbit)
[docs]
class StatusMixIn(Generic[MaskbitType]):
"""A mixin that provides status tracking with callbacks.
Provides a status property that executes a list of callbacks when
the status changes.
Parameters
----------
maskbit_flags
A class containing the available statuses as a series of maskbit
flags. Usually as subclass of `enum.Flag`.
initial_status
The initial status.
callback_func
The function to call if the status changes. It receives the status.
call_now
Whether the callback function should be called when initialising.
Attributes
----------
callbacks
A list of the callback functions to call.
"""
def __init__(
self,
maskbit_flags: Type[MaskbitType],
initial_status: Optional[MaskbitType] = None,
callback_func: Optional[Callable[[MaskbitType], Any]] = None,
call_now: bool = False,
):
self.flags = maskbit_flags
self.callbacks: List[Callable[[MaskbitType], Any]] = []
self._status: MaskbitType | None = initial_status
self.watcher: Optional[asyncio.Event] = None
if callback_func is not None:
if isinstance(callback_func, (list, tuple)):
self.callbacks = callback_func
else:
self.callbacks.append(callback_func)
if call_now is True:
self.do_callbacks()
[docs]
def do_callbacks(self):
"""Calls functions in ``callbacks``."""
assert hasattr(self, "callbacks"), "missing callbacks attribute."
loop = asyncio.get_event_loop()
for func in self.callbacks:
if self.status:
loop.call_soon(func, self.status)
@property
def status(self):
"""Returns the status."""
return self._status
@status.setter
def status(self, value):
"""Sets the status."""
if value != self._status:
self._status = value
self.do_callbacks()
if self.watcher is not None:
self.watcher.set()
[docs]
async def wait_for_status(self, value):
"""Awaits until the status matches ``value``."""
if self.status == value:
return
self.watcher = asyncio.Event()
while self.status != value:
await self.watcher.wait()
if self.watcher is not None:
self.watcher.clear()
self.watcher = None
[docs]
class CallbackMixIn(object):
"""A mixin for executing callbacks.
Parameters
----------
callbacks
A list of functions or coroutines to be called.
"""
def __init__(
self,
callbacks: List[Callable[[Any], Any]] = [],
loop: Optional[asyncio.AbstractEventLoop] = None,
):
self._callbacks = []
for cb in callbacks:
self.register_callback(cb)
self._running = [] # Running callbacks
[docs]
async def stop_callbacks(self):
"""Cancels any running callback task."""
for cb in self._running:
if not cb.done():
cb.cancel()
with contextlib.suppress(asyncio.CancelledError):
for cb in self._running:
await cb
self._running = []
[docs]
def register_callback(self, callback_func: Callable[..., Any]):
"""Adds a callback function or coroutine function."""
assert callable(callback_func), "callback_func must be a callable."
self._callbacks.append(callback_func)
[docs]
def remove_callback(self, callback_func: Callable[..., Any]):
"""Removes a callback function."""
assert (
callback_func in self._callbacks
), "callback_func is not in the list of callbacks."
self._callbacks.remove(callback_func)
[docs]
def notify(self, *args):
"""Calls the callback functions with some arguments.
Coroutine callbacks are scheduled as a task. Synchronous callbacks
are called immediately.
"""
if self._callbacks is None:
return
for cb in self._callbacks:
n_args = len(inspect.getfullargspec(cb).args)
if asyncio.iscoroutinefunction(cb):
task = asyncio.create_task(cb(*args[:n_args]))
self._running.append(task)
# Auto-dispose of the task once it completes
task.add_done_callback(self._running.remove)
else:
cb(*args[:n_args])
def dict_depth(d: dict) -> int:
"""Gets the depth of a dictionary."""
if isinstance(d, dict):
return 1 + (max(map(dict_depth, d.values())) if d else 0)
return 0
def format_value(value: Any) -> str:
"""Formats messages in a way that is compatible with the parser.
Parameters
----------
value
The data to be formatted.
Returns
-------
formatted_text
A string with the escaped text.
"""
if isinstance(value, str):
if " " in value and not (value.startswith("'") or value.startswith('"')):
value = escape(value)
# for char in ",/:_-":
# if char in value:
# value = escape(value)
# break
elif isinstance(value, bool):
value = "T" if value else "F"
elif isinstance(value, (tuple, list)):
value = ",".join([format_value(item) for item in value])
elif isinstance(value, dict):
if dict_depth(value) > 1:
raise ValueError("Cannot format a dictionary with depth > 1.")
value = format_value(list(value.values()))
else:
value = str(value)
return value
[docs]
def escape(value: Any):
"""Escapes a text using `json.dumps`."""
return json.dumps(value)
T = TypeVar("T")
[docs]
class CaseInsensitiveDict(Dict[str, T]):
"""A dictionary that performs case-insensitive operations."""
def __init__(self, values: Any):
self._lc = []
dict.__init__(self, values)
self._lc = [key.lower() for key in values]
assert len(set(self._lc)) == len(
self._lc
), "the are duplicated items in the dict."
def __get_key__(self, key):
"""Returns the correct value of the key, regardless of its case."""
try:
idx = self._lc.index(key.lower())
except ValueError:
return key
return list(self)[idx]
def __getitem__(self, key):
return dict.__getitem__(self, self.__get_key__(key))
def __setitem__(self, key, value):
if key.lower() not in self._lc:
self._lc.append(key.lower())
dict.__setitem__(self, key, value)
else:
dict.__setitem__(self, self.__get_key__(key), value)
def __contains__(self, key):
return dict.__contains__(self, self.__get_key__(key))
def __eq__(self, key):
return dict.__eq__(self, self.__get_key__(key))
def cli_coro(f):
"""Decorator function that allows defining coroutines with click."""
def wrapper(*args, **kwargs):
loop = asyncio.new_event_loop()
return loop.run_until_complete(f(*args, **kwargs))
return functools.update_wrapper(wrapper, f)
async def as_complete_failer(
aws: List[Coroutine],
on_fail_callback: Optional[Callable] = None,
**kwargs,
) -> Tuple[bool, str | None]:
"""Similar to `~asyncio.as_complete` but cancels all the tasks
if any of them returns `False`.
Parameters
----------
aws
A list of awaitable objects. If not a list, it will be wrapped in one.
on_fail_callback
A function or coroutine to call if any of the tasks failed.
kwargs
A dictionary of keywords to be passed to `~asyncio.as_complete`.
Returns
-------
result_tuple
A tuple in which the first element is `True` if all the tasks
completed, `False` if any of them failed and the rest were cancelled.
If `False`, the second element is `None` if no exceptions were caught
during the execution of the tasks, otherwise it contains the error
message. If `True`, the second element is always `None`.
"""
if not isinstance(aws, (list, tuple)):
aws = [aws]
loop = kwargs.get("loop", asyncio.get_event_loop())
tasks = [loop.create_task(aw) for aw in aws]
failed = False
error_message = None
for next_completed in asyncio.as_completed(tasks, **kwargs):
try:
result = await next_completed
except Exception as ee:
error_message = str(ee)
result = False
if not result:
failed = True
break
if failed:
# Cancel tasks
[task.cancel() for task in tasks]
with contextlib.suppress(BaseException):
await asyncio.gather(*[task for task in tasks])
if on_fail_callback:
if asyncio.iscoroutinefunction(on_fail_callback):
await on_fail_callback()
else:
on_fail_callback()
return (False, error_message)
return (True, None)
def log_reply(
log: logging.Logger,
message_code: MessageCode,
message: str,
use_message_code: bool = False,
):
"""Logs an actor message with the correct code."""
code_dict = {
"f": logging.ERROR,
"e": logging.ERROR,
"w": logging.WARNING,
"i": logging.INFO,
":": logging.INFO,
"d": logging.DEBUG,
}
if use_message_code:
log.log(code_dict[message_code.value], message)
else:
# Sets the REPLY log level
log_level_no = REPLY
if log_level_no in logging._levelToName:
log_level = log_level_no
else:
log_level = logging.DEBUG
log.log(log_level, message)
_ActorClass = TypeVar("_ActorClass")
[docs]
class ActorHandler(logging.Handler):
"""A handler that outputs log messages as actor keywords.
Parameters
----------
actor
The actor instance.
level
The level above which records will be output in the actor.
keyword
The keyword around which the messages will be output.
code_mapping
A mapping of logging levels to actor codes. The values provided
override the default mapping. For example, to make input log messages
with info level be output as debug,
``code_mapping={logging.INFO: 'd'}``.
filter_warnings
A list of warning classes that will be issued to the actor. Subclasses
of the filter warning are accepted, any other warnings will be ignored.
"""
def __init__(
self,
actor,
level: int = logging.ERROR,
keyword: str = "text",
code_mapping: Optional[Dict[int, str]] = None,
filter_warnings: Optional[List[Type[Warning]]] = None,
):
self.actor = actor
self.keyword = keyword
self.code_mapping = {
logging.DEBUG: "d",
logging.INFO: "i",
logging.WARNING: "w",
logging.ERROR: "e",
}
if code_mapping:
self.code_mapping.update(code_mapping)
self.filter_warnings = filter_warnings
super().__init__(level=level)
[docs]
def emit(self, record: logging.LogRecord):
"""Emits the record."""
message = record.getMessage()
message_lines = message.splitlines()
if record.exc_info is not None and record.exc_info[0] is not None:
message_lines.append(f"{record.exc_info[0].__name__}: {record.exc_info[1]}")
if record.levelno <= logging.DEBUG:
code = self.code_mapping[logging.DEBUG]
elif record.levelno <= logging.INFO:
code = self.code_mapping[logging.INFO]
elif record.levelno <= logging.WARNING:
code = self.code_mapping[logging.WARNING]
warning_category_groups = re.match(WARNING_REGEX, message)
if warning_category_groups is not None:
message_lines = self._filter_warning(warning_category_groups)
elif record.levelno >= logging.ERROR:
code = self.code_mapping[logging.ERROR]
else:
code = "w"
for line in message_lines:
result = self.actor.write(code, message={self.keyword: line})
if asyncio.iscoroutine(result):
asyncio.create_task(result)
def _filter_warning(self, warning_category_groups):
warning_category, warning_text = warning_category_groups.groups()
message_lines = [f"{warning_text} ({warning_category})"]
try:
if self.filter_warnings:
for warning_filter in self.filter_warnings:
if warning_category == warning_filter.__name__:
return message_lines
return []
except NameError:
return message_lines