tensor.rewriting.math – Tensor Rewrites for Math Operations#
Rewrites for the Ops in pytensor.tensor.math.
- class pytensor.tensor.rewriting.math.AlgebraicCanonizer(main, inverse_fn, reciprocal_fn, calculate, use_reciprocal=True)[source]#
A
Rewriterthat rewrites algebraic expressions.The variable is a
node_rewriter. It is best used with aWalkingGraphRewriterin in-to-out order.Usage:
AlgebraicCanonizer(main, inverse, reciprocal, calculate)- Parameters:
main – A suitable
Opclass that is commutative, associative and takes one to an arbitrary number of inputs, e.g. add or mulinverse – An
Opclass such thatinverse(main(x, y), y) == x(e.g.subortrue_div).reciprocal – A function such that
main(x, reciprocal(y)) == inverse(x, y)(e.g.negorreciprocal).calculate – Function that takes a list of
numpy.ndarrayinstances for the numerator, another list for the denumerator, and calculatesinverse(main(\*num), main(\*denum)). It takes a keyword argument,aslist. IfTrue, the value should be returned as a list of one element, unless the value is such thatvalue = main(). In that case, the return value should be an empty list.
Examples
>>> import pytensor.tensor as pt >>> from pytensor.tensor.rewriting.math import AlgebraicCanonizer >>> add_canonizer = AlgebraicCanonizer(add, sub, neg, \ ... lambda n, d: sum(n) - sum(d)) >>> mul_canonizer = AlgebraicCanonizer(mul, true_div, reciprocal, \ ... lambda n, d: prod(n) / prod(d))
Examples of rewrites
mul_canonizercan perform:x / x -> 1(x * y) / x -> yx / y / x -> 1 / yx / y / z -> x / (y * z)x / (y / z) -> (x * z) / y(a / b) * (b / c) * (c / d) -> a / d(2.0 * x) / (4.0 * y) -> (0.5 * x) / y2 * x / 2 -> xx * y * z -> Elemwise(mul){x,y,z} #only one pass over the memory.!-> Elemwise(mul){x,Elemwise(mul){y,z}}- get_num_denum(inp)[source]#
This extract two lists,
numanddenum, such that the input is:self.inverse(self.main(\*num), self.main(\*denum)). It returns the two lists in a(num, denum)pair.For example, for main, inverse and
reciprocal = \*, / and inv(),input -> returned value (num, denum)x*y -> ([x, y], [])inv(x) -> ([], [x])inv(x) * inv(y) -> ([], [x, y])x*y/z -> ([x, y], [z])log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y])(((a / b) * c) / d) -> ([a, c], [b, d])a / (b / c) -> ([a, c], [b])log(x) -> ([log(x)], [])x**y -> ([x**y], [])x * y * z -> ([x, y, z], [])
- merge_num_denum(num, denum)[source]#
Utility function which takes two lists, num and denum, and returns something which is equivalent to inverse(main(*num), main(*denum)), but depends on the length of num and the length of denum (in order to minimize the number of operations).
Let n = len(num) and d = len(denum):
n=0, d=0: neutral element (given by self.calculate([], []))(for example, this would be 0 if main is additionand 1 if main is multiplication)n=1, d=0: num[0]n=0, d=1: reciprocal(denum[0])n=1, d=1: inverse(num[0], denum[0])n=0, d>1: reciprocal(main(*denum))n>1, d=0: main(*num)n=1, d>1: inverse(num[0], main(*denum))n>1, d=1: inverse(main(*num), denum[0])n>1, d>1: inverse(main(*num), main(*denum))Given the values of n and d to which they are associated, all of the above are equivalent to: inverse(main(*num), main(*denum))
- simplify(num, denum, out_type)[source]#
Shorthand for:
self.simplify_constants(*self.simplify_factors(num, denum))
- simplify_constants(orig_num, orig_denum, out_type=None)[source]#
Find all constants and put them together into a single constant.
Finds all constants in orig_num and orig_denum and puts them together into a single constant. The constant is inserted as the first element of the numerator. If the constant is the neutral element, it is removed from the numerator.
Examples
Let main be multiplication:
[2, 3, x], [] -> [6, x], [][x, y, 2], [4, z] -> [0.5, x, y], [z][x, 2, y], [z, 2] -> [x, y], [z]
- simplify_factors(num, denum)[source]#
For any Variable r which is both in num and denum, removes it from both lists. Modifies the lists inplace. Returns the modified lists. For example:
[x], [x] -> [], [][x, y], [x] -> [y], [][a, b], [c, d] -> [a, b], [c, d]
- tracks()[source]#
Return the list of
Opclasses to which this rewrite applies.Returns
Nonewhen the rewrite applies to all nodes.
- transform(fgraph, node)[source]#
Rewrite the sub-graph given by
node.Subclasses should implement this function so that it returns one of the following:
Falseto indicate that this rewrite cannot be applied tonodeA list of
Variables to use in place of thenode’s current outputs- A
dictmapping oldVariables toVariables, or the key "remove"mapping to a list ofVariables to be removed.
- A
- Parameters:
fgraph – A
FunctionGraphcontainingnode.node – An
Applynode to be rewritten.
- pytensor.tensor.rewriting.math.attempt_distribution(factor, num, denum, out_type)[source]#
Try to insert each
numand eachdenumin the factor?- Returns:
If there are changes,
new_numandnew_denumcontain all the numerators and denominators that could not be distributed in the factor- Return type:
changes?, new_factor, new_num, new_denum
- pytensor.tensor.rewriting.math.check_for_x_over_absX(numerators, denominators)[source]#
Convert x/abs(x) into sign(x).
- pytensor.tensor.rewriting.math.compute_mul(tree)[source]#
Compute the Variable that is the output of a multiplication tree.
This is the inverse of the operation performed by
parse_mul_tree, i.e. compute_mul(parse_mul_tree(tree)) == tree.- Parameters:
tree – A multiplication tree (as output by
parse_mul_tree).- Returns:
A Variable that computes the multiplication represented by the tree.
- Return type:
object
- pytensor.tensor.rewriting.math.is_1pexp(t, only_process_constants=True)[source]#
- Returns:
If ‘t’ is of the form (1+exp(x)), return (False, x). Else return None.
- Return type:
object
- pytensor.tensor.rewriting.math.is_exp(var)[source]#
Match a variable with either of the
exp(x)or-exp(x)patterns.- Parameters:
var – The Variable to analyze.
- Returns:
A pair (b, x) with
ba boolean set to True ifvaris of the form-exp(x)and False ifvaris of the formexp(x). Ifvarcannot be cast into either form, then returnNone.- Return type:
tuple
- pytensor.tensor.rewriting.math.is_inverse_pair(node_op, prev_op, inv_pair)[source]#
Given two consecutive operations, check if they are the provided pair of inverse functions.
- pytensor.tensor.rewriting.math.is_mul(var)[source]#
Match a variable with
x * y * z * ....- Parameters:
var – The Variable to analyze.
- Returns:
A list [x, y, z, …] if
varis of the formx * y * z * ..., or None ifvarcannot be cast into this form.- Return type:
object
- pytensor.tensor.rewriting.math.is_neg(var)[source]#
Match a variable with the
-xpattern.- Parameters:
var – The Variable to analyze.
- Returns:
xifvaris of the form-x, or None otherwise.- Return type:
object
- pytensor.tensor.rewriting.math.parse_mul_tree(root)[source]#
Parse a tree of multiplications starting at the given root.
- Parameters:
root – The variable at the root of the tree.
- Returns:
A tree where each non-leaf node corresponds to a multiplication in the computation of
root, represented by the list of its inputs. Each input is a pair [n, x] withna boolean value indicating whether sub-treexshould be negated.- Return type:
object
Examples
x * y -> [False, [[False, x], [False, y]]] -(x * y) -> [True, [[False, x], [False, y]]] -x * y -> [False, [[True, x], [False, y]]] -x -> [True, x] (x * y) * -z -> [False, [[False, [[False, x], [False, y]]], [True, z]]]
- pytensor.tensor.rewriting.math.perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None, sigm_minus_x=None, parent=None, child_idx=None, full_tree=None)[source]#
Core processing of the
local_sigm_times_exprewrite.This recursive function operates on a multiplication tree as output by
parse_mul_tree. It walks through the tree and modifies it in-place by replacing matching pairs (exp, sigmoid) with the desired version.- Parameters:
tree – The sub-tree to operate on.
exp_x – List of arguments
xso thatexp(x)exists somewhere in the whole multiplication tree. Each argument is a pair(x, leaf)withxthe argument of the exponential, andleafthe corresponding leaf in the multiplication tree (of the form[n, exp(x)]– seeparse_mul_tree). IfNone, this argument is initialized to an empty list.exp_minus_x – Similar to
exp_x, but forexp(-x).sigm_x – Similar to
exp_x, but forsigmoid(x).sigm_minus_x – Similar to
exp_x, but forsigmoid(-x).parent – Parent of
tree(Noneiftreeis the global root).child_idx – Index of
treein its parent’s inputs (Noneiftreeis the global root).full_tree – The global multiplication tree (should not be set except by recursive calls to this function). Used for debugging only.
- Returns:
Trueif a modification was performed somewhere in the whole multiplication tree, orFalseotherwise.- Return type:
bool
- pytensor.tensor.rewriting.math.replace_leaf(arg, leaves, new_leaves, op, neg)[source]#
Attempt to replace a leaf of a multiplication tree.
We search for a leaf in
leaveswhose argument isarg, and if we find one, we remove it fromleavesand add tonew_leavesa leaf with argumentargand variableop(arg).- Parameters:
arg – The argument of the leaf we are looking for.
leaves – List of leaves to look into. Each leaf should be a pair (x, l) with
xthe argument of the Op found in the leaf, andlthe actual leaf as found in a multiplication tree output byparse_mul_tree(i.e. a pair [boolean, variable]).new_leaves – If a replacement occurred, then the leaf is removed from
leavesand added to the listnew_leaves(after being modified byop).op – A function that, when applied to
arg, returns the Variable we want to replace the original leaf variable with.neg (bool) – If True, then the boolean value associated to the leaf should be swapped. If False, then this value should remain unchanged.
- Returns:
True if a replacement occurred, or False otherwise.
- Return type:
bool
- pytensor.tensor.rewriting.math.scalarconsts_rest(inputs, elemwise=True, only_process_constants=False)[source]#
Partition a list of variables into two kinds: scalar constants, and the rest.
- pytensor.tensor.rewriting.math.simplify_mul(tree)[source]#
Simplify a multiplication tree.
- Parameters:
tree – A multiplication tree (as output by
parse_mul_tree).- Returns:
A multiplication tree computing the same output as
treebut without useless multiplications by 1 nor -1 (identified by leaves of the form [False, None] or [True, None] respectively). Useless multiplications (with less than two inputs) are also removed from the tree.- Return type:
object