This section introduces intermediate representation (IR) rewrites, and how they can be used to implement optimizations.
As discussed earlier in “Stage 6a: Rewrite typed IR”, rewriting the Numba IR allows us to perform optimizations that would be much more difficult to perform at the lower LLVM level. Similar to the Numba type and lowering subsystems, the rewrite subsystem is user extensible. This extensibility affords Numba the possibility of supporting a wide variety of domain-specific optimizations (DSO’s).
The remaining subsections detail the mechanics of implementing a rewrite, registering a rewrite with the rewrite registry, and provide examples of adding new rewrites, as well as internals of the array expression optimization pass. We conclude by reviewing some use cases exposed in the examples, as well as reviewing any points where developers should take care.
Rewriting passes have a simple match()
and
apply()
interface. The division between matching and
rewriting follows how one would define a term rewrite in a declarative
domain-specific languages (DSL’s). In such DSL’s, one may write a
rewrite as follows:
<match> => <replacement>
The <match>
and <replacement>
symbols represent IR term
expressions, where the left-hand side presents a pattern to match, and
the right-hand side an IR term constructor to build upon matching.
Whenever the rewrite matches an IR pattern, any free variables in the
left-hand side are bound within a custom environment. When applied,
the rewrite uses the pattern matching environment to bind any free
variables in the right-hand side.
As Python is not commonly used in a declarative capacity, Numba uses object state to handle the transfer of information between the matching and application steps.
Rewrite
Base Class¶Rewrite
¶The Rewrite
class simply defines an abstract base class
for Numba rewrites. Developers should define rewrites as
subclasses of this base type, overloading the
match()
and apply()
methods.
pipeline
¶The pipeline attribute contains the
numba.compiler.Pipeline
instance that is currently
compiling the function under consideration for rewriting.
__init__
(self, pipeline, *args, **kws)¶The base constructor for rewrites simply stashes its arguments
into attributes of the same name. Unless being used in
debugging or testing, rewrites should only be constructed by
the RewriteRegistry
in the
RewriteRegistry.apply()
method, and the construction
interface should remain stable (though the pipeline will
commonly contain just about everything there is to know).
match
(self, block, typemap, callmap)¶The match()
method takes four arguments other
than self:
numba.ir.FunctionIR
for the
function being rewritten.numba.ir.Block
. The
matching method should iterate over the instructions contained
in the numba.ir.Block.body
member.dict
instance mapping
from symbol names in the IR, represented as strings, to Numba
types.dict
instance mapping from
calls, represented as numba.ir.Expr
instances, to
their corresponding call site type signatures, represented as
a numba.typing.templates.Signature
instance.The match()
method should return a bool
result. A True
result should indicate that one or more
matches were found, and the apply()
method will
return a new replacement numba.ir.Block
instance. A
False
result should indicate that no matches were found, and
subsequent calls to apply()
will return undefined
or invalid results.
apply
(self)¶The apply()
method should only be invoked
following a successful call to match()
. This
method takes no additional parameters other than self, and
should return a replacement numba.ir.Block
instance.
As mentioned above, the behavior of calling
apply()
is undefined unless
match()
has already been called and returned
True
.
Rewrite
¶Before going into the expectations for the overloaded methods any
Rewrite
subclass must have, let’s step back a minute to
review what is taking place here. By providing an extensible
compiler, Numba opens itself to user-defined code generators which may
be incomplete, or worse, incorrect. When a code generator goes awry,
it can cause abnormal program behavior or early termination.
User-defined rewrites add a new level of complexity because they must
not only generate correct code, but the code they generate should
ensure that the compiler does not get stuck in a match/apply loop.
Non-termination by the compiler will directly lead to non-termination
of user function calls.
There are several ways to help ensure that a rewrite terminates:
In the “Case study: Array Expressions” subsection, below, we’ll see how the array expression rewriter uses both of these techniques.
Rewrite.match()
¶Every rewrite developer should seek to have their implementation of
match()
return a False
value as quickly as
possible. Numba is a just-in-time compiler, and adding compilation
time ultimately adds to the user’s run time. When a rewrite returns
False
for a given block, the registry will no longer process that
block with that rewrite, and the compiler is that much closer to
proceeding to lowering.
This need for timeliness has to be balanced against collecting the necessary information to make a match for a rewrite. Rewrite developers should be comfortable adding dynamic attributes to their subclasses, and then having these new attributes guide construction of the replacement basic block.
Rewrite.apply()
¶The apply()
method should return a replacement
numba.ir.Block
instance to replace the basic block that
contained a match for the rewrite. As mentioned above, the IR built
by apply()
methods should preserve the semantics of the
user’s code, but also seek to avoid generating another match for the
same rewrite or set of rewrites.
When you want to include a rewrite in the rewrite pass, you should
register it with the rewrite registry. The numba.rewrites
module provides both the abstract base class and a class decorator for
hooking into the Numba rewrite subsystem. The following illustrates a
stub definition of a new rewrite:
from numba import rewrites
@rewrites.register_rewrite
class MyRewrite(rewrites.Rewrite):
def match(self, block, typemap, calltypes):
raise NotImplementedError("FIXME")
def apply(self):
raise NotImplementedError("FIXME")
Developers should note that using the class decorator as shown above will register a rewrite at import time. It is the developer’s responsibility to ensure their extensions are loaded before compilation starts.
This subsection looks at the array expression rewriter in more depth.
The array expression rewriter, and most of its support functionality,
are found in the numba.npyufunc.array_exprs
module. The
rewriting pass itself is implemented in the RewriteArrayExprs
class. In addition to the rewriter, the
array_exprs
module includes a function for
lowering array expressions,
_lower_array_expr()
. The overall
optimization process is as follows:
RewriteArrayExprs.match()
: The rewrite pass looks for two or
more array operations that form an array expression.RewriteArrayExprs.apply()
: Once an array expression is found,
the rewriter replaces the individual array operations with a new
kind of IR expression, the arrayexpr
.numba.npyufunc.array_exprs._lower_array_expr()
: During
lowering, the code generator calls
_lower_array_expr()
whenever it
finds an arrayexpr
IR expression.More details on each step of the optimization are given below.
RewriteArrayExprs.match()
method¶The array expression optimization pass starts by looking for array
operations, including calls to supported ufunc
‘s and
user-defined DUFunc
‘s. Numba IR follows the
conventions of a static single assignment (SSA) language, meaning that
the search for array operators begins with looking for assignment
instructions.
When the rewriting pass calls the RewriteArrayExprs.match()
method, it first checks to see if it can trivially reject the basic
block. If the method determines the block to be a candidate for
matching, it sets up the following state variables in the rewrite
object:
At this point, the match method iterates iterates over the assignment instructions in the input basic block. For each assignment instruction, the matcher looks for one of two things:
The end of the matching method simply checks for a non-empty matches
list, returning True
if there were one or more matches, and
False
when matches is empty.
RewriteArrayExprs.apply()
method¶When one or matching array expressions are found by
RewriteArrayExprs.match()
, the rewriting pass will call
RewriteArrayExprs.apply()
. The apply method works in two
passes. The first pass iterates over the matches found, and builds a
map from instructions in the old basic block to new instructions in
the new basic block. The second pass iterates over the instructions
in the old basic block, copying instructions that are not changed by
the rewrite, and replacing or deleting instructions that were
identified by the first pass.
The RewriteArrayExprs._handle_matches()
implements the first
pass of the code generation portion of the rewrite. For each match,
this method builds a special IR expression that contains an expression
tree for the array expression. To compute the leaves of the
expression tree, the _handle_matches()
method
iterates over the operands of the identified root operation. If the
operand is another array operation, it is translated into an
expression sub-tree. If the operand is a constant,
_handle_matches()
copies the constant value.
Otherwise, the operand is marked as being used by an array expression.
As the method builds array expression nodes, it builds a map from old
instructions to new instructions (replace_map), as well as sets of
variables that may have moved (used_vars), and variables that should
be removed altogether (dead_vars). These three data structures are
returned back to the calling RewriteArrayExprs.apply()
method.
The remaining part of the RewriteArrayExprs.apply()
method
iterates over the instructions in the old basic block. For each
instruction, this method either replaces, deletes, or duplicates that
instruction based on the results of
RewriteArrayExprs._handle_matches()
. The following list
describes how the optimization handles individual instructions:
apply()
checks to see if it is in the
replacement instruction map. When an assignment instruction is found
in the instruction map, apply()
must then
check to see if the replacement instruction is also in the replacement
map. The optimizer continues this check until it either arrives at a
None
value or an instruction that isn’t in the replacement map.
Instructions that have a replacement that is None
are deleted.
Instructions that have a non-None
replacement are replaced.
Assignment instructions not in the replacement map are appended to the
new basic block with no changes made.apply()
uses to move them past any uses of
that variable. The loop copies delete instructions for non-dead
variables, and ignores delete instructions for dead variables
(effectively removing them from the basic block).Finally, the apply()
method returns the new
basic block for lowering.
_lower_array_expr()
function¶If we left things at just the rewrite, then the lowering stage of the
compiler would fail, complaining it doesn’t know how to lower
arrayexpr
operations. We start by hooking a lowering function
into the target context whenever the RewriteArrayExprs
class
is instantiated by the compiler. This hook causes the lowering pass to
call _lower_array_expr()
whenever it
encounters an arrayexr
operator.
This function has two steps:
ufunc
, returning the result of the expression on
scalar values in the broadcasted array arguments. The lowering
function accomplishes this by translating from the array expression
tree into a Python AST.numba.targets.numpyimpl.numpy_ufunc_kernel()
after defining
how to lower calls to the synthetic function.The end result is similar to loop lifting in Numba’s object mode.
We have seen how to implement rewrites in Numba, starting with the interface, and ending with an actual optimization. The key points of this section are: