Examples

A Simple Function

Suppose we want to write an image-processing function in Python. Here’s how it might look.

import numpy

def filter2d(image, filt):
    M, N = image.shape
    Mf, Nf = filt.shape
    Mf2 = Mf // 2
    Nf2 = Nf // 2
    result = numpy.zeros_like(image)
    for i in range(Mf2, M - Mf2):
        for j in range(Nf2, N - Nf2):
            num = 0.0
            for ii in range(Mf):
                for jj in range(Nf):
                    num += (filt[Mf-1-ii, Nf-1-jj] * image[i-Mf2+ii, j-Nf2+jj])
            result[i, j] = num
    return result

# This kind of quadruply-nested for-loop is going to be quite slow.
# Using Numba we can compile this code to LLVM which then gets
# compiled to machine code:

from numba import double, jit

fastfilter_2d = jit(double[:,:](double[:,:], double[:,:]))(filter2d)

# Now fastfilter_2d runs at speeds as if you had first translated
# it to C, compiled the code and wrapped it with Python
image = numpy.random.random((100, 100))
filt = numpy.random.random((10, 10))
res = fastfilter_2d(image, filt)

Numba actually produces two functions. The first function is the low-level compiled version of filter2d. The second function is the Python wrapper to that low-level function so that the function can be called from Python. The first function can be called from other numba functions to eliminate all python overhead in function calling.

Classes

# -*- coding: utf-8 -*-
"""
Example for extension classes.

Things that work:

    - overriding Numba methods in Numba (all methods are virtual)
    - inheritance
    - instance attributes
    - subclassing in python and calling overridden methods in Python
    - arbitrary new attributes on extension classes and objects
    - weakrefs to extension objects

Things that do NOT (yet) work:

    - overriding methods in Python and calling the method from Numba
    - multiple inheritance of Numba classes
        (multiple inheritance with Python classes should work)
    - subclassing variable sized objects like 'str' or 'tuple'
"""
from __future__ import print_function, division, absolute_import

from numba import jit, void, int_, double

# All methods must be given signatures

@jit
class Shrubbery(object):
    @void(int_, int_)
    def __init__(self, w, h):
        # All instance attributes must be defined in the initializer
        self.width = w
        self.height = h

        # Types can be explicitly specified through casts
        self.some_attr = double(1.0)

    @int_()
    def area(self):
        return self.width * self.height

    @void()
    def describe(self):
        print("This shrubbery is ", self.width,
              "by", self.height, "cubits.")
 
shrub = Shrubbery(10, 20)
print(shrub.area())
shrub.describe()
print(shrub.width, shrub.height)
shrub.width = 30
print(shrub.area())
print(shrub._numba_attrs._fields_) # This is an internal attribute subject to change!

class MyClass(Shrubbery):
    def newmethod(self):
        print("This is a new method.")

shrub2 = MyClass(30,40)
shrub2.describe()
shrub2.newmethod()
print(shrub._numba_attrs._fields_)

Closures

# -*- coding: utf-8 -*-
"""
Example for closures. Closures may be of arbitrary dept, and they keep
the scope alive as long as the closure is alive. Only variables that are
closed over (cell variables in the defining function, free variables in the
closure), are kept alive. See also numba/tests/closures/test_closure.py
"""
from __future__ import print_function, division, absolute_import

from numba import autojit, jit, float_
from numpy import linspace

@autojit
def generate_power_func(n):
    @jit(float_(float_))
    def nth_power(x):
        return x ** n

    # This is a native call
    print(nth_power(10))

    # Return closure and keep all cell variables alive
    return nth_power

for n in range(2, 5):
    func = generate_power_func(n)
    print([func(x) for x in linspace(1.,2.,10.)])

Structs

# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
from numba import struct, jit, double
import numpy as np

record_type = struct([('x', double), ('y', double)])
record_dtype = record_type.get_dtype()
a = np.array([(1.0, 2.0), (3.0, 4.0)], dtype=record_dtype)

@jit(argtypes=[record_type[:]])
def hypot(data):
    # return types of numpy functions are inferred
    result = np.empty_like(data, dtype=np.float64) 
    # notice access to structure elements 'x' and 'y' via attribute access
    # You can also index by field name or field index:
    #       data[i].x == data[i]['x'] == data[i][0]
    for i in range(data.shape[0]):
        result[i] = np.sqrt(data[i].x * data[i].x + data[i].y * data[i].y)

    return result

print(hypot(a))

# Notice inferred return type
print(hypot.signature)
# Notice native sqrt calls and for.body direct access to memory...
#print(hypot.lfunc)

Pointers

# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
import numba
from numba import *
from numba.tests.test_support import autojit_py3doc
import numpy as np

int32p = int32.pointer()
voidp = void.pointer()

@autojit_py3doc
def test_pointer_arithmetic():
    """
    >>> test_pointer_arithmetic()
    48L
    """
    p = int32p(Py_uintptr_t(0))
    p = p + 10
    p += 2
    return Py_uintptr_t(p) # 0 + 4 * 12

@autojit_py3doc(locals={"pointer_value": Py_uintptr_t})
def test_pointer_indexing(pointer_value, type_p):
    """
    >>> a = np.array([1, 2, 3, 4], dtype=np.float32)
    >>> test_pointer_indexing(a.ctypes.data, float32.pointer())
    (1.0, 2.0, 3.0, 4.0)

    >>> a = np.array([1, 2, 3, 4], dtype=np.int64)
    >>> test_pointer_indexing(a.ctypes.data, int64.pointer())
    (1L, 2L, 3L, 4L)
    """
    p = type_p(pointer_value)
    return p[0], p[1], p[2], p[3]

@autojit
def test_compare_null():
    """
    >>> test_compare_null()
    True
    """
    return voidp(Py_uintptr_t(0)) == numba.NULL

numba.testing.testmod()

Objects

# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
from numba import double, autojit

class MyClass(object):
    def mymethod(self, arg):
        return arg * 2
    
@autojit(locals=dict(mydouble=double)) # specify types for local variables
def call_method(obj):
    print(obj.mymethod("hello"))  # object result
    mydouble = obj.mymethod(10.2) # native double
    print(mydouble * 2)           # native multiplication
    
call_method(MyClass())

Mandelbrot

# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
from numba import autojit
import numpy as np
from pylab import imshow, jet, show, ion

@autojit
def mandel(x, y, max_iters):
    """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
    """
    i = 0
    c = complex(x,y)
    z = 0.0j
    for i in range(max_iters):
        z = z*z + c
        if (z.real*z.real + z.imag*z.imag) >= 4:
            return i

    return 255

@autojit
def create_fractal(min_x, max_x, min_y, max_y, image, iters):
    height = image.shape[0]
    width = image.shape[1]

    pixel_size_x = (max_x - min_x) / width
    pixel_size_y = (max_y - min_y) / height
    for x in range(width):
        real = min_x + x * pixel_size_x
        for y in range(height):
            imag = min_y + y * pixel_size_y
            color = mandel(real, imag, iters)
            image[y, x] = color

    return image

image = np.zeros((500, 750), dtype=np.uint8)
imshow(create_fractal(-2.0, 1.0, -1.0, 1.0, image, 20))
jet()
ion()
show()

Filterbank Correlation

# -*- coding: utf-8 -*-
"""
This file demonstrates a filterbank correlation loop.
"""
from __future__ import print_function, division, absolute_import
import numpy as np
import numba
from numba.decorators import jit
nd4type = numba.double[:,:,:,:]

@jit(argtypes=(nd4type, nd4type, nd4type))
def fbcorr(imgs, filters, output):
    n_imgs, n_rows, n_cols, n_channels = imgs.shape
    n_filters, height, width, n_ch2 = filters.shape

    for ii in range(n_imgs):
        for rr in range(n_rows - height + 1):
            for cc in range(n_cols - width + 1):
                for hh in xrange(height):
                    for ww in xrange(width):
                        for jj in range(n_channels):
                            for ff in range(n_filters):
                                imgval = imgs[ii, rr + hh, cc + ww, jj]
                                filterval = filters[ff, hh, ww, jj]
                                output[ii, ff, rr, cc] += imgval * filterval

def main ():
    imgs = np.random.randn(10, 64, 64, 3)
    filt = np.random.randn(6, 5, 5, 3)
    output = np.zeros((10, 60, 60, 6))

    import time
    t0 = time.time()
    fbcorr(imgs, filt, output)
    print(time.time() - t0)

if __name__ == "__main__":
    main()

Multi threading

"""
Example of multithreading by releasing the GIL through ctypes.
"""
from __future__ import print_function, division, absolute_import

from timeit import repeat
import threading
from ctypes import pythonapi, c_void_p
from math import exp

import numpy as np
from numba import jit, void, double

nthreads = 2
size = 1e6

def timefunc(correct, s, func, *args, **kwargs):
    print(s.ljust(20), end=" ")
    # Make sure the function is compiled before we start the benchmark
    res = func(*args, **kwargs)
    if correct is not None:
        assert np.allclose(res, correct)
    # time it
    print('{:>5.0f} ms'.format(min(repeat(lambda: func(*args, **kwargs),
                                          number=5, repeat=2)) * 1000))
    return res

def make_singlethread(inner_func):
    def func(*args):
        length = len(args[0])
        result = np.empty(length, dtype=np.float64)
        inner_func(result, *args)
        return result
    return func

def make_multithread(inner_func, numthreads):
    def func_mt(*args):
        length = len(args[0])
        result = np.empty(length, dtype=np.float64)
        args = (result,) + args        
        chunklen = (length + 1) // numthreads
        chunks = [[arg[i * chunklen:(i + 1) * chunklen] for arg in args]
                  for i in range(numthreads)]

        # You should make sure inner_func is compiled at this point, because
        # the compilation must happen on the main thread. This is the case
        # in this example because we use jit().
        threads = [threading.Thread(target=inner_func, args=chunk)
                   for chunk in chunks[:-1]]
        for thread in threads:
            thread.start()

        # the main thread handles the last chunk
        inner_func(*chunks[-1])

        for thread in threads:
            thread.join()
        return result
    return func_mt
  
savethread = pythonapi.PyEval_SaveThread
savethread.argtypes = []
savethread.restype = c_void_p

restorethread = pythonapi.PyEval_RestoreThread
restorethread.argtypes = [c_void_p]
restorethread.restype = None

def inner_func(result, a, b):
    threadstate = savethread()
    for i in range(len(result)):
        result[i] = exp(2.1 * a[i] + 3.2 * b[i])
    restorethread(threadstate)

signature = void(double[:], double[:], double[:])
inner_func_nb = jit(signature, nopython=True)(inner_func)
func_nb = make_singlethread(inner_func_nb)
func_nb_mt = make_multithread(inner_func_nb, nthreads)
            
def func_np(a, b):
    return np.exp(2.1 * a + 3.2 * b)

a = np.random.rand(size)
b = np.random.rand(size)
c = np.random.rand(size)

correct = timefunc(None, "numpy (1 thread)", func_np, a, b)
timefunc(correct, "numba (1 thread)", func_nb, a, b)
timefunc(correct, "numba (%d threads)" % nthreads, func_nb_mt, a, b)

Strings and libc

# -*- coding: utf-8 -*-

"""
Example of using strings with numba using libc and some basic string
functionality.
"""

from __future__ import division, absolute_import
import struct
import socket

import numba as nb
import cffi

ffi = cffi.FFI()
ffi.cdef("""
void abort(void);
char *strstr(const char *s1, const char *s2);
int atoi(const char *str);
char *strtok(char *restrict str, const char *restrict sep);
""")

lib = ffi.dlopen(None)

# For now, we need to make these globals so numba will recognize them
abort, strstr, atoi, strtok = lib.abort, lib.strstr, lib.atoi, lib.strtok

int8_p = nb.int8.pointer()
int_p  = nb.int_.pointer()

@nb.autojit(nopython=True)
def parse_int_strtok(s):
    """
    Convert an IP address given as a string to an int, similar to
    socket.inet_aton(). Performs no error checking!
    """
    result = nb.uint32(0)
    current = strtok(s, ".")
    for i in range(4):
        byte = atoi(current)
        shift = (3 - i) * 8
        result |= byte << shift
        current = strtok(int_p(nb.NULL), ".")

    return result

@nb.autojit(nopython=True)
def parse_int_manual(s):
    """
    Convert an IP address given as a string to an int, similar to
    socket.inet_aton(). Performs no error checking!
    """
    result = nb.uint32(0)
    end = len(s)
    start = 0
    shift = 3
    for i in range(end):
        if s[i] == '.'[0] or i == end - 1:
            byte = atoi(int8_p(s) + start)
            result |= byte << (shift * 8)
            shift -= 1
            start = i + 1

    return result

result1 = parse_int_strtok('192.168.1.2')
result2 = parse_int_manual('1.2.3.4')
print(socket.inet_ntoa(struct.pack('>I', result1)))
print(socket.inet_ntoa(struct.pack('>I', result2)))