Source code for pacman.utils.expressionfunction

import ast
import importlib.util
import sys
from collections.abc import Callable
from typing import Any, List, Set, Tuple

from pacman.utils.simple_repr import SimpleRepr, from_repr, simple_repr


[docs]class ExpressionFunction(Callable, SimpleRepr): """ Callable object representing a function from a python string. expression. Example: f = ExpressionFunction('a + b') f.variable_names -> ['a', 'b'] f(a=1, b=3) -> 4 f.expression -> 'a + b' Note: this callable only works with keyword arguments. """
[docs] def __init__(self, expression: str, source_file=None, **fixed_vars) -> None: """ Create a callable representing the expression. :param expression: a valid python expression (any builtin python function can be used, e.g. abs, round, etc.). for example "abs(a1 - b)" :param fixed_vars: extra keyword parameters will be interpreted as fixed parameter for the expression and the produced callable will represent a partial evaluation if the expression with these parameter already fixed. If the name of these keyword parameter do not match any of the variables found in the expression, a `ValueError` is raised. """ self._expression = expression.lstrip() self._fixed_vars = fixed_vars self._source_file = source_file has_return, self.exp_vars = _analyse_ast(self._expression) # Build the function definition code from the expression: f_def = f"def f({', '.join([v for v in self.exp_vars])} ):\n" if not has_return: f_def += f" return {self._expression}" else: self._expression = ( f"\n{self._expression}" if not self._expression.startswith("\n") else self._expression ) f_def += self._expression.replace("\n", "\n ") try: f_compiled = compile(f_def, "<string>", "exec") except SyntaxError: raise SyntaxError( f"Syntax error in string expression: '{self._expression}'" ) # Make the module that contains the constraint definition available to exec: g = dict(globals()) if source_file is not None: # import the module that contains the constraint definition spec = importlib.util.spec_from_file_location("source", source_file) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) g["source"] = module # And execute on compiled function definition, to get the function object: try: local = {} exec(f_compiled, g, local) self.exp_func = local["f"] except SyntaxError: raise SyntaxError(f"Syntax error in multi-line string expression {f_def}'") for v in fixed_vars: if v not in self.exp_vars: raise ValueError( 'Cannot fix variable "{}" which is not ' 'present in the expression "{}"'.format(v, expression) )
@property def expression(self): return self._expression @property def __name__(self): return self._expression @property def variable_names(self) -> List[str]: """ :return: a set of variable names that must be set when calling f """ return [v for v in self.exp_vars if v not in self._fixed_vars]
[docs] def partial(self, **kwargs): return ExpressionFunction(self.expression, **kwargs)
def __call__(self, **kwargs): # Note that we only accept named arguments ! l = kwargs.copy() l.update(self._fixed_vars) received = set(kwargs.keys()) expected = set(self.variable_names) unexpected = received - expected missing = expected - received if missing: raise TypeError("Missing named argument(s) " + str(missing)) if unexpected: raise TypeError("Unexpected argument(s) " + str(unexpected)) res = self.exp_func(**l) return res def __eq__(self, other): if type(self) != type(other): return False if ( self._expression == other._expression and self._source_file == other._source_file ): return True return False def __str__(self): return f"ExpressionFunction({self._expression})" def __repr__(self): return f"ExpressionFunction({self._expression}, {self.exp_vars})" def __hash__(self): return hash((self._expression, tuple(self._fixed_vars.items()))) def _simple_repr(self): r = super()._simple_repr() r["fixed_vars"] = simple_repr(self._fixed_vars) return r @classmethod def _from_repr(cls, r): fixed_vars = r["fixed_vars"] del r["fixed_vars"] args = { k: from_repr(v) for k, v in r.items() if k not in ["__qualname__", "__module__"] } exp_fct = cls(**args, **fixed_vars) return exp_fct
[docs]class VarCounterVisitor(ast.NodeVisitor): """A simple visitor to count variables in an AST tree."""
[docs] def __init__(self): self.loaded = set() self.stored = set() self.has_return = False self.imported = set()
[docs] def visit(self, node) -> Any: if isinstance(node, ast.Name): if isinstance(node.ctx, ast.Load): self.loaded.add(node.id) elif isinstance(node.ctx, ast.Store): self.stored.add(node.id) elif isinstance(node, ast.Return): self.has_return = True # We must keep track of importer name in order to avoid considering as variable # names: elif isinstance(node, ast.Import): self.imported.update([n.name for n in node.names]) elif isinstance(node, ast.ImportFrom): self.imported.update([n.name for n in node.names]) self.generic_visit(node)
[docs] def get_vars(self): names = (self.loaded - self.stored) - self.imported # We want to allow using builtin function like abs, round, etc. # We must filter them out from the list of variables. # We also filter out any name that starts with 'source', as this is the syntax # used in yaml when referring to constraints defined in separate python files. builtins = dir(sys.modules["builtins"]) return {n for n in names if not n.startswith("source") and n not in builtins}
def _analyse_ast(str_code: str) -> Tuple[bool, Set[str]]: """ Analyse the ast built from `str_definition`. Parameters ---------- str_code: str A string containing a piece of valid python code : statement, expression or function definition (but without the `def ....` line). Returns ------- has_return: bool True is the expression contains at least one return statement. variables: Set of str A set containing the identifiers of all variables used but not declared in `str_code`. """ node = ast.parse(str_code) visitor = VarCounterVisitor() visitor.visit(node) return visitor.has_return, visitor.get_vars()