Adding Type Inference for External Code

Users can add type inference for non-numba functions, to specify to numba what the result of an operation will look like. This can make numba code much more efficient by knowing the lower-level types.

Type Functions

Users can write type functions to tell Numba what the return type of a certain function call in numba code will be. The return type may be always the same, or it may depend on the input types in a simple or complicated way.

Below we will describe ways to handle all these cases.

Simple Signatures

Simple signatures can be easily handled by registering a callable with a static signature:

from numba import register_callable
from numba import *

@register_callable(double(double))
def square(arg):
    return arg * 2

This specifies that the function takes only a double and returns a double.

If a function is statically polymorphic, a typeset can be used to represent a variety of inputs:

import numba as nb
from numba import *

@nb.register_callable(nb.numeric(nb.numeric))
def square(...):
    ...

Here numba.numeric represents the set of numeric types. The set is then bound in a similar way to Cython’s fused types, in that the argument types instantiate the sets to concrete types. This means the signature above allows signatures like double(double) and int_(int_), but not for instance double(int).

Type sets can be created explicitly:

signatures = numba.typeset(object_(object_), numeric(numeric))

@numba.register_callable(signatures)
def square(...):
    ...

Parametric Polymorphism

Type sets are fine for small sets, but not general enough for a large or infinite type universe. Instead, we allow parametric polymorphism through simple templates and type functions (functions that return the type at compile time given the input types). For an introduction to templates we refer the reader to templates_.

Templates can be used as follows:

# create a type variable that will be bound by an argument type
a = numba.template("a")

@numba.register_callable(a(a[:]))
def sum_1d(...):
    ...

More general type expressions are listed under templates_.

Note

Templates for user-based type inference is not yet implemented.

For full control a user can write an explicit function that performs the type inference based on the input types. The input types are passed in as objects and inferred from the function’s signature:

from numba import typesystem
from numba.typesystem import get_type

def infer_reduce(a, axis, dtype, out):
    if out is not None:
        return out

    dtype_type = get_dtype(a, dtype, static_dtype)

    if axis is None:
        # Return the scalar type
        return dtype_type

    if dtype_type:
        # Handle the axis parameter
        if axis.is_tuple and axis.is_sized:
            # axis=(tuple with a constant size)
            return typesystem.array(dtype_type, a.ndim - axis.size)
        elif axis.is_int:
            # axis=1
            return typesystem.array(dtype_type, a.ndim - 1)
        else:
            # axis=(something unknown)
            return object_

register_inferer(np, 'sum', infer_reduce)    # Register type inference for np.sum
register_inferer(np, 'prod', infer_reduce)   # Register type inference for np.prod

The above works through introspection of the function using the inspect module. A call in the user code to np.sum or np.prod will now ask the above function to resolve its type at Numba compile time, passing in the types representing the arguments, or None when absent:

# Numba code
np.sum(a)              # => infer_reduce(a=double[:, :]), axis=None, dtype=None, out=None)
np.sum(a, axis=(1, 2)) # => infer_reduce(a=double[:, :, :]), axis=tuple(base_type=int_, size=2),
                       #                 dtype=None, out=None)

A shorthand function to register type functions is provided by numba.register:

@numba.register(np)
def sum(a, axis, dtype, out):
    ...

This retrieves np.sum based on the name of the type inferring function (hence it must be called sum).

Registering Unbound Methods

Unbound methods are transient, and hence can not be registered by value. Instead we register a dotted path starting at a value, e.g.:

numba.register_unbound(np, "add", "reduce", infer_reduce)

To allow type inference for np.add.reduce(). The first string specifies the module (np), the second the object ("add"), the third the dotted path ("reduce") and the last the type function (infer_reduce).

Future Directions

The code above is clearly very verbose, which is partly due to the generality of the sum signature. In the future we hope to expose a more declarative way to specify parametrically polymorphic signatures. Perhaps something like:

sum(a)                                              => a
sum(array(dtype, ndim), axis=integral)              => array(dtype, ndim - 1)
sum(array(dtype, ndim), axis=tuple(integral, size)) => array(dtype, ndim - size)
sum(in, axis=axis, out=out)                         => sum(out, axis=axis)
sum(_, _)                                           => object_