Source code for taxcalc.decorators

"""
Implement numba JIT decorators used to speed-up the execution
of Tax-Calculator functions in the calcfunctions.py module.
"""
# CODING-STYLE CHECKS:
# pycodestyle decorators.py
# pylint --disable=locally-disabled decorators.py

import os
import io
import ast
import inspect
import numba
from taxcalc.policy import Policy


DO_JIT = True
# One way to use the Python debugger is to do these two things:
#  (a) change the line immediately above this comment from
#      "DO_JIT = True" to "DO_JIT = False", and
#  (b) import pdb package and call pdb.set_trace() in either the
#      calculator.py or calcfunctions.py file.


[docs] def id_wrapper(*dec_args, **dec_kwargs): # pylint: disable=unused-argument """ Function wrapper when numba package is not being used during debugging. """ def wrap(fnc): """ wrap function nested in id_wrapper function. """ def wrapped_f(*args, **kwargs): """ wrapped_f function nested in wrap function. """ return fnc(*args, **kwargs) return wrapped_f return wrap
if DO_JIT is False or 'NOTAXCALCJIT' in os.environ: JIT = id_wrapper else: JIT = numba.jit
[docs] class GetReturnNode(ast.NodeVisitor): """ A NodeVisitor to get the return tuple names from a calc-style function. """
[docs] def visit_Return(self, node): # pylint: disable=invalid-name,no-self-use """ visit_Return is used by NodeVisitor.visit method. """ if isinstance(node.value, ast.Tuple): return [e.id for e in node.value.elts] return [node.value.id]
[docs] def create_apply_function_string(sigout, sigin, parameters): """ Create a string for a function of the form:: def ap_fuc(x_0, x_1, x_2, ...): for i in range(len(x_0)): x_0[i], ... = jitted_f(x_j[i], ...) return x_0[i], ... where the specific args to jitted_f and the number of values to return is determined by sigout and sigin. Parameters ---------- sigout: iterable of the out arguments sigin: iterable of the in arguments parameters: iterable of which of the args (from in_args) are parameter variables (as opposed to column records). This influences how we construct the apply-style function Returns ------- a String representing the function """ fstr = io.StringIO() total_len = len(sigout) + len(sigin) out_args = ["x_" + str(i) for i in range(0, len(sigout))] in_args = ["x_" + str(i) for i in range(len(sigout), total_len)] fstr.write("def ap_func({0}):\n".format(",".join(out_args + in_args))) fstr.write(" for i in range(len(x_0)):\n") out_index = [x + "[i]" for x in out_args] in_index = [] for arg, _var in zip(in_args, sigin): in_index.append(arg + "[i]" if _var not in parameters else arg) fstr.write(" " + ",".join(out_index) + " = ") fstr.write("jitted_f(" + ",".join(in_index) + ")\n") fstr.write(" return " + ",".join(out_args) + "\n") return fstr.getvalue()
[docs] def create_toplevel_function_string(args_out, args_in, pm_or_pf): """ Create a string for a function of the form: def hl_func(x_0, x_1, x_2, ...): outputs = (...) = calc_func(...) header = [...] return DataFrame(data, columns=header) Parameters ---------- args_out: iterable of the out arguments args_in: iterable of the in arguments pm_or_pf: iterable of strings for object that holds each arg Returns ------- a String representing the function """ fstr = io.StringIO() fstr.write("def hl_func(pm, pf") fstr.write("):\n") fstr.write(" from pandas import DataFrame\n") fstr.write(" import numpy as np\n") fstr.write(" import pandas as pd\n") fstr.write(" def get_values(x):\n") fstr.write(" if isinstance(x, pd.Series):\n") fstr.write(" return x.values\n") fstr.write(" else:\n") fstr.write(" return x\n") fstr.write(" outputs = \\\n") outs = [] for ppp, attr in zip(pm_or_pf, args_out + args_in): outs.append(ppp + "." + attr + ", ") outs = [m_or_f + "." + arg for m_or_f, arg in zip(pm_or_pf, args_out)] fstr.write(" (" + ", ".join(outs) + ") = \\\n") fstr.write(" " + "applied_f(") for ppp, attr in zip(pm_or_pf, args_out + args_in): # Bring Policy parameter values down a dimension. if ppp == "pm": attr += "[0]" fstr.write("get_values(" + ppp + "." + attr + ")" + ", ") fstr.write(")\n") fstr.write(" header = [") col_headers = ["'" + out + "'" for out in args_out] fstr.write(", ".join(col_headers)) fstr.write("]\n") if len(args_out) == 1: fstr.write(" return DataFrame(data=outputs," "columns=header)") else: fstr.write(" return DataFrame(data=np.column_stack(" "outputs),columns=header)") return fstr.getvalue()
[docs] def make_apply_function(func, out_args, in_args, parameters, do_jit=DO_JIT, **kwargs): """ Takes a calc-style function and creates the necessary Python code for an apply-style function. Will also jit the function if desired. Parameters ---------- func: the calc-style function out_args: list of out arguments for the apply-style function in_args: list of in arguments for the apply-style function parameters: iterable of which of the args (from in_args) are parameter variables (as opposed to column records). This influences how we construct the apply-style function. do_jit: Bool, if True, jit the resulting apply-style function Returns ------- apply-style function """ if do_jit: jitted_f = JIT(**kwargs)(func) else: jitted_f = func apfunc = create_apply_function_string(out_args, in_args, parameters) func_code = compile(apfunc, "<string>", "exec") fakeglobals = {} eval(func_code, # pylint: disable=eval-used {"jitted_f": jitted_f}, fakeglobals) if do_jit: return JIT(**kwargs)(fakeglobals['ap_func']) return fakeglobals['ap_func']
[docs] def apply_jit(dtype_sig_out, dtype_sig_in, parameters=None, **kwargs): """ Make a decorator that takes in a calc-style function, handle apply step. """ if not parameters: parameters = [] def make_wrapper(func): """ make_wrapper function nested in apply_jit function. """ theargs = inspect.getfullargspec(func).args jitted_apply = make_apply_function(func, dtype_sig_out, dtype_sig_in, parameters, **kwargs) def wrapper(*args): """ wrapper function nested in make_wrapper function. """ in_arrays = [] out_arrays = [] for farg in theargs: if hasattr(args[0], farg): in_arrays.append(getattr(args[0], farg)) else: in_arrays.append(getattr(args[1], farg)) for farg in dtype_sig_out: if hasattr(args[0], farg): out_arrays.append(getattr(args[0], farg)) else: out_arrays.append(getattr(args[1], farg)) final_array = out_arrays + in_arrays ans = jitted_apply(*final_array) return ans return wrapper return make_wrapper
[docs] def iterate_jit(parameters=None, **kwargs): """ Public decorator for a calc-style function (see calcfunctions.py) that transforms the calc-style function into an apply-style function that can be called by Calculator class methods (see calculator.py). """ if not parameters: parameters = [] def make_wrapper(func): """ make_wrapper function nested in iterate_jit decorator wraps specified func using apply_jit. """ # pylint: disable=too-many-locals # Get the input arguments from the function in_args = inspect.getfullargspec(func).args # Get the numba.jit arguments jit_args_list = inspect.getfullargspec(JIT).args + ['nopython'] kwargs_for_jit = dict() for key, val in kwargs.items(): if key in jit_args_list: kwargs_for_jit[key] = val # Any name that is a parameter # Boolean flag is given special treatment. # Identify those names here param_list = Policy.parameter_list() allowed_parameters = param_list allowed_parameters += list(arg[1:] for arg in param_list) additional_parameters = [arg for arg in in_args if arg in allowed_parameters] additional_parameters += parameters # Remote duplicates all_parameters = list(set(additional_parameters)) src = inspect.getsourcelines(func)[0] # Discover the return arguments by walking # the AST of the function grn = GetReturnNode() all_out_args = None for node in ast.walk(ast.parse(''.join(src))): all_out_args = grn.visit(node) if all_out_args: break if not all_out_args: raise ValueError("Can't find return statement in function!") # Now create the apply-style possibly-jitted function applied_jitted_f = make_apply_function(func, list(reversed(all_out_args)), in_args, parameters=all_parameters, do_jit=DO_JIT, **kwargs_for_jit) def wrapper(*args, **kwargs): """ wrapper function nested in make_wrapper function nested in iterate_jit decorator. """ # os TESTING environment only accepts string arguments if os.getenv('TESTING') == 'True': return func(*args, **kwargs) in_arrays = [] pm_or_pf = [] for farg in all_out_args + in_args: if hasattr(args[0], farg): in_arrays.append(getattr(args[0], farg)) pm_or_pf.append("pm") elif hasattr(args[1], farg): in_arrays.append(getattr(args[1], farg)) pm_or_pf.append("pf") # Create the high level function high_level_func = create_toplevel_function_string(all_out_args, list(in_args), pm_or_pf) func_code = compile(high_level_func, "<string>", "exec") fakeglobals = {} eval(func_code, # pylint: disable=eval-used {"applied_f": applied_jitted_f}, fakeglobals) high_level_fn = fakeglobals['hl_func'] ans = high_level_fn(*args, **kwargs) return ans return wrapper return make_wrapper