Warning
The custom pipeline feature is for expert use only. Modifying the compiler behavior can invalidate internal assumptions in the numba source code.
For library developers looking for a way to extend or modify the compiler
behavior, you can do so by defining a custom compiler by inheriting from
numba.compiler.CompilerBase
. The default Numba compiler is defined
as numba.compiler.Compiler
, implementing the .define_pipelines()
method, which adds the nopython-mode, object-mode and interpreted-mode
pipelines. For convenience these three pipelines are defined in
numba.compiler.DefaultPassBuilder
by the methods:
.define_nopython_pipeline()
.define_objectmode_pipeline()
.define_interpreted_pipeline()
respectively.
To use a custom subclass of CompilerBase
, supply it as the
pipeline_class
keyword argument to the @jit
and @generated_jit
decorators. By doing so, the effect of the custom pipeline is limited to the
function being decorated.
Numba makes it possible to implement a new compiler pass and does so through the use of an API similar to that of LLVM. The following demonstrates the basic process involved.
All passes must inherit from numba.compiler_machinery.CompilerPass
, commonly
used subclasses are:
numba.compiler_machinery.FunctionPass
for describing a pass that operates
on a function-at-once level and may mutate the IR state.numba.compiler_machinery.AnalysisPass
for describing a pass that performs
analysis only.numba.compiler_machinery.LoweringPass
for describing a pass that performs
lowering only.In this example a new compiler pass will be implemented that will rewrite all
ir.Const(x)
nodes, where x
is a subclass of numbers.Number
, such
that the value of x is incremented by one. There is no use for this pass other
than to serve as a pedagogical vehicle!
The numba.compiler_machinery.FunctionPass
is appropriate for the suggested
pass behavior and so is the base class of the new pass. Further, a run_pass
method is defined to do the work (this method is abstract, all compiler passes
must implement it).
First the new class:
from numba import njit
from numba import ir
from numba.compiler import CompilerBase, DefaultPassBuilder
from numba.compiler_machinery import FunctionPass, register_pass
from numba.untyped_passes import IRProcessing
from numbers import Number
# Register this pass with the compiler framework, declare that it will not
# mutate the control flow graph and that it is not an analysis_only pass (it
# potentially mutates the IR).
@register_pass(mutates_CFG=False, analysis_only=False)
class ConstsAddOne(FunctionPass):
_name = "consts_add_one" # the common name for the pass
def __init__(self):
FunctionPass.__init__(self)
# implement method to do the work, "state" is the internal compiler
# state from the CompilerBase instance.
def run_pass(self, state):
func_ir = state.func_ir # get the FunctionIR object
mutated = False # used to record whether this pass mutates the IR
# walk the blocks
for blk in func_ir.blocks.values():
# find the assignment nodes in the block and walk them
for assgn in blk.find_insts(ir.Assign):
# if an assignment value is a ir.Consts
if isinstance(assgn.value, ir.Const):
const_val = assgn.value
# if the value of the ir.Const is a Number
if isinstance(const_val.value, Number):
# then add one!
const_val.value += 1
mutated |= True
return mutated # return True if the IR was mutated, False if not.
Note also that the class must be registered with Numba’s compiler machinery
using @register_pass
. This in part is to allow the declaration of whether
the pass mutates the control flow graph and whether it is an analysis only pass.
Next, define a new compiler based on the existing
numba.compiler.CompilerBase
. The compiler pipeline is defined through the
use of an existing pipeline and the new pass declared above is added to be run
after the IRProcessing
pass.
class MyCompiler(CompilerBase): # custom compiler extends from CompilerBase
def define_pipelines(self):
# define a new set of pipelines (just one in this case) and for ease
# base it on an existing pipeline from the DefaultPassBuilder,
# namely the "nopython" pipeline
pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
# Add the new pass to run after IRProcessing
pm.add_pass_after(ConstsAddOne, IRProcessing)
# finalize
pm.finalize()
# return as an iterable, any number of pipelines may be defined!
return [pm]
Finally update the @njit
decorator at the call site to make use of the newly
defined compilation pipeline.
@njit(pipeline_class=MyCompiler) # JIT compile using the custom compiler
def foo(x):
a = 10
b = 20.2
c = x + a + b
return c
print(foo(100)) # 100 + 10 + 20.2 (+ 1 + 1), extra + 1 + 1 from the rewrite!
It is often useful to be able to see the changes a pass makes to the IR. Numba
conveniently permits this through the use of the environment variable
NUMBA_DEBUG_PRINT_AFTER
. In the case of the above pass, running the
example code with NUMBA_DEBUG_PRINT_AFTER="ir_processing,consts_add_one"
gives:
----------------------------nopython: ir_processing-----------------------------
label 0:
x = arg(0, name=x) ['x']
$const0.1 = const(int, 10) ['$const0.1']
a = $const0.1 ['$const0.1', 'a']
del $const0.1 []
$const0.2 = const(float, 20.2) ['$const0.2']
b = $const0.2 ['$const0.2', 'b']
del $const0.2 []
$0.5 = x + a ['$0.5', 'a', 'x']
del x []
del a []
$0.7 = $0.5 + b ['$0.5', '$0.7', 'b']
del b []
del $0.5 []
c = $0.7 ['$0.7', 'c']
del $0.7 []
$0.9 = cast(value=c) ['$0.9', 'c']
del c []
return $0.9 ['$0.9']
----------------------------nopython: consts_add_one----------------------------
label 0:
x = arg(0, name=x) ['x']
$const0.1 = const(int, 11) ['$const0.1']
a = $const0.1 ['$const0.1', 'a']
del $const0.1 []
$const0.2 = const(float, 21.2) ['$const0.2']
b = $const0.2 ['$const0.2', 'b']
del $const0.2 []
$0.5 = x + a ['$0.5', 'a', 'x']
del x []
del a []
$0.7 = $0.5 + b ['$0.5', '$0.7', 'b']
del b []
del $0.5 []
c = $0.7 ['$0.7', 'c']
del $0.7 []
$0.9 = cast(value=c) ['$0.9', 'c']
del c []
return $0.9 ['$0.9']
Note the change in the values in the const
nodes.
Numba has built-in support for timing all compiler passes, the execution times
are stored in the metadata associated with a compilation result. This
demonstrates one way of accessing this information based on the previously
defined function, foo
:
compile_result = foo.overloads[foo.signatures[0]]
nopython_times = compile_result.metadata['pipeline_times']['nopython']
for k in nopython_times.keys():
if ConstsAddOne._name in k:
print(nopython_times[k])
the output of which is, for example:
pass_timings(init=1.914000677061267e-06, run=4.308700044930447e-05, finalize=1.7400006981915794e-06)
this displaying the pass initialization, run and finalization times in seconds.