Adding Standard Library Operations
To add operations to the Standard Library, define a subclass of
StandardOperation
and a function wrapper.
StandardOperation
function wrappers must take Variable
objects as
required unnamed arguments.
If keyword arguments (kwargs) are required, they must be compile-time
constants.
from csdl.core.standard_operation import StandardOperation
class power_combination(StandardOperation): def __init__(self, *args, powers, coeff, **kwargs): # set the number of outputs and arguments before calling # superclass constructor; nargs=None implies variable number of # arguments self.nouts = 1 self.nargs = None super().__init__(*args, **kwargs)
# for this particular operation, all arguments must have the # same shape; note that the dependencies were established in # the superclass constructor dep0 = self.dependencies[0] for dep in self.dependencies: if dep0.shape != dep.shape: raise ValueError( "Shapes of inputs to linear_combination do not match")
# store properties that the compiler can use for IR optimizations self.properties['elementwise'] = True
# store compile time constants that the back end will use self.literals['powers'] = powers self.literals['coeff'] = coeff
# if operation is elementwise, store string to be evaluated when # computing finite difference for a combined operation def define_compute_strings(self): out_name = self.outs[0].name self.compute_string = '{}='.format(out_name) args = self.dependencies powers = self.literals['powers'] coeff = self.literals['coeff'] # if isinstance(constant, np.ndarray): # raise notimplementederror("constant must be a scalar constant") if isinstance(powers, (int, float)): powers = [powers] * len(args) self.compute_string = '{}={}'.format(out_name, coeff) for arg, power in zip(args, powers): if not np.all(coeff == 0): self.compute_string += '*{}**{}'.format(arg.name, power) else: self.compute_string = '0'
The StandardOperation
subclasses are located in csdl/operations
.
To define the function wrapper that adds the operation to the IR,
def cos(var): # perform some type checking if not isinstance(var, Variable): raise TypeError(var, " is not an Variable object")
# construct operation op = ops.cos(var)
# construct Output(s) that depend(s) on operation, establishing # dependency of Output(s) on operation; # references to all outputs so that compiler may be able to # detect dead code op.outs = (Output( None, op=op, shape=op.dependencies[0].shape, ), )
# return Outputs return op.outs[0]