Skip to main content

Adding Standard Library Operations


To add operations to the Standard Library, define a subclass of StandardOperation and a function wrapper. StandardOperation function wrappers must take Variable objects as required unnamed arguments. If keyword arguments (kwargs) are required, they must be compile-time constants.

from csdl.core.standard_operation import StandardOperation
class power_combination(StandardOperation):    def __init__(self, *args, powers, coeff, **kwargs):        # set the number of outputs and arguments before calling        # superclass constructor; nargs=None implies variable number of        # arguments        self.nouts = 1        self.nargs = None        super().__init__(*args, **kwargs)
        # for this particular operation, all arguments must have the        # same shape; note that the dependencies were established in        # the superclass constructor        dep0 = self.dependencies[0]        for dep in self.dependencies:            if dep0.shape != dep.shape:                raise ValueError(                    "Shapes of inputs to linear_combination do not match")
        # store properties that the compiler can use for IR optimizations        self.properties['elementwise'] = True
        # store compile time constants that the back end will use        self.literals['powers'] = powers        self.literals['coeff'] = coeff
    # if operation is elementwise, store string to be evaluated when    # computing finite difference for a combined operation    def define_compute_strings(self):        out_name = self.outs[0].name        self.compute_string = '{}='.format(out_name)        args = self.dependencies        powers = self.literals['powers']        coeff = self.literals['coeff']        # if isinstance(constant, np.ndarray):        #     raise notimplementederror("constant must be a scalar constant")        if isinstance(powers, (int, float)):            powers = [powers] * len(args)        self.compute_string = '{}={}'.format(out_name, coeff)        for arg, power in zip(args, powers):            if not np.all(coeff == 0):                self.compute_string += '*{}**{}'.format(arg.name, power)            else:                self.compute_string = '0'

The StandardOperation subclasses are located in csdl/operations. To define the function wrapper that adds the operation to the IR,

def cos(var):    # perform some type checking    if not isinstance(var, Variable):        raise TypeError(var, " is not an Variable object")
    # construct operation    op = ops.cos(var)
    # construct Output(s) that depend(s) on operation, establishing    # dependency of Output(s) on operation;    # references to all outputs so that compiler may be able to    # detect dead code    op.outs = (Output(        None,        op=op,        shape=op.dependencies[0].shape,    ), )
    # return Outputs    return op.outs[0]