@overload
¶As mentioned in the high-level extension API, you
can use the @overload
decorator to create a Numba implementation of a
function that can be used in nopython mode functions. A common use case
is to re-implement NumPy functions so that they can be called in @jit
decorated code. This section discusses how and when to use the @overload
decorator and what contributing such a function to the Numba code base might
entail. This should help you get started when needing to use the @overload
decorator or when attempting to contribute new functions to Numba itself.
The @overload
decorator and it’s variants are useful when you have a
third-party library that you do not control and you wish to provide Numba
compatible implementations for specific functions from that library.
Let’s assume that you are working on a minimization algorithm that makes use of
scipy.linalg.norm
to find different vector norms and the frobenius
norm for matrices.
You know that only integer and real numbers will be involved. (While this may
sound like an artificial example, especially because a Numba implementation of
numpy.linalg.norm
exists, it is largely pedagogical and serves to
illustrate how and when to use @overload
).
The skeleton might look something like this:
def algorithm():
# setup
v = ...
while True:
# take a step
d = scipy.linalg.norm(v)
if d < tolerance:
break
Now, let’s further assume, that you have heard of Numba and you now wish to use
it to accelerate your function. However, after adding the
jit(nopython=True)
decorator, Numba complains that scipy.linalg.norm
isn’t supported. From
looking at the documentation, you realize that a norm is probably fairly easy
to implement using NumPy. A good starting point is the following template.
# Declare that function `myfunc` is going to be overloaded (have a
# substitutable Numba implementation)
@overload(myfunc)
# Define the overload function with formal arguments
# these arguments must be matched in the inner function implementation
def jit_myfunc(arg0, arg1, arg2, ...):
# This scope is for typing, access is available to the *type* of all
# arguments. This information can be used to change the behaviour of the
# implementing function and check that the types are actually supported
# by the implementation.
print(arg0) # this will show the Numba type of arg0
# This is the definition of the function that implements the `myfunc` work.
# It does whatever algorithm is needed to implement myfunc.
def myfunc_impl(arg0, arg1, arg2, ...): # match arguments to jit_myfunc
# < Implementation goes here >
return # whatever needs to be returned by the algorithm
# return the implementation
return myfunc_impl
After some deliberation and tinkering, you end up with the following code:
import numpy as np
from numba import njit, types
from numba.extending import overload, register_jitable
from numba.errors import TypingError
import scipy.linalg
@register_jitable
def _oneD_norm_2(a):
# re-usable implementation of the 2-norm
val = np.abs(a)
return np.sqrt(np.sum(val * val))
@overload(scipy.linalg.norm)
def jit_norm(a, ord=None):
if isinstance(ord, types.Optional):
ord = ord.type
# Reject non integer, floating-point or None types for ord
if not isinstance(ord, (types.Integer, types.Float, types.NoneType)):
raise TypingError("'ord' must be either integer or floating-point")
# Reject non-ndarray types
if not isinstance(a, types.Array):
raise TypingError("Only accepts NumPy ndarray")
# Reject ndarrays with non integer or floating-point dtype
if not isinstance(a.dtype, (types.Integer, types.Float)):
raise TypingError("Only integer and floating point types accepted")
# Reject ndarrays with unsupported dimensionality
if not (0 <= a.ndim <= 2):
raise TypingError('3D and beyond are not allowed')
# Implementation for scalars/0d-arrays
elif a.ndim == 0:
return a.item()
# Implementation for vectors
elif a.ndim == 1:
def _oneD_norm_x(a, ord=None):
if ord == 2 or ord is None:
return _oneD_norm_2(a)
elif ord == np.inf:
return np.max(np.abs(a))
elif ord == -np.inf:
return np.min(np.abs(a))
elif ord == 0:
return np.sum(a != 0)
elif ord == 1:
return np.sum(np.abs(a))
else:
return np.sum(np.abs(a)**ord)**(1. / ord)
return _oneD_norm_x
# Implementation for matrices
elif a.ndim == 2:
def _two_D_norm_2(a, ord=None):
return _oneD_norm_2(a.ravel())
return _two_D_norm_2
if __name__ == "__main__":
@njit
def use(a, ord=None):
# simple test function to check that the overload works
return scipy.linalg.norm(a, ord)
# spot check for vectors
a = np.arange(10)
print(use(a))
print(scipy.linalg.norm(a))
# spot check for matrices
b = np.arange(9).reshape((3, 3))
print(use(b))
print(scipy.linalg.norm(b))
As you can see, the implementation only supports what you need right now:
@register_jitable
.So what actually happens here? The overload
decorator registers a suitable
implementation for scipy.linalg.norm
in case a call to this is encountered
in code that is being JIT-compiled, for example when you decorate your
algorithm
function with @jit(nopython=True)
. In that case, the function
jit_norm
will be called with the currently encountered types and will then
return either _oneD_norm_x
in the vector case and _two_D_norm_2
.
You can download the example code here: mynorm.py
@overload
for NumPy functions¶Numba supports NumPy through the provision of @jit
compatible
re-implementations of NumPy functions. In such cases @overload
is a very
convenient option for writing such implementations, however there are a few
additional things to watch out for.
np.corrcoef
may return an array or a scalar depending on its
inputs.If you are implementing a new function, you should always update the
documentation.
The sources can be found in docs/source/reference/numpysupported.rst
. Be
sure to mention any limitations that your implementation has, e.g. no support
for the axis
keyword.
When writing tests for the functionality itself, it’s useful to include handling of non-finite values, arrays with different shapes and layouts, complex inputs, scalar inputs, inputs with types for which support is not documented (e.g. a function which the NumPy docs say requires a float or int input might also ‘work’ if given a bool or complex input).
When writing tests for exceptions, for example if adding tests to
numba/tests/test_np_functions.py
, you may encounter the following error
message:
======================================================================
FAIL: test_foo (numba.tests.test_np_functions.TestNPFunctions)
----------------------------------------------------------------------
Traceback (most recent call last):
File "<path>/numba/numba/tests/support.py", line 645, in tearDown
self.memory_leak_teardown()
File "<path>/numba/numba/tests/support.py", line 619, in memory_leak_teardown
self.assert_no_memory_leak()
File "<path>/numba/numba/tests/support.py", line 628, in assert_no_memory_leak
self.assertEqual(total_alloc, total_free)
AssertionError: 36 != 35
This occurs because raising exceptions from jitted code leads to reference
leaks. Ideally, you will place all exception testing in a separate test
method and then add a call in each test to self.disable_leak_check()
to
disable the leak-check (inherit from numba.tests.support.TestCase
to make
that available).
For many of the functions that are available in NumPy, there are
corresponding methods defined on the NumPy ndarray
type. For example, the
function repeat
is available as a NumPy module level function and a
member function on the ndarray
class.
import numpy as np
a = np.arange(10)
# function
np.repeat(a, 10)
# method
a.repeat(10)
Once you have written the function implementation, you can easily use
@overload_method
and reuse it. Just be sure to check that NumPy doesn’t
diverge in the implementations of its function/method.
As an example, the repeat
function/method:
@extending.overload_method(types.Array, 'repeat')
def array_repeat(a, repeats):
def array_repeat_impl(a, repeat):
# np.repeat has already been overloaded
return np.repeat(a, repeat)
return array_repeat_impl
If you need to create ancillary functions, for example to re-use a small
utility function or to split your implementation across functions for the
sake of readability, you can make use of the @register_jitable
decorator.
This will make those functions available from within your @jit
and
@overload
decorated functions.
The Numba continuous integration (CI) set up tests a wide variety of NumPy versions, you’ll sometimes be alerted to a change in behaviour from some previous NumPy version. If you can find supporting evidence in the NumPy change log / repository, then you’ll need to decide whether to create branches and attempt to replicate the logic across versions, or use a version gate (with associated wording in the documentation) to advertise that Numba replicates NumPy from some particular version onwards.
You can look at the Numba source code for inspiration, many of the overloaded
NumPy functions and methods are in numba/targets/arrayobj.py
. Below, you
will find a list of implementations to look at that are well implemented in
terms of accepted types and test coverage.
np.repeat