We will extend the Numba frontend to support a class that it does not currently support so as to allow:
(all the above in nopython mode)
We will mix APIs from the high-level extension API and the low-level extension API, depending on what is available for a given task.
The starting point for our example is the following pure Python class:
class Interval(object):
"""
A half-open interval on the real number line.
"""
def __init__(self, lo, hi):
self.lo = lo
self.hi = hi
def __repr__(self):
return 'Interval(%f, %f)' % (self.lo, self.hi)
@property
def width(self):
return self.hi - self.lo
As the Interval
class is not known to Numba, we must create a new Numba
type to represent instances of it. Numba does not deal with Python types
directly: it has its own type system that allows a different level of
granularity as well as various meta-information not available with regular
Python types.
We first create a type class IntervalType
and, since we don’t need the
type to be parametric, we instantiate a single type instance interval_type
:
from numba import types
class IntervalType(types.Type):
def __init__(self):
super(IntervalType, self).__init__(name='Interval')
interval_type = IntervalType()
In itself, creating a Numba type doesn’t do anything. We must teach Numba
how to infer some Python values as instances of that type. In this example,
it is trivial: any instance of the Interval
class should be treated as
belonging to the type interval_type
:
from numba.extending import typeof_impl
@typeof_impl.register(Interval)
def typeof_index(val, c):
return interval_type
Function arguments and global values will thusly be recognized as belonging
to interval_type
whenever they are instances of Interval
.
We want to be able to construct interval objects from Numba functions, so
we must teach Numba to recognize the two-argument Interval(lo, hi)
constructor. The arguments should be floating-point numbers:
from numba.extending import type_callable
@type_callable(Interval)
def type_interval(context):
def typer(lo, hi):
if isinstance(lo, types.Float) and isinstance(hi, types.Float):
return interval_type
return typer
The type_callable()
decorator specifies that the decorated function
should be invoked when running type inference for the given callable object
(here the Interval
class itself). The decorated function must simply
return a typer function that will be called with the argument types. The
reason for this seemingly convoluted setup is for the typer function to have
exactly the same signature as the typed callable. This allows handling
keyword arguments correctly.
The context argument received by the decorated function is useful in more sophisticated cases where computing the callable’s return type requires resolving other types.
We have finished teaching Numba about our type inference additions. We must now teach Numba how to actually generated code and data for the new operations.
As a general rule, nopython mode does not work on Python objects as they are generated by the CPython interpreter. The representations used by the interpreter are far too inefficient for fast native code. Each type supported in nopython mode therefore has to define a tailored native representation, also called a data model.
A common case of data model is an immutable struct-like data model, that
is akin to a C struct
. Our interval datatype conveniently falls in
that category, and here is a possible data model for it:
from numba.extending import models, register_model
@register_model(IntervalType)
class IntervalModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
('lo', types.float64),
('hi', types.float64),
]
models.StructModel.__init__(self, dmm, fe_type, members)
This instructs Numba that values of type IntervalType
(or any instance
thereof) are represented as a structure of two fields lo
and hi
,
each of them a double-precision floating-point number (types.float64
).
Note
Mutable types need more sophisticated data models to be able to persist their values after modification. They typically cannot be stored and passed on the stack or in registers like immutable types do.
We want the data model attributes lo
and hi
to be exposed under
the same names for use in Numba functions. Numba provides a convenience
function to do exactly that:
from numba.extending import make_attribute_wrapper
make_attribute_wrapper(IntervalType, 'lo', 'lo')
make_attribute_wrapper(IntervalType, 'hi', 'hi')
This will expose the attributes in read-only mode. As mentioned above, writable attributes don’t fit in this model.
As the width
property is computed rather than stored in the structure,
we cannot simply expose it like we did for lo
and hi
. We have to
re-implement it explicitly:
from numba.extending import overload_attribute
@overload_attribute(IntervalType, "width")
def get_width(interval):
def getter(interval):
return interval.hi - interval.lo
return getter
You might ask why we didn’t need to expose a type inference hook for this
attribute? The answer is that @overload_attribute
is part of the
high-level API: it combines type inference and code generation in a
single API.
Now we want to implement the two-argument Interval
constructor:
from numba.extending import lower_builtin
from numba import cgutils
@lower_builtin(Interval, types.Float, types.Float)
def impl_interval(context, builder, sig, args):
typ = sig.return_type
lo, hi = args
interval = cgutils.create_struct_proxy(typ)(context, builder)
interval.lo = lo
interval.hi = hi
return interval._getvalue()
There is a bit more going on here. @lower_builtin
decorates the
implementation of the given callable or operation (here the Interval
constructor) for some specific argument types. This allows defining
type-specific implementations of a given operation, which is important
for heavily overloaded functions such as len()
.
types.Float
is the class of all floating-point types (types.float64
is an instance of types.Float
). It is generally more future-proof
to match argument types on their class rather than on specific instances
(however, when returning a type – chiefly during the type inference
phase –, you must usually return a type instance).
cgutils.create_struct_proxy()
and interval._getvalue()
are a bit
of boilerplate due to how Numba passes values around. Values are passed
as instances of llvmlite.ir.Value
, which can be too limited:
LLVM structure values especially are quite low-level. A struct proxy
is a temporary wrapper around a LLVM structure value allowing to easily
get or set members of the structure. The _getvalue()
call simply
gets the LLVM value out of the wrapper.
If you try to use an Interval
instance at this point, you’ll certainly
get the error “cannot convert Interval to native value”. This is because
Numba doesn’t yet know how to make a native interval value from a Python
Interval
instance. Let’s teach it how to do it:
from numba.extending import unbox, NativeValue
@unbox(IntervalType)
def unbox_interval(typ, obj, c):
"""
Convert a Interval object to a native interval structure.
"""
lo_obj = c.pyapi.object_getattr_string(obj, "lo")
hi_obj = c.pyapi.object_getattr_string(obj, "hi")
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder)
interval.lo = c.pyapi.float_as_double(lo_obj)
interval.hi = c.pyapi.float_as_double(hi_obj)
c.pyapi.decref(lo_obj)
c.pyapi.decref(hi_obj)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(interval._getvalue(), is_error=is_error)
Unbox is the other name for “convert a Python object to a native value”
(it fits the idea of a Python object as a sophisticated box containing
a simple native value). The function returns a NativeValue
object
which gives its caller access to the computed native value, the error bit
and possibly other information.
The snippet above makes abundant use of the c.pyapi
object, which
gives access to a subset of the
Python interpreter’s C API.
Note the use of c.pyapi.err_occurred()
to detect any errors that
may have happened when unboxing the object (try passing Interval('a', 'b')
for example).
We also want to do the reverse operation, called boxing, so as to return interval values from Numba functions:
from numba.extending import box
@box(IntervalType)
def box_interval(typ, val, c):
"""
Convert a native interval structure to an Interval object.
"""
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
lo_obj = c.pyapi.float_from_double(interval.lo)
hi_obj = c.pyapi.float_from_double(interval.hi)
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Interval))
res = c.pyapi.call_function_objargs(class_obj, (lo_obj, hi_obj))
c.pyapi.decref(lo_obj)
c.pyapi.decref(hi_obj)
c.pyapi.decref(class_obj)
return res
nopython mode functions are now able to make use of Interval objects and the various operations you have defined on them. You can try for example the following functions:
from numba import jit
@jit(nopython=True)
def inside_interval(interval, x):
return interval.lo <= x < interval.hi
@jit(nopython=True)
def interval_width(interval):
return interval.width
@jit(nopython=True)
def sum_intervals(i, j):
return Interval(i.lo + j.lo, i.hi + j.hi)
We have shown how to do the following tasks:
Type
classtypeof_impl.register
StructModel
and register_model
@box
decorator@unbox
decorator
and the NativeValue
class@type_callable
and
@lower_builtin
decoratorsmake_attribute_wrapper
convenience function@overload_attribute
decorator