Source code for firexkit.chain

import inspect
from inspect import signature, getfullargspec
from functools import wraps
from typing import Union

from celery.canvas import chain, Signature
from celery.local import PromiseProxy
from celery.result import AsyncResult
from celery.utils.log import get_task_logger

from firexkit.result import wait_on_async_results_and_maybe_raise
from firexkit.bag_of_goodies import BagOfGoodies
from firexkit.task import parse_signature, FireXTask, undecorate, ReturnsCodingException, get_attr_unwrapped

logger = get_task_logger(__name__)

[docs]def returns(*args): """ The decorator is used to allow us to specify the keys of the dict that the task returns. This is used only to signal to the user the inputs and outputs of a task, and deduce what arguments are required for a chain. """ if not args or len(args) != len(set(args)): raise ReturnsCodingException("@returns cannot contain duplicate keys") def decorator(func): if type(func) is PromiseProxy: raise ReturnsCodingException("@returns must be applied to a function (before @app.task)") # Store the arguments of the decorator as a function attribute undecorated = func while hasattr(undecorated, '__wrapped__'): undecorated = undecorated.__wrapped__ undecorated._decorated_return_keys = args # Return a dictionary mapping the tuple returned by the original function # to the keys specified by the arguments to the @returns decorator @wraps(func) def returns_wrapper(*orig_args, **orig_kwargs) -> dict: result = func(*orig_args, **orig_kwargs) return FireXTask.convert_returns_to_dict(undecorated._decorated_return_keys, result) return returns_wrapper return decorator
[docs]class InjectArgs(object): def __init__(self, **kwargs): self.injectArgs = kwargs def __or__(self, other) -> Union[chain, Signature]: if isinstance(other, InjectArgs): self.injectArgs.update(other.injectArgs) else: # chains and signatures are both handled by this _inject_args_into_signature(other, **self.injectArgs) return other def __ior__(self, other: Union[chain, Signature]) -> Union[chain, Signature]: return self | other
def _inject_args_into_signature(sig, **kwargs): try: # This might be a chain task = sig.tasks[0] except AttributeError: # This might be a Task task = sig existing = task.kwargs.keys() subset = {k: v for k, v in kwargs.items() if k not in existing} task.kwargs.update(subset) Signature.injectArgs = _inject_args_into_signature
[docs]def verify_chain_arguments(sig: Signature): """ Verifies that the chain is not missing any parameters. Asserts if any parameters are missing, or if a reference parameter (@something) has not provider """ try: tasks = sig.tasks except AttributeError: tasks = [sig] missing = {} previous = set() ref_args = {} undefined_indirect = {} for task in tasks: task_obj =[task.task] partial_bound = set(signature(*task.args).arguments.keys()) kwargs_keys = set(task.kwargs.keys()) | {'args', 'kwargs'} if isinstance(task_obj, FireXTask): required_args = task_obj.required_args else: required_args, _ = parse_signature(inspect.signature( missing_params = set(required_args) - (partial_bound | kwargs_keys | previous) if missing_params: missing[] = missing_params previous |= (partial_bound | kwargs_keys) # get @returns for next loop try: current_task_returns = set(task_obj.return_keys) previous |= current_task_returns except AttributeError: try: current_task_returns = set(get_attr_unwrapped(task_obj, '_decorated_return_keys')) previous |= current_task_returns except AttributeError: current_task_returns = set() # If any of the previous keys has a dynamic return, then we can't do any validation if any(FireXTask.is_dynamic_return(k) for k in current_task_returns): break # check for validity of reference values (@ arguments) that are consumed by this microservice necessary_args = getfullargspec(undecorate(task_obj)).args new_ref = {k: v[1:] for k, v in task.kwargs.items() if hasattr(v, 'startswith') and v.startswith(BagOfGoodies.INDIRECT_ARG_CHAR)} ref_args.update(new_ref) for needed in necessary_args: if needed in ref_args and ref_args[needed] not in previous: undefined_indirect[needed] = ref_args[needed] if missing: txt = '' for k, v in missing.items(): for arg in v: if txt: txt += '\n' txt += ' ' + arg + '\t: required by "%s"' % k.split('.')[-1] raise InvalidChainArgsException('Missing mandatory arguments: \n%s' % txt, missing) if undefined_indirect: txt = "\n".join([k + ": " + v for k, v in undefined_indirect.items()]) raise InvalidChainArgsException('Chain indirectly references the following unavailable parameters: \n%s' % txt, undefined_indirect) return True
[docs]class InvalidChainArgsException(Exception): def __init__(self, msg, wrong_args: dict = None): super(InvalidChainArgsException, self).__init__(msg) self.wrong_args = wrong_args if wrong_args else {}
def _enqueue(self: Signature, block: bool = False, raise_exception_on_failure: bool = True, caller_task: Signature = None, queue: str = None, priority: int = None) -> AsyncResult: verify_chain_arguments(self) if queue: self.set_queue(queue) if priority: self.set_priority(priority) result_promise = self.delay() if block: wait_on_async_results_and_maybe_raise(results=result_promise, raise_exception_on_failure=raise_exception_on_failure, caller_task=caller_task) return result_promise
[docs]def set_attr(sig: Signature, **options): """Set arbitrary executions options in every task in the :attr:`sig`""" try: [task.set(**options) for task in sig.tasks] except AttributeError: sig.set(**options)
[docs]def set_priority(sig: Signature, priority: int): """Set the :attr:`priority` execution option in every task in :attr:`sig`""" set_attr(sig, priority=priority)
[docs]def set_queue(sig: Signature, queue): """Set the :attr:`queue` execution option in every task in :attr:`sig`""" set_attr(sig, queue=queue)
[docs]def set_label(sig: Signature, label): sig.set(label=label)
[docs]def get_label(sig: Signature): try: return sig.options['label'] except KeyError: try: return '|'.join([ for task in sig.tasks]) except AttributeError: return
Signature.set_priority = set_priority Signature.set_queue = set_queue Signature.set_label = set_label Signature.get_label = get_label Signature.enqueue = _enqueue