Writing CUDA-Python

The CUDA JIT is a low-level entry point to the CUDA features in Numba. It translates Python functions into PTX code which execute on the CUDA hardware. The jit decorator is applied to Python functions written in our Python dialect for CUDA. Numba interacts with the CUDA Driver API to load the PTX onto the CUDA device and execute.

Imports

Most of the CUDA public API for CUDA features are exposed in the numba.cuda module:

from numba import cuda

Compiling

CUDA kernels and device functions are compiled by decorating a Python function with the jit or autojit decorators.

numba.cuda.jit(restype=None, argtypes=None, device=False, inline=False, bind=True, link=[], debug=False, **kws)

JIT compile a python function conforming to the CUDA-Python specification.

To define a CUDA kernel that takes two int 1D-arrays:

@cuda.jit('void(int32[:], int32[:])')
def foo(aryA, aryB):
    ...

Note

A kernel cannot have any return value.

To launch the cuda kernel:

griddim = 1, 2
blockdim = 3, 4
foo[griddim, blockdim](aryA, aryB)

griddim is the number of thread-block per grid. It can be:

  • an int;
  • tuple-1 of ints;
  • tuple-2 of ints.

blockdim is the number of threads per block. It can be:

  • an int;
  • tuple-1 of ints;
  • tuple-2 of ints;
  • tuple-3 of ints.

The above code is equaivalent to the following CUDA-C.

dim3 griddim(1, 2);
dim3 blockdim(3, 4);
foo<<<griddim, blockdim>>>(aryA, aryB);

To access the compiled PTX code:

print foo.ptx

To define a CUDA device function that takes two ints and returns a int:

@cuda.jit('int32(int32, int32)', device=True)
def bar(a, b):
    ...

To force inline the device function:

@cuda.jit('int32(int32, int32)', device=True, inline=True)
def bar_forced_inline(a, b):
    ...

A device function can only be used inside another kernel. It cannot be called from the host.

Using bar in a CUDA kernel:

@cuda.jit('void(int32[:], int32[:], int32[:])')
def use_bar(aryA, aryB, aryOut):
    i = cuda.grid(1) # global position of the thread for a 1D grid.
    aryOut[i] = bar(aryA[i], aryB[i])
numba.cuda.autojit(func, **kws)

JIT at callsite. Function signature is not needed as this will capture the type at call time. Each signature of the kernel is cached for future use.

Note

Can only compile CUDA kernel.

Example:

import numpy

@cuda.autojit
def foo(aryA, aryB):
    ...

aryA = numpy.arange(10, dtype=np.int32)
aryB = numpy.arange(10, dtype=np.float32)
foo[griddim, blockdim](aryA, aryB)

In the above code, a version of foo with the signature “void(int32[:], float32[:])” is compiled.

Thread Identity by CUDA Intrinsics

A set of CUDA intrinsics is used to identify the current execution thread. These intrinsics are meaningful inside a CUDA kernel or device function only. A common pattern to assign the computation of each element in the output array to a thread.

For a 1D grid:

tx = cuda.threadIdx.x
bx = cuda.blockIdx.x
bw = cuda.blockDim.x
i = tx + bx * bw
array[i] = something(i)

For a 2D grid:

tx = cuda.threadIdx.x
ty = cuda.threadIdx.y
bx = cuda.blockIdx.x
by = cuda.blockIdx.y
bw = cuda.blockDim.x
bh = cuda.blockDim.y
x = tx + bx * bw
y = ty + by * bh
array[x, y] = something(x, y)

Since these patterns are so common, there is a shorthand function to produce the same result.

For a 1D grid:

i = cuda.grid(1)
array[i] = something(i)

For a 2D grid:

x, y = cuda.grid(2)
array[x, y] = something(x, y)

Memory Transfer

By default, any NumPy arrays used as argument of a CUDA kernel is transferred automatically to and from the device. However, to achieve maximum performance and minimizing redundant memory transfer, user should manage the memory transfer explicitly.

Host->device transfers are asynchronous to the host. Device->host transfers are synchronous to the host. If a non-zero CUDA stream is provided, the transfer becomes asynchronous.

numba.cuda.to_device(ary, stream=0, copy=True, to=None)

Allocate and transfer a numpy ndarray to the device.

To copy host->device a numpy array:

ary = numpy.arange(10)
d_ary = cuda.to_device(ary)

To enqueue the transfer to a stream:

stream = cuda.stream()
d_ary = cuda.to_device(ary, stream=stream)

The resulting d_ary is a DeviceNDArray.

To copy device->host:

hary = d_ary.copy_to_host()

To copy device->host to an existing array:

ary = numpy.empty(shape=d_ary.shape, dtype=d_ary.dtype)
d_ary.copy_to_host(ary)

To enqueue the transfer to a stream:

hary = d_ary.copy_to_host(stream=stream)
DeviceNDArray.copy_to_host(ary=None, stream=0)

Copy self to ary or create a new numpy ndarray if ary is None.

Always returns the host array.

The following are special DeviceNDArray factories:

numba.cuda.device_array(shape, dtype=np.float, strides=None, order='C', stream=0)

Allocate an empty device ndarray. Similar to numpy.empty()

numba.cuda.pinned_array(shape, dtype=np.float, strides=None, order='C')

Allocate a numpy.ndarray with a buffer that is pinned (pagelocked). Similar to numpy.empty().

numba.cuda.mapped_array(shape, dtype=np.float, strides=None, order='C', stream=0, portable=False, wc=False)

Allocate a mapped ndarray with a buffer that is pinned and mapped on to the device. Similar to numpy.empty()

Parameters:
  • portable – a boolean flag to allow the allocated device memory to be usable in multiple devices.
  • wc – a boolean flag to enable writecombined allocation which is faster to write by the host and to read by the device, but slower to write by the host and slower to write by the device.

Memory Lifetime

The live time of a device array is bound to the lifetime of the DeviceNDArray instance.

CUDA Stream

A CUDA stream is a command queue for the CUDA device. By specifying a stream, the CUDA API calls become asynchronous, meaning that the call may return before the command has been completed. Memory transfer instructions and kernel invocation can use CUDA stream:

stream = cuda.stream()
devary = cuda.to_device(an_array, stream=stream)
a_cuda_kernel[griddim, blockdim, stream](devary)
devary.copy_to_host(an_array, stream=stream)
# data may not be available in an_array
stream.synchronize()
# data available in an_array
numba.cuda.stream()

Create a CUDA stream that represents a command queue for the device.

An alternative syntax is available for use with a python context:

stream = cuda.stream()
with stream.auto_synchronize():
    devary = cuda.to_device(an_array, stream=stream)
    a_cuda_kernel[griddim, blockdim, stream](devary)
    devary.copy_to_host(an_array, stream=stream)
# data available in an_array

When the python with context exits, the stream is automatically synchronized.

Shared Memory

For maximum performance, a CUDA kernel needs to use shared memory for manual caching of data. CUDA JIT supports the use of cuda.shared.array(shape, dtype) for specifying an NumPy-array-like object inside a kernel.

For example::

bpg = 50
tpb = 32
n = bpg * tpb

@jit(argtypes=[float32[:,:], float32[:,:], float32[:,:]], target='gpu')
def cu_square_matrix_mul(A, B, C):
    sA = cuda.shared.array(shape=(tpb, tpb), dtype=float32)
    sB = cuda.shared.array(shape=(tpb, tpb), dtype=float32)

    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    bx = cuda.blockIdx.x
    by = cuda.blockIdx.y
    bw = cuda.blockDim.x
    bh = cuda.blockDim.y

    x = tx + bx * bw
    y = ty + by * bh

    acc = 0.
    for i in range(bpg):
        if x < n and y < n:
            sA[ty, tx] = A[y, tx + i * tpb]
            sB[ty, tx] = B[ty + i * tpb, x]

        cuda.syncthreads()

        if x < n and y < n:
            for j in range(tpb):
                acc += sA[ty, j] * sB[j, tx]

        cuda.syncthreads()

    if x < n and y < n:
        C[y, x] = acc

The equivalent code in CUDA-C would be:

#define pos2d(Y, X, W) ((Y) * (W) + (X))

const unsigned int BPG = 50;
const unsigned int TPB = 32;
const unsigned int N = BPG * TPB;

__global__
void cuMatrixMul(const float A[], const float B[], float C[]){
    __shared__ float sA[TPB * TPB];
    __shared__ float sB[TPB * TPB];

    unsigned int tx = threadIdx.x;
    unsigned int ty = threadIdx.y;
    unsigned int bx = blockIdx.x;
    unsigned int by = blockIdx.y;
    unsigned int bw = blockDim.x;
    unsigned int bh = blockDim.y;

    unsigned int x = tx + bx * bw;
    unsigned int y = ty + by * bh;

    float acc = 0.0;

    for (int i = 0; i < BPG; ++i) {
        if (x < N and y < N) {
            sA[pos2d(ty, tx, TPB)] = A[pos2d(y, tx + i * TPB, N)];
            sB[pos2d(ty, tx, TPB)] = B[pos2d(ty + i * TPB, x, N)];
        }
        __syncthreads();
        if (x < N and y < N) {
            for (int j = 0; j < TPB; ++j) {
                acc += sA[pos2d(ty, j, TPB)] * sB[pos2d(j, tx, TPB)];
            }
        }
        __syncthreads();
    }

    if (x < N and y < N) {
        C[pos2d(y, x, N)] = acc;
    }
}

The return value of cuda.shared.array is a NumPy-array-like object. The shape argument is similar as in NumPy API, with the requirement that it must contain a constant expression. The dtype argument takes Numba types.

Synchronization Primitives

We currently support cuda.syncthreads() only. It is the same as __syncthreads() in CUDA-C.