Examples

A Simple Function

Suppose we want to write an image-processing function in Python. Here’s how it might look.

import numpy

def filter2d(image, filt):
    M, N = image.shape
    Mf, Nf = filt.shape
    Mf2 = Mf // 2
    Nf2 = Nf // 2
    result = numpy.zeros_like(image)
    for i in range(Mf2, M - Mf2):
        for j in range(Nf2, N - Nf2):
            num = 0.0
            for ii in range(Mf):
                for jj in range(Nf):
                    num += (filt[Mf-1-ii, Nf-1-jj] * image[i-Mf2+ii, j-Nf2+jj])
            result[i, j] = num
    return result

# This kind of quadruply-nested for-loop is going to be quite slow.
# Using Numba we can compile this code to LLVM which then gets
# compiled to machine code:

from numba import double, jit

fastfilter_2d = jit(double[:,:](double[:,:], double[:,:]))(filter2d)

# Now fastfilter_2d runs at speeds as if you had first translated
# it to C, compiled the code and wrapped it with Python
image = numpy.random.random((100, 100))
filt = numpy.random.random((10, 10))
res = fastfilter_2d(image, filt)

Numba actually produces two functions. The first function is the low-level compiled version of filter2d. The second function is the Python wrapper to that low-level function so that the function can be called from Python. The first function can be called from other numba functions to eliminate all python overhead in function calling.

Objects

# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
from numba import jit

class MyClass(object):
    def mymethod(self, arg):
        return arg * 2

@jit
def call_method(obj):
    print(obj.mymethod("hello"))  # object result
    mydouble = obj.mymethod(10.2) # native double
    print(mydouble * 2)           # native multiplication
    
call_method(MyClass())

UFuncs

from numba import vectorize
from numba import autojit, double, jit
import math
import numpy as np

@vectorize(['f8(f8)','f4(f4)'])
def sinc(x):
    if x == 0:
        return 1.0
    else:
        return math.sin(x*math.pi) / (x*math.pi)

@vectorize(['int8(int8,int8)',
            'int16(int16,int16)',
            'int32(int32,int32)',
            'int64(int64,int64)',
            'f4(f4,f4)',
            'f8(f8,f8)'])
def add(x,y):
    return x + y

@vectorize(['f8(f8)','f4(f4)'])
def logit(x):
    return math.log(x / (1-x))

@vectorize(['f8(f8)','f4(f4)'])
def expit(x):
    if x > 0:
        x = math.exp(x)
        return x / (1 + x)
    else:
        return 1 / (1 + math.exp(-x))

@jit('f8(f8,f8[:])')
def polevl(x, coef):
    N = len(coef)
    ans = coef[0]
    i = 1
    while i < N:
        ans = ans * x + coef[i]
        i += 1
    return ans

@jit('f8(f8,f8[:])')
def p1evl(x, coef):
    N = len(coef)
    ans = x + coef[0]
    i = 1
    while i < N:
        ans = ans * x + coef[i]
        i += 1
    return ans    


PP = np.array([
  7.96936729297347051624E-4,
  8.28352392107440799803E-2,
  1.23953371646414299388E0,
  5.44725003058768775090E0,
  8.74716500199817011941E0,
  5.30324038235394892183E0,
  9.99999999999999997821E-1], 'd')

PQ = np.array([
  9.24408810558863637013E-4,
  8.56288474354474431428E-2,
  1.25352743901058953537E0,
  5.47097740330417105182E0,
  8.76190883237069594232E0,
  5.30605288235394617618E0,
  1.00000000000000000218E0], 'd')
  
DR1 = 5.783185962946784521175995758455807035071
DR2 = 30.47126234366208639907816317502275584842

RP = np.array([
-4.79443220978201773821E9,
 1.95617491946556577543E12,
-2.49248344360967716204E14,
 9.70862251047306323952E15], 'd')

RQ = np.array([
    # 1.00000000000000000000E0,
 4.99563147152651017219E2,
 1.73785401676374683123E5,
 4.84409658339962045305E7,
 1.11855537045356834862E10,
 2.11277520115489217587E12,
 3.10518229857422583814E14,
 3.18121955943204943306E16,
 1.71086294081043136091E18], 'd')

QP = np.array([
-1.13663838898469149931E-2,
-1.28252718670509318512E0,
-1.95539544257735972385E1,
-9.32060152123768231369E1,
-1.77681167980488050595E2,
-1.47077505154951170175E2,
-5.14105326766599330220E1,
-6.05014350600728481186E0], 'd')

QQ = np.array([
    # 1.00000000000000000000E0,
  6.43178256118178023184E1,
  8.56430025976980587198E2,
  3.88240183605401609683E3,
  7.24046774195652478189E3,
  5.93072701187316984827E3,
  2.06209331660327847417E3,
  2.42005740240291393179E2], 'd')

NPY_PI_4 = .78539816339744830962
SQ2OPI  = .79788456080286535587989

@jit('f8(f8)')
def j0(x):
    if (x < 0):
        x = -x

    if (x <= 5.0):
        z = x * x
        if (x < 1.0e-5):
            return (1.0 - z / 4.0)
        p = (z-DR1) * (z-DR2)
        p = p * polevl(z, RP) / polevl(z, RQ)
        return p
    
    w = 5.0 / x
    q = 25.0 / (x*x)
    p = polevl(q, PP) / polevl(q, PQ)
    q = polevl(q, QP) / p1evl(q, QQ)
    xn = x - NPY_PI_4
    p = p*math.cos(xn) - w * q * math.sin(xn)
    return p * SQ2OPI / math.sqrt(x)


x = np.arange(10000, dtype='i8')
y = np.arange(10000, dtype='i8')
print(sum(x, y))

Mandelbrot

# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
from numba import jit
import numpy as np
from pylab import imshow, jet, show, ion

@jit
def mandel(x, y, max_iters):
    """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
    """
    i = 0
    c = complex(x,y)
    z = 0.0j
    for i in range(max_iters):
        z = z*z + c
        if (z.real*z.real + z.imag*z.imag) >= 4:
            return i

    return 255

@jit
def create_fractal(min_x, max_x, min_y, max_y, image, iters):
    height = image.shape[0]
    width = image.shape[1]

    pixel_size_x = (max_x - min_x) / width
    pixel_size_y = (max_y - min_y) / height
    for x in range(width):
        real = min_x + x * pixel_size_x
        for y in range(height):
            imag = min_y + y * pixel_size_y
            color = mandel(real, imag, iters)
            image[y, x] = color

    return image

image = np.zeros((500, 750), dtype=np.uint8)
imshow(create_fractal(-2.0, 1.0, -1.0, 1.0, image, 20))
jet()
ion()
show()

Filterbank Correlation

# -*- coding: utf-8 -*-

"""
This file demonstrates a filterbank correlation loop.
"""
from __future__ import print_function, division, absolute_import
import numpy as np
import numba
from numba.utils import IS_PY3
from numba.decorators import jit
nd4type = numba.double[:,:,:,:]

if IS_PY3:
    xrange = range

@jit(argtypes=(nd4type, nd4type, nd4type))
def fbcorr(imgs, filters, output):
    n_imgs, n_rows, n_cols, n_channels = imgs.shape
    n_filters, height, width, n_ch2 = filters.shape

    for ii in range(n_imgs):
        for rr in range(n_rows - height + 1):
            for cc in range(n_cols - width + 1):
                for hh in xrange(height):
                    for ww in xrange(width):
                        for jj in range(n_channels):
                            for ff in range(n_filters):
                                imgval = imgs[ii, rr + hh, cc + ww, jj]
                                filterval = filters[ff, hh, ww, jj]
                                output[ii, ff, rr, cc] += imgval * filterval

def main ():
    imgs = np.random.randn(10, 64, 64, 3)
    filt = np.random.randn(6, 5, 5, 3)
    output = np.zeros((10, 60, 60, 6))

    import time
    t0 = time.time()
    fbcorr(imgs, filt, output)
    print(time.time() - t0)

if __name__ == "__main__":
    main()

Multi threading

# -*- coding: utf-8 -*-
"""
Example of multithreading by releasing the GIL through ctypes.
"""
from __future__ import print_function, division, absolute_import

from timeit import repeat
import threading
from ctypes import pythonapi, c_void_p
from math import exp

import numpy as np
from numba import jit, void, double

nthreads = 2
size = 1e6

def timefunc(correct, s, func, *args, **kwargs):
    print(s.ljust(20), end=" ")
    # Make sure the function is compiled before we start the benchmark
    res = func(*args, **kwargs)
    if correct is not None:
        assert np.allclose(res, correct)
    # time it
    print('{:>5.0f} ms'.format(min(repeat(lambda: func(*args, **kwargs),
                                          number=5, repeat=2)) * 1000))
    return res

def make_singlethread(inner_func):
    def func(*args):
        length = len(args[0])
        result = np.empty(length, dtype=np.float64)
        inner_func(result, *args)
        return result
    return func

def make_multithread(inner_func, numthreads):
    def func_mt(*args):
        length = len(args[0])
        result = np.empty(length, dtype=np.float64)
        args = (result,) + args        
        chunklen = (length + 1) // numthreads
        chunks = [[arg[i * chunklen:(i + 1) * chunklen] for arg in args]
                  for i in range(numthreads)]

        # You should make sure inner_func is compiled at this point, because
        # the compilation must happen on the main thread. This is the case
        # in this example because we use jit().
        threads = [threading.Thread(target=inner_func, args=chunk)
                   for chunk in chunks[:-1]]
        for thread in threads:
            thread.start()

        # the main thread handles the last chunk
        inner_func(*chunks[-1])

        for thread in threads:
            thread.join()
        return result
    return func_mt
  
savethread = pythonapi.PyEval_SaveThread
savethread.argtypes = []
savethread.restype = c_void_p

restorethread = pythonapi.PyEval_RestoreThread
restorethread.argtypes = [c_void_p]
restorethread.restype = None

def inner_func(result, a, b):
    threadstate = savethread()
    for i in range(len(result)):
        result[i] = exp(2.1 * a[i] + 3.2 * b[i])
    restorethread(threadstate)

signature = void(double[:], double[:], double[:])
inner_func_nb = jit(signature, nopython=True)(inner_func)
func_nb = make_singlethread(inner_func_nb)
func_nb_mt = make_multithread(inner_func_nb, nthreads)
            
def func_np(a, b):
    return np.exp(2.1 * a + 3.2 * b)

a = np.random.rand(size)
b = np.random.rand(size)
c = np.random.rand(size)

correct = timefunc(None, "numpy (1 thread)", func_np, a, b)
timefunc(correct, "numba (1 thread)", func_nb, a, b)
timefunc(correct, "numba (%d threads)" % nthreads, func_nb_mt, a, b)