Demo 2: CUDA Backend for Tanh Approximation in GELU Activation Layer¶
(Depends on Ch.08)
This demo notebook shows how to use a CUDA backend to accelerate the GELU activation function using a Pade44 rational approximation for tanh. We build on the previous demo and show how to offload the computation to the GPU using Numba and a custom backend.
The notebook demonstrates:
- How to configure and use a GPU backend for vectorized ufuncs
- How to run and test the optimized GELU function on CUDA
- How to compare results with the original NumPy implementation
from numba import cuda
from ch08_gpu_offload import GPUBackend
from ch08_gpu_offload import gpu_compiler_config as _ch08_gpu_compiler_config
from demo01_gelu_tanh_approx import *
from utils.report import Report
Setup GPU Backend¶
Define a backend that combines the ufunc backend and GPU backend, enabling compilation and execution of vectorized functions on CUDA devices.
class GpuUfuncBackend(Backend, GPUBackend):
# Ufunc + GPU backend
def __init__(self, compile_only: bool = False):
GPUBackend.__init__(self, compile_only)
gpu_compiler_config = {
**_ch08_gpu_compiler_config,
"converter_class": ExtendEGraphToRVSDG,
"cost_model": MyCostModel(),
"backend": GpuUfuncBackend(compile_only=not cuda.is_available()),
}
Configure the CUDA Ufunc Pipeline¶
Set up the pipeline to compile the GELU function as a CUDA-accelerated vectorized ufunc, using the GPU backend and the Pade44 tanh approximation.
report = Report("Pipeline execution report", enable_nested_metadata=True)
cuda_vectorized_gelu = ufunc_vectorize(
input_type=Float32,
ndim=1,
compiler_config={
**gpu_compiler_config,
"pipeline_report": report,
"pipeline_debug": True,
},
extra_ruleset=additional_rules | optimize_rules,
)(gelu_tanh_forward)
if __name__ == "__main__":
report.display()
Pipeline execution report
_Region_1 = Region("804", InPorts(Vec[String]("!io", "a"))) GraphRoot( Term.Func( "1274", "transformed_gelu_tanh_forward", Term.RegionEnd( _Region_1, PortList( Vec[Port]( Port("!io", _Region_1.get(0)), Port( "!ret", Nb_Mul_Float32( Nb_Mul_Float32(Npy_cast_f64_to_f32(Term.LiteralF64(0.5)), _Region_1.get(1)), Nb_Add_Float32( Npy_cast_i64_to_f32(Term.LiteralI64(1)), Npy_tanh_float32( Nb_Mul_Float32( Npy_sqrt_float32(Nb_Div_Float32(Npy_cast_i64_to_f32(Term.LiteralI64(2)), Npy_cast_f64_to_f32(Term.LiteralF64(3.141592653589793)))), Nb_Add_Float32( _Region_1.get(1), Nb_Mul_Float32(Npy_cast_f64_to_f32(Term.LiteralF64(0.044715)), Nb_Pow_Float32_Int64(_Region_1.get(1), Term.LiteralI64(3))), ), ) ), ), ), ), ) ), ), ) )
time elapsed 269.30ms timing breakdown: 46.46ms: [debug] initial egraph 194.40ms: [debug] saturated egraph 28.44ms: [debug] egglog.extract
module { func.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %cst = arith.constant 5.000000e-01 : f64 %c1_i64 = arith.constant 1 : i64 %cst_0 = arith.constant 1.000000e+01 : f64 %c2_i64 = arith.constant 2 : i64 %cst_1 = arith.constant 3.1415926535897931 : f64 %cst_2 = arith.constant 4.471500e-02 : f64 %cst_3 = arith.constant 1.000000e+00 : f64 %cst_4 = arith.constant 1.050000e+02 : f64 %cst_5 = arith.constant 4.500000e+01 : f64 cf.br ^bb1 ^bb1: // pred: ^bb0 %c0_i32 = arith.constant 0 : i32 %0 = arith.truncf %cst : f64 to f32 %1 = arith.mulf %0, %arg0 : f32 %2 = arith.sitofp %c1_i64 : i64 to f32 %3 = arith.truncf %cst_0 : f64 to f32 %4 = arith.sitofp %c2_i64 : i64 to f32 %5 = arith.truncf %cst_1 : f64 to f32 %6 = arith.divf %4, %5 : f32 %7 = math.sqrt %6 : f32 %8 = arith.truncf %cst_2 : f64 to f32 %9 = arith.truncf %cst_3 : f64 to f32 %10 = arith.mulf %arg0, %9 : f32 %11 = arith.mulf %arg0, %10 : f32 %12 = arith.mulf %arg0, %11 : f32 %13 = arith.mulf %8, %12 : f32 %14 = arith.addf %arg0, %13 : f32 %15 = arith.mulf %7, %14 : f32 %16 = arith.mulf %15, %9 : f32 %17 = arith.mulf %15, %16 : f32 %18 = arith.mulf %15, %17 : f32 %19 = arith.mulf %3, %18 : f32 %20 = arith.truncf %cst_4 : f64 to f32 %21 = arith.mulf %20, %15 : f32 %22 = arith.addf %19, %21 : f32 %23 = arith.mulf %15, %18 : f32 %24 = arith.truncf %cst_5 : f64 to f32 %25 = arith.mulf %24, %17 : f32 %26 = arith.addf %23, %25 : f32 %27 = arith.addf %26, %20 : f32 %28 = arith.divf %22, %27 : f32 %29 = arith.addf %2, %28 : f32 %30 = arith.mulf %1, %29 : f32 return %30 : f32 } }
time elapsed 3.23ms timing breakdown: 3.23ms: Lowered module
#map = affine_map<(d0)[s0, s1] -> ((d0 - s0) ceildiv s1)> #map1 = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> module attributes {gpu.container_module} { llvm.func @sqrtf(f32) -> f32 attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, sym_visibility = "private"} llvm.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %0 = llvm.mlir.constant(4.500000e+01 : f32) : f32 %1 = llvm.mlir.constant(1.050000e+02 : f32) : f32 %2 = llvm.mlir.constant(2.000000e+00 : f32) : f32 %3 = llvm.mlir.constant(1.000000e+01 : f32) : f32 %4 = llvm.mlir.constant(1.000000e+00 : f32) : f32 %5 = llvm.mlir.constant(5.000000e-01 : f32) : f32 %6 = llvm.mlir.constant(3.1415926535897931 : f64) : f64 %7 = llvm.mlir.constant(4.471500e-02 : f64) : f64 %8 = llvm.fmul %arg0, %5 : f32 %9 = llvm.fptrunc %6 : f64 to f32 %10 = llvm.fdiv %2, %9 : f32 %11 = llvm.call @sqrtf(%10) : (f32) -> f32 %12 = llvm.fptrunc %7 : f64 to f32 %13 = llvm.fmul %arg0, %arg0 : f32 %14 = llvm.fmul %arg0, %13 : f32 %15 = llvm.fmul %12, %14 : f32 %16 = llvm.fadd %arg0, %15 : f32 %17 = llvm.fmul %11, %16 : f32 %18 = llvm.fmul %17, %17 : f32 %19 = llvm.fmul %17, %18 : f32 %20 = llvm.fmul %19, %3 : f32 %21 = llvm.fmul %17, %1 : f32 %22 = llvm.fadd %20, %21 : f32 %23 = llvm.fmul %17, %19 : f32 %24 = llvm.fmul %18, %0 : f32 %25 = llvm.fadd %23, %24 : f32 %26 = llvm.fadd %25, %1 : f32 %27 = llvm.fdiv %22, %26 : f32 %28 = llvm.fadd %27, %4 : f32 %29 = llvm.fmul %8, %28 : f32 llvm.return %29 : f32 } llvm.func @_mlir_ciface_func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %0 = llvm.call @func(%arg0) : (f32) -> f32 llvm.return %0 : f32 } llvm.func @ufunc(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64) attributes {llvm.emit_c_interface} { %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %1 = llvm.insertvalue %arg5, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %2 = llvm.insertvalue %arg6, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %3 = llvm.insertvalue %arg7, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %4 = llvm.insertvalue %arg8, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %5 = llvm.insertvalue %arg9, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %6 = builtin.unrealized_conversion_cast %5 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xf32> %7 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %8 = llvm.insertvalue %arg0, %7[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %9 = llvm.insertvalue %arg1, %8[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %10 = llvm.insertvalue %arg2, %9[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %11 = llvm.insertvalue %arg3, %10[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %12 = llvm.insertvalue %arg4, %11[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %13 = builtin.unrealized_conversion_cast %12 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xf32> %14 = llvm.mlir.constant(4.471500e-02 : f64) : f64 %15 = llvm.mlir.constant(3.1415926535897931 : f64) : f64 %16 = llvm.mlir.constant(5.000000e-01 : f32) : f32 %17 = llvm.mlir.constant(1.000000e+00 : f32) : f32 %18 = llvm.mlir.constant(1.000000e+01 : f32) : f32 %19 = llvm.mlir.constant(2.000000e+00 : f32) : f32 %20 = llvm.mlir.constant(1.050000e+02 : f32) : f32 %21 = llvm.mlir.constant(4.500000e+01 : f32) : f32 %22 = llvm.mlir.constant(0 : index) : i64 %23 = builtin.unrealized_conversion_cast %arg3 : i64 to index %24 = llvm.mlir.constant(0 : index) : i64 %25 = builtin.unrealized_conversion_cast %24 : i64 to index %26 = llvm.mlir.constant(1 : index) : i64 %27 = builtin.unrealized_conversion_cast %26 : i64 to index %28 = llvm.mlir.constant(1 : index) : i64 %29 = builtin.unrealized_conversion_cast %28 : i64 to index %30 = affine.apply #map(%23)[%25, %27] gpu.launch_func @ufunc_kernel::@ufunc_kernel blocks in (%30, %29, %29) threads in (%29, %29, %29) args(%27 : index, %25 : index, %13 : memref<?xf32>, %16 : f32, %15 : f64, %19 : f32, %14 : f64, %18 : f32, %20 : f32, %21 : f32, %17 : f32, %6 : memref<?xf32>) llvm.return } llvm.func @_mlir_ciface_ufunc(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {llvm.emit_c_interface} { %0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %4 = llvm.extractvalue %0[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %5 = llvm.extractvalue %0[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %6 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %7 = llvm.extractvalue %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %8 = llvm.extractvalue %6[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %9 = llvm.extractvalue %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %10 = llvm.extractvalue %6[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %11 = llvm.extractvalue %6[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> llvm.call @ufunc(%1, %2, %3, %4, %5, %7, %8, %9, %10, %11) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> () llvm.return } gpu.module @ufunc_kernel { llvm.func @sqrtf(f32) -> f32 attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, sym_visibility = "private"} gpu.func @ufunc_kernel(%arg0: index, %arg1: index, %arg2: memref<?xf32>, %arg3: f32, %arg4: f64, %arg5: f32, %arg6: f64, %arg7: f32, %arg8: f32, %arg9: f32, %arg10: f32, %arg11: memref<?xf32>) kernel attributes {known_block_size = array<i32: 1, 1, 1>} { %0 = builtin.unrealized_conversion_cast %arg11 : memref<?xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %1 = builtin.unrealized_conversion_cast %arg2 : memref<?xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %block_id_x = gpu.block_id x %block_id_y = gpu.block_id y %block_id_z = gpu.block_id z %thread_id_x = gpu.thread_id x %thread_id_y = gpu.thread_id y %thread_id_z = gpu.thread_id z %grid_dim_x = gpu.grid_dim x %grid_dim_y = gpu.grid_dim y %grid_dim_z = gpu.grid_dim z %block_dim_x = gpu.block_dim x %block_dim_y = gpu.block_dim y %block_dim_z = gpu.block_dim z %2 = affine.apply #map1(%block_id_x)[%arg0, %arg1] %3 = builtin.unrealized_conversion_cast %2 : index to i64 %4 = llvm.extractvalue %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %5 = llvm.getelementptr %4[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %6 = llvm.load %5 : !llvm.ptr -> f32 %7 = llvm.fmul %6, %arg3 : f32 %8 = llvm.fptrunc %arg4 : f64 to f32 %9 = llvm.fdiv %arg5, %8 : f32 %10 = llvm.call @sqrtf(%9) : (f32) -> f32 %11 = llvm.fptrunc %arg6 : f64 to f32 %12 = llvm.fmul %6, %6 : f32 %13 = llvm.fmul %6, %12 : f32 %14 = llvm.fmul %11, %13 : f32 %15 = llvm.fadd %6, %14 : f32 %16 = llvm.fmul %10, %15 : f32 %17 = llvm.fmul %16, %16 : f32 %18 = llvm.fmul %16, %17 : f32 %19 = llvm.fmul %18, %arg7 : f32 %20 = llvm.fmul %16, %arg8 : f32 %21 = llvm.fadd %19, %20 : f32 %22 = llvm.fmul %16, %18 : f32 %23 = llvm.fmul %17, %arg9 : f32 %24 = llvm.fadd %22, %23 : f32 %25 = llvm.fadd %24, %arg8 : f32 %26 = llvm.fdiv %21, %25 : f32 %27 = llvm.fadd %26, %arg10 : f32 %28 = llvm.fmul %7, %27 : f32 %29 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %30 = llvm.getelementptr %29[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 llvm.store %28, %30 : f32, !llvm.ptr gpu.return } } }
time elapsed 6.50ms timing breakdown: 6.50ms: MLIR optimized
Test GELU Ufunc on CUDA¶
Run the compiled CUDA ufunc on a random input and compare the result to the original NumPy implementation. If CUDA is unavailable, skip the test.
if __name__ == "__main__":
if not cuda.is_available():
print("SKIPPED. CUDA unavailable")
else:
relclose = lambda x, y: np.allclose(x, y, rtol=1e-6)
input_val = np.random.random(100).astype(np.float32)
report.display()
run_test(
gelu_tanh_forward,
cuda_vectorized_gelu,
(input_val,),
equal=relclose,
verbose=True,
)
SKIPPED. CUDA unavailable
Benchmark¶
if __name__ == "__main__":
if not cuda.is_available():
print("SKIPPED. CUDA unavailable")
else:
input_val = np.random.random(300000).astype(np.float32)
out = np.zeros_like(input_val)
print("original")
%timeit gelu_tanh_forward(input_val)
print("superoptimized")
%timeit cuda_vectorized_gelu(input_val, out=out)
SKIPPED. CUDA unavailable