Demo 2: CUDA Backend for Tanh Approximation in GELU Activation Layer¶
(Depends on Ch.08)
This demo notebook shows how to use a CUDA backend to accelerate the GELU activation function using a Pade44 rational approximation for tanh. We build on the previous demo and show how to offload the computation to the GPU using Numba and a custom backend.
The notebook demonstrates:
- How to configure and use a GPU backend for vectorized ufuncs
- How to run and test the optimized GELU function on CUDA
- How to compare results with the original NumPy implementation
from numba import cuda
from ch08_gpu_offload import GPUBackend
from ch08_gpu_offload import gpu_compiler_config as _ch08_gpu_compiler_config
from demo01_gelu_tanh_approx import *
from utils.report import Report
Setup GPU Backend¶
Define a backend that combines the ufunc backend and GPU backend, enabling compilation and execution of vectorized functions on CUDA devices.
class GpuUfuncBackend(Backend, GPUBackend):
# Ufunc + GPU backend
def __init__(self, compile_only: bool = False):
GPUBackend.__init__(self, compile_only)
gpu_compiler_config = {
**_ch08_gpu_compiler_config,
"converter_class": ExtendEGraphToRVSDG,
"cost_model": MyCostModel(),
"backend": GpuUfuncBackend(compile_only=not cuda.is_available()),
}
Configure the CUDA Ufunc Pipeline¶
Set up the pipeline to compile the GELU function as a CUDA-accelerated vectorized ufunc, using the GPU backend and the Pade44 tanh approximation.
report = Report("Pipeline execution report", enable_nested_metadata=True)
cuda_vectorized_gelu = ufunc_vectorize(
input_type=Float32,
ndim=1,
compiler_config={**gpu_compiler_config, "pipeline_report": report, "pipeline_debug": True},
extra_ruleset=additional_rules | optimize_rules,
)(gelu_tanh_forward)
if __name__ == "__main__":
report.display()
Pipeline execution report
_Region_1 = Region("804", InPorts(Vec[String]("!io", "a"))) GraphRoot( Term.Func( "1274", "transformed_gelu_tanh_forward", Term.RegionEnd( _Region_1, PortList( Vec[Port]( Port("!io", _Region_1.get(0)), Port( "!ret", Nb_Mul_Float32( Nb_Mul_Float32(Npy_cast_f64_to_f32(Term.LiteralF64(0.5)), _Region_1.get(1)), Nb_Add_Float32( Npy_cast_i64_to_f32(Term.LiteralI64(1)), Npy_tanh_float32( Nb_Mul_Float32( Npy_sqrt_float32(Nb_Div_Float32(Npy_cast_i64_to_f32(Term.LiteralI64(2)), Npy_cast_f64_to_f32(Term.LiteralF64(3.141592653589793)))), Nb_Add_Float32( _Region_1.get(1), Nb_Mul_Float32(Npy_cast_f64_to_f32(Term.LiteralF64(0.044715)), Nb_Pow_Float32_Int64(_Region_1.get(1), Term.LiteralI64(3))), ), ) ), ), ), ), ) ), ), ) )
time elapsed 365.70ms timing breakdown: 66.69ms: [debug] initial egraph 258.65ms: [debug] saturated egraph 40.36ms: [debug] egglog.extract
module { func.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %cst = arith.constant 5.000000e-01 : f64 %c1_i64 = arith.constant 1 : i64 %cst_0 = arith.constant 1.000000e+01 : f64 %c2_i64 = arith.constant 2 : i64 %cst_1 = arith.constant 3.1415926535897931 : f64 %cst_2 = arith.constant 4.471500e-02 : f64 %cst_3 = arith.constant 1.000000e+00 : f64 %cst_4 = arith.constant 1.050000e+02 : f64 %cst_5 = arith.constant 4.500000e+01 : f64 cf.br ^bb1 ^bb1: // pred: ^bb0 %c0_i32 = arith.constant 0 : i32 %0 = arith.truncf %cst : f64 to f32 %1 = arith.mulf %0, %arg0 : f32 %2 = arith.sitofp %c1_i64 : i64 to f32 %3 = arith.truncf %cst_0 : f64 to f32 %4 = arith.sitofp %c2_i64 : i64 to f32 %5 = arith.truncf %cst_1 : f64 to f32 %6 = arith.divf %4, %5 : f32 %7 = math.sqrt %6 : f32 %8 = arith.truncf %cst_2 : f64 to f32 %9 = arith.truncf %cst_3 : f64 to f32 %10 = arith.mulf %arg0, %9 : f32 %11 = arith.mulf %arg0, %10 : f32 %12 = arith.mulf %arg0, %11 : f32 %13 = arith.mulf %8, %12 : f32 %14 = arith.addf %arg0, %13 : f32 %15 = arith.mulf %7, %14 : f32 %16 = arith.mulf %15, %9 : f32 %17 = arith.mulf %15, %16 : f32 %18 = arith.mulf %15, %17 : f32 %19 = arith.mulf %3, %18 : f32 %20 = arith.truncf %cst_4 : f64 to f32 %21 = arith.mulf %20, %15 : f32 %22 = arith.addf %19, %21 : f32 %23 = arith.mulf %15, %18 : f32 %24 = arith.truncf %cst_5 : f64 to f32 %25 = arith.mulf %24, %17 : f32 %26 = arith.addf %23, %25 : f32 %27 = arith.addf %26, %20 : f32 %28 = arith.divf %22, %27 : f32 %29 = arith.addf %2, %28 : f32 %30 = arith.mulf %1, %29 : f32 return %30 : f32 } }
time elapsed 3.42ms timing breakdown: 3.42ms: Lowered module
module attributes {gpu.container_module} { llvm.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %0 = llvm.mlir.constant(4.500000e+01 : f32) : f32 %1 = llvm.mlir.constant(1.050000e+02 : f32) : f32 %2 = llvm.mlir.constant(2.000000e+00 : f32) : f32 %3 = llvm.mlir.constant(1.000000e+01 : f32) : f32 %4 = llvm.mlir.constant(1.000000e+00 : f32) : f32 %5 = llvm.mlir.constant(5.000000e-01 : f32) : f32 %6 = llvm.mlir.constant(3.1415926535897931 : f64) : f64 %7 = llvm.mlir.constant(4.471500e-02 : f64) : f64 %8 = llvm.fmul %arg0, %5 : f32 %9 = llvm.fptrunc %6 : f64 to f32 %10 = llvm.fdiv %2, %9 : f32 %11 = llvm.intr.sqrt(%10) : (f32) -> f32 %12 = llvm.fptrunc %7 : f64 to f32 %13 = llvm.fmul %arg0, %arg0 : f32 %14 = llvm.fmul %arg0, %13 : f32 %15 = llvm.fmul %12, %14 : f32 %16 = llvm.fadd %arg0, %15 : f32 %17 = llvm.fmul %11, %16 : f32 %18 = llvm.fmul %17, %17 : f32 %19 = llvm.fmul %17, %18 : f32 %20 = llvm.fmul %19, %3 : f32 %21 = llvm.fmul %17, %1 : f32 %22 = llvm.fadd %20, %21 : f32 %23 = llvm.fmul %17, %19 : f32 %24 = llvm.fmul %18, %0 : f32 %25 = llvm.fadd %23, %24 : f32 %26 = llvm.fadd %25, %1 : f32 %27 = llvm.fdiv %22, %26 : f32 %28 = llvm.fadd %27, %4 : f32 %29 = llvm.fmul %8, %28 : f32 llvm.return %29 : f32 } llvm.func @_mlir_ciface_func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %0 = llvm.call @func(%arg0) : (f32) -> f32 llvm.return %0 : f32 } llvm.func @ufunc(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64) attributes {llvm.emit_c_interface} { %0 = llvm.mlir.constant(-1 : index) : i64 %1 = llvm.mlir.constant(0 : index) : i64 %2 = llvm.mlir.constant(4.500000e+01 : f32) : f32 %3 = llvm.mlir.constant(1.050000e+02 : f32) : f32 %4 = llvm.mlir.constant(2.000000e+00 : f32) : f32 %5 = llvm.mlir.constant(1.000000e+01 : f32) : f32 %6 = llvm.mlir.constant(1.000000e+00 : f32) : f32 %7 = llvm.mlir.constant(5.000000e-01 : f32) : f32 %8 = llvm.mlir.constant(3.1415926535897931 : f64) : f64 %9 = llvm.mlir.constant(4.471500e-02 : f64) : f64 %10 = llvm.mlir.constant(1 : index) : i64 %11 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %12 = llvm.insertvalue %arg0, %11[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %13 = llvm.insertvalue %arg1, %12[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %14 = llvm.insertvalue %arg2, %13[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %15 = llvm.insertvalue %arg3, %14[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %16 = llvm.extractvalue %15[3] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %17 = llvm.alloca %10 x !llvm.array<1 x i64> : (i64) -> !llvm.ptr llvm.store %16, %17 : !llvm.array<1 x i64>, !llvm.ptr %18 = llvm.getelementptr %17[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<1 x i64> %19 = llvm.load %18 : !llvm.ptr -> i64 %20 = llvm.mul %1, %0 : i64 %21 = llvm.add %19, %20 : i64 %22 = llvm.icmp "sle" %21, %1 : i64 %23 = llvm.sub %1, %21 : i64 %24 = llvm.sub %21, %10 : i64 %25 = llvm.select %22, %23, %24 : i1, i64 %26 = llvm.sdiv %25, %10 : i64 %27 = llvm.sub %1, %26 : i64 %28 = llvm.add %26, %10 : i64 %29 = llvm.select %22, %27, %28 : i1, i64 gpu.launch_func @ufunc_kernel::@ufunc_kernel blocks in (%29, %10, %10) threads in (%10, %10, %10) : i64 args(%10 : i64, %1 : i64, %arg0 : !llvm.ptr, %arg1 : !llvm.ptr, %arg2 : i64, %arg3 : i64, %arg4 : i64, %7 : f32, %8 : f64, %4 : f32, %9 : f64, %5 : f32, %3 : f32, %2 : f32, %6 : f32, %arg5 : !llvm.ptr, %arg6 : !llvm.ptr, %arg7 : i64, %arg8 : i64, %arg9 : i64) llvm.return } llvm.func @_mlir_ciface_ufunc(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {llvm.emit_c_interface} { %0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %4 = llvm.extractvalue %0[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %5 = llvm.extractvalue %0[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %6 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %7 = llvm.extractvalue %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %8 = llvm.extractvalue %6[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %9 = llvm.extractvalue %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %10 = llvm.extractvalue %6[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %11 = llvm.extractvalue %6[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> llvm.call @ufunc(%1, %2, %3, %4, %5, %7, %8, %9, %10, %11) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> () llvm.return } gpu.binary @ufunc_kernel [#gpu.object<#nvvm.target, "P\EDU\BA\01\00\10\00\F0\15\00\00\00\00\00\00\02\00\01\01H\00\00\00\08\12\00\00\00\00\00\00\00\00\00\00\00\00\00\00\07\00\01\002\00\00\00\00\00\00\00\00\00\00\00\11\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\7FELF\02\01\013\07\00\00\00\00\00\00\00\02\00\BE\00|\00\00\00\00\00\00\00\00\00\00\00`\11\00\00\00\00\00\00\A0\0E\00\00\00\00\00\002\052\00@\008\00\03\00@\00\0B\00\01\00\00.shstrtab\00.strtab\00.symtab\00.symtab_shndx\00.nv.info\00.text.ufunc_kernel\00.nv.info.ufunc_kernel\00.nv.shared.ufunc_kernel\00.nv.constant2.ufunc_kernel\00.nv.constant0.ufunc_kernel\00.rel.nv.constant0.ufunc_kernel\00.nv.callgraph\00.nv.prototype\00.nv.rel.action\00\00.shstrtab\00.strtab\00.symtab\00.symtab_shndx\00.nv.info\00.text.ufunc_kernel\00.nv.info.ufunc_kernel\00.nv.shared.ufunc_kernel\00.nv.constant2.ufunc_kernel\00$__internal_0_$__cuda_sm3x_div_rn_noftz_f32_slowpath\00.rel.nv.constant0.ufunc_kernel\00.nv.constant0.ufunc_kernel\00.nv.callgraph\00.nv.prototype\00.nv.rel.action\00ufunc_kernel\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\002\00\00\00\03\00\0A\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00s\00\00\00\03\00\08\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\8E\00\00\00\22\00\0A\00P\03\00\00\00\00\00\000\04\00\00\00\00\00\00\E2\00\00\00\03\00\09\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\FD\00\00\00\03\00\06\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\19\01\00\00\03\00\07\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00(\01\00\00\12\10\0A\00\00\00\00\00\00\00\00\00\80\07\00\00\00\00\00\00\04/\08\00\07\00\00\00\10\00\00\00\04\12\08\00\03\00\00\00\00\00\00\00\04\11\08\00\03\00\00\00\00\00\00\00\04\12\08\00\07\00\00\00\00\00\00\00\04\11\08\00\07\00\00\00\00\00\00\00\04\12\08\00\07\00\00\00\00\00\00\00\047\04\00|\00\00\00\010\00\00\01*\00\00\04\0A\08\00\04\00\00\00@\01\90\00\03\19\90\00\04\17\0C\00\00\00\00\00\13\00\88\00\00\F0!\00\04\17\0C\00\00\00\00\00\12\00\80\00\00\F0!\00\04\17\0C\00\00\00\00\00\11\00x\00\00\F0!\00\04\17\0C\00\00\00\00\00\10\00p\00\00\F0!\00\04\17\0C\00\00\00\00\00\0F\00h\00\00\F0!\00\04\17\0C\00\00\00\00\00\0E\00d\00\00\F0\11\00\04\17\0C\00\00\00\00\00\0D\00`\00\00\F0\11\00\04\17\0C\00\00\00\00\00\0C\00\\\00\00\F0\11\00\04\17\0C\00\00\00\00\00\0B\00X\00\00\F0\11\00\04\17\0C\00\00\00\00\00\0A\00P\00\00\F0!\00\04\17\0C\00\00\00\00\00\09\00H\00\00\F0\11\00\04\17\0C\00\00\00\00\00\08\00@\00\00\F0!\00\04\17\0C\00\00\00\00\00\07\008\00\00\F0\11\00\04\17\0C\00\00\00\00\00\06\000\00\00\F0!\00\04\17\0C\00\00\00\00\00\05\00(\00\00\F0!\00\04\17\0C\00\00\00\00\00\04\00 \00\00\F0!\00\04\17\0C\00\00\00\00\00\03\00\18\00\00\F0!\00\04\17\0C\00\00\00\00\00\02\00\10\00\00\F0!\00\04\17\0C\00\00\00\00\00\01\00\08\00\00\F0!\00\04\17\0C\00\00\00\00\00\00\00\00\00\00\F0!\00\03\1B\FF\00\04\1D\04\00\18\00\00\00\04\1C\04\00H\03\00\00\04\05\0C\00\01\00\00\00\01\00\00\00\01\00\00\00\04\1E\04\00\10\00\00\00\00\00\00\00\FF\FF\FF\FF\00\00\00\00\FE\FF\FF\FF\00\00\00\00\FD\FF\FF\FF\00\00\00\00\FC\FF\FF\FF\00\00\00\00s\00\00\00\00\00\00\00\00\00\00\11%\00\056\00\00\80?\FF\FF\FF\7F\00\00\00\80\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\F6\07\00\FE\00|\1C\00\01\00\87\00\80\07\98L\0B\00'\06\80\07\98L\00\00W\02\00\00\C8\F0\F1\0F\22\FE@\C4\1F\08\05\00\07\05\80\7F\00N\02\00\07\05\80\7F\00N\03\00\07\05\80\7F\10O\F1\07\22\FE@\84\1F\00\06\00\07\05\80\7F\10N\07\00\07\05\80\7F0N\09\00\17\05\80\7F\00N\F1\07\22\FE@\D0\1F\08\0A\00\17\05\80\7F\10O\05\00\07\05\80\02(N\02\007\00\18\010[\F1\07 \FE\00\94\1F\00\04\00\A7\00\98\040[\03\05g\00\A0\03\C0\\\00\02'\05\00\80\10L\F1\07\A0\FE@\C4\1F\00\03\047\00\00\00\10\\\02\00'\00\00\00H8\03\037\05\00\08\10L\F5\07\C0\FE\00\88\1F\00\04\02g\05\00\80\10L\00\00'\00\C0\01\F86\05\00w\05\00\08\10L\B1\00 \FE\00\D0\1F\00\03\04\07\00\00 \D4\EE\06\00\07\06\80\07\98L\07\00\17\06\80\07\98L?\07 \E6\02\C4\1F\00\06\0Eg\00\00\00\A8\\\08\06G\00\00\00\80P\00\0Bg\00\00\00\88\\\E6\17\C0\FE\00\D8\1F\08\09\06\07\00\08\04\81Q\09\08\97\00\00\04\80Y\08\09'\06\80\7F\80I\F6\07\00\FE\00\F4\1F\00\0A\06'\06\00\04\81Q\08\09\A7\00\00\04\80Y\0F\00\08\03\00\00@\E2\F1\0F\C0\FE\00\F4\1F\04\05\00g\00\80\07\98\\\04\00'\06\80\07\98L@\00\00\1D\00\00`\E2\F6\0F \FE\01\D0\1F\00\08\00w\00\80\07\98\\\04\00G\06\80\07\98L\05\00W\06\80\07\98L\F0\07 \E2\00\C4\1F\0C\09\00\87\00\80\07\98\\\04\0EG\00\00\00\A8\\\06\037\00\00\00h\\\F4\07 \FC\00\B0\1F\00\0B\03\E7\05\00\00hL\87\09\07\80\80\03\BE6\07\03g\00\00\00h\\\F2\07\A0\E7\00\F4\\\00\09\09\08\80K\00h8\09\09W\00\00\00\80P\06\09G\00\00\00\80P\F1\0F\A0\FE\02\D8\1F\00\08\04w\00\00\00h\\\06\06\08\809\00h8\08\03\87\00\00\00X\\\F6\07\C0\FC\00\C4\1F\08\08\06\87\00\00\00h\\\05\08\87\00\00\00h\\\06\08W\00\00\00h\\\F5\07 \FE@\C4\1F\00\05\05\87\06\00\00hL\04\08g\00\00\00h\\\08\08w\06\00\00hL\F4\07@\FC\00\D0\1F\00\06\06g\06\00\00hL\04\04W\00\00\00X\\\07\06\87\00\00\00X\\\F2\07\80\E2\00\C4\1F\00\0A\04w\06\00\00XL\04\0AG\00\00\00\80P\00\07\A7\00\00\00\88\\\F6\0F\A0\FF\00\98\1F\00\05\0A\07\00\08\02\81Q\05\04W\00\00\02\80Y\04\07W\00\80\7F\80Y\F6\07\00\FE\00\F4\1F\00\06\0AG\00\80\03\81Y\04\05g\00\00\02\80Y\0F\00\08\03\00\00@\E2\F1\07\C0\FE\00\F4\1F\00\04\00w\00\80\07\98\\\05\00\A7\00\80\07\98\\@\00\00\05\00\00`\E2\F6\0F \FE\00\D4\1F\00\04\00w\00\80\07\98\\\02\02\C7\06\00\80\10L\04\04\97\06\00\00XL\F1\07@\FE\00\D0\1F\00\03\00\D7\06\00\08\10L\04\04\B7\00\00\00h\\\04\02\07\00\00 \DC\EE\FF\07 \FE\00\D4\1F\00\0F\00\07\00\00\00\00\E3\0C\05w\81\00\00\008\0A\04w\81\00\00\008\F1\07\A0\FE\00\B4\1F\00\09\0C\F7\FF\FF\FF\0F\1C\08\0A\F7\FF\FF\FF\0F\1C\07\09\D7\0F\80\03h6\ED\07\00\FE\00\F4\1F\00\07\08\D7\0F\00 h6\06\00\F8\0F\80\07\98\\\0F\00\08\10\00\00@\E2\F1\07\C0\FE\00\B4\1F\00\06\04\07\80\FF\03\CC0\07\05\07\80\FF\03\CC0\FF\06w\00\002@\\\FD\07\C0\FE\00\B4\1F\00\0F\00\006\00\00@\E2\06\00\17\00\88\07\98L\FF\05g\80<\02\E0[\FD\07 \FE\00\B4\1F\00\0F\00\082\00\00@\E2\FF\04\07\80\FF\83\CD0\97\05\07\80\FF\83\BD6\F0\07\A0\FD\00\B4\1F\00\87\04\07\80\FF\83\BD6\02\00\8A/\00\00@\E2\FF\04\17\00\080AL\ED\07\A0\FF\00\B4\1F\00\0F\A0\07!\81\03\90P\0F\00\01,\00\00@\E2\FF\05\17\00\080AL\ED\07\A0\FF\00\C4\1F\00\07\80\07!\81\03\90P\0F\00\00(\00\00@\E2\07\08\F7\0F\80\03m[\EC\07 \FE\00\C4\1F\00\0F\09\F7\0F\80\03m[\06\00\F0\0F\80\07\98\\\06\F0\08\FC\FF\FF\0F\01\F1\07\80\FE\00\98\1F\00\04\04\08\80\DF\7F\802\05\05\09\80\DF\7F\802\06\06\09\04\00\00\00\1C\F6\07@\FE\00\C0\1F\00\07\0C\07\00\00\08\EC\16\07\07W\00\00\00\12\\\05\0A\17\F8\FF\FF\0F\1C\14\07@\FE\00\C4\1F\08\08\07G\00\00\00\80P\09\07\F7\0F\000Y\\\04\05G\00\80\0B\1A\\\E3\07`\FEA\CC\1F\00\05\05\F7\07\00\06\C28\0A\08\07\00\88\04\80Q\05\05g\00\00\00\10\\\F6\07\C0\FE\00\D8\1F\00\0D\08\A7\00\00\04\80Y\0A\04\D7\00\80\7F\80Y\08\09\A7\00\00\02\80Y\FD\07\C0\FC\00\D8\1F\00\0A\0D\87\00\00\05\80Y\0F\09\A7\00\00\02\80Y\07\0D\F7\00\00\05\80Y\F6\07\C0\FE\00\D8\1F\00\04\07w\81\00\00\008\09\04W\00\00\00\10\\\04\09\F7\FF\FF\FF\0F\1C\ED\07\A0\FF\00\B4\1F\00\07\04\E7\0F\80\03l6\0F\00\88\14\00\00@\E2\07\09\E7\0F\80\03i6\FD\07\A0\FD\00\F4\1F\00\0F\00\00\11\00\00@\E2\07\09\17\00\80\03m6\0F\00\00\00\00\00 \E3\ED\07\00\FE\00\F4\1F\00\07\09\87\FE\FF\03m7\07\07\07\00\00\00\08\04\0F\00\08\00\00\00 \E3\F1\07.\FE@\C4\1F\00\04\0D\F7\00\00\05\98Y\0F\09\F7\0F\80\03k[\08\09\07\02\00\00\00\1C\F3\07\C0\FE\00\84\1F\00\05\0D\F7\00\00\05\88Y\04\04\F7\FF\FF\07\00\04\06\04\07\00\00\08 \04\F5\07 \FE\00\C4\1F\00\04\0D\F7\00\00\05\90Y\08\06\87\00\00\00H\\\07\04W\00\80\83\BD[\F4\07@\FE\00\98\1F\00\04\09\F7\0F\00\00\12\\\0F\08\F7\0F\80\00k[\04\04\F7\0F\80\04K[\F5\07 \FE\00\B0\1F\00\04\06G\00\00\00(\\\07\00\07!\80\03\90P\06\04\17\00\00\00(8\FD\07\C0\FE\00\D8\1F\00\05\FF\17\00\00\04\A08\05\05\17\00\00\03\F8<\05\05G\00\00\00G\\\F6\07\00\FE\00\FC\1F\00\05\06W\00\00\00\10\\\07\05w\00\00\02G\\\0F\00\07\00\00\00 \E3\F6\07\00\FE\00\FC\1F\00\07\07\07\00\00\00\08\04\07\07\07\00\00\F8'\04\0F\00\07\00\00\00 \E3\F0\07\E0\FF\00\D8\1F\00\07\05w\00\80\0B\18\\\0F\00\07\00\00\00 \E3\04\05'\00\08\02H\02\F0\07\E0\FF\00\C0\1F\00\07\04\07\00\00\F8'\04\0F\00\07\00\00\00 \E3\07\05'\00\08\02H\02\FF\07@\FE\00D\1C\00\0F\00\07\00\00\00 \E3\07\F0\07\00\00\FC\0F\01\07\07W\00\00\00\80P\FF\07\00\FE\00\FC\1F\00\0F\00\07\00\00\00 \E3\07\04W\00\00\10X\\\0F\00\07\00\00\00 \E3\FF\07\00\FC\00\80\1F\00\0F\00\07\FF\FF\0F@\E2\00\0F\07\00\00\00\B0P\00\0F\07\00\00\00\B0P\E0\07\00\FC\00\80\1F\00\00\0F\07\00\00\00\B0P\00\0F\07\00\00\00\B0P\00\0F\07\00\00\00\B0P\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00@\00\00\00\00\00\00\00\F3\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\0B\00\00\00\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\003\01\00\00\00\00\00\005\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\13\00\00\00\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00h\02\00\00\00\00\00\00\C0\00\00\00\00\00\00\00\02\00\00\00\07\00\00\00\08\00\00\00\00\00\00\00\18\00\00\00\00\00\00\00)\00\00\00\00\00\00p\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00(\03\00\00\00\00\00\00H\00\00\00\00\00\00\00\03\00\00\00\00\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00E\00\00\00\00\00\00p@\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00p\03\00\00\00\00\00\00\8C\01\00\00\00\00\00\00\03\00\00\00\0A\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\C8\00\00\00\01\00\00p\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\FC\04\00\00\00\00\00\00 \00\00\00\00\00\00\00\03\00\00\00\00\00\00\00\04\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00\E4\00\00\00\0B\00\00p\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00 \05\00\00\00\00\00\00\10\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00s\00\00\00\01\00\00\00B\00\00\00\00\00\00\00\00\00\00\00\00\00\00\000\05\00\00\00\00\00\00\0C\00\00\00\00\00\00\00\00\00\00\00\0A\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\8E\00\00\00\01\00\00\00B\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00<\05\00\00\00\00\00\00\D0\01\00\00\00\00\00\00\00\00\00\00\0A\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\002\00\00\00\01\00\00\00\06\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00 \07\00\00\00\00\00\00\80\07\00\00\00\00\00\00\03\00\00\00\07\00\00\10 \00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\06\00\00\00\05\00\00\00`\11\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\A8\00\00\00\00\00\00\00\A8\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00\01\00\00\00\05\00\00\000\05\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00p\09\00\00\00\00\00\00p\09\00\00\00\00\00\00\08\00\00\00\00\00\00\00\01\00\00\00\05\00\00\00`\11\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\A8\00\00\00\00\00\00\00\A8\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00\01\00\01\01X\00\00\00H\03\00\00\00\00\00\00B\03\00\00H\00\00\00\00\00\06\002\00\00\00\00\00\00\00\00\00\00\00\11 \00\00\00\00\00\00\00\00\00\00\00\00\00\00\08\09\00\00\00\00\00\00\00\00\00\00\00\00\00\00P\00\00\00\00\00\00\00\00\00\00\00\00\00\00\000\0A//\03\00\F3\1E\0A.version 6.0\0A.target sm_50\0A.address_size 64\0A1\00\F9\18isible .entry ufunc_kernel(\0A.param .u64\1A\00\11_\18\00?_0,\22\00\0D\1F1\22\00\0E\1F2\22\00\0E\1F3\22\00\0E\1F4\22\00\0E\1F5\22\00\0E\166\22\00?f32\22\00\01\177\22\00\0FD\00\03\1F8D\00\0E\1F9D\00\0E/10E\00\0E/11#\00\0F\1F2#\00\0F\1F3#\00\0F\0FY\01\0F\1F1Z\01\0F/16#\00\0F\1F7#\00\0F\1F8#\00\0F\F1\009\0A)\0A.maxntid 1,\03\00\F3\04\0A{\0A.reg .b32 %r<2>;\11\00\00\F5\00U%f<30\12\00\A6b64 %rd<13%\00\00\13\000fd<\12\002\0Ald\82\00\22.u)\00O1, [\88\00\00=0];+\00\1F2+\00\03\F4\0416];\0Acvta.to.global2\00!3,8\00\0EK\00\1F4K\00\04\0Fv\00\00\1F5+\00\03\1F3u\00\06\116u\00\815;\0Amov.u2\01\B11, %ctaid.x6\00\00+\00\03\1B\00 d7/\00\B21;\0Amul.lo.sE\00%8,\1C\00td1;\0Aadd\1A\00&9, \00d4;\0Ashlt\01310,!\00\1922\00611,\98\00'10\E3\00\02\BF\01\0F\83\01\04\127\0D\01\03\DD\00\02+\00\00\83\01\01R\00\08)\01\03\E0\01\0FG\00\04\228]\DC\00#rnD\00\223,I\008%f1E\00\01\1C\00\0F\98\01\04\129m\01\03D\00\03q\00!5,v\00\0E\89\00\0F(\02\05\00T\026div\8A\00\226,s\00=%f5\8A\00\1F7E\00\04\00\FA\00\B3sqrt.approx.\00\00\8F\01-f6E\00\1F9E\00\04\1F2\D0\00\01\01\BF\01=fd2F\00/11G\00\04\1C3[\01%12\\\01\0FG\00\00\1F3G\00\04\1D4G\00\164\A3\01\1C2\1C\00%5,\22\00\2210h\02\06\1D\00\1669\00\1C59\00$7,)\01,16\1C\00%8,\22\00,17\1D\00\179\1D\00\1B8\1D\00520,#\00\0C9\00&219\00\1B9\AA\00522,\22\00,209\00\173r\00\1C9\1D\00&4,\B2\00\1C1W\00%5,#\00,23\1D\00%6,#\00\1B9\8B\02627,\96\00\1C69\00%8,#\00,13\90\00%9,U\03\1A8\F8\03&2,\05\05\00\AC\01(st\CE\03\01\C9\03!2]9\00\C09;\0Aret;\0A\0A}\0A\00\00\00\00\00\00\00">] }
time elapsed 65.02ms timing breakdown: 65.02ms: MLIR optimized
Test GELU Ufunc on CUDA¶
Run the compiled CUDA ufunc on a random input and compare the result to the original NumPy implementation. If CUDA is unavailable, skip the test.
if __name__ == "__main__":
if not cuda.is_available():
print("SKIPPED. CUDA unavailable")
else:
relclose = lambda x, y: np.allclose(x, y, rtol=1e-6)
input_val = np.random.random(100).astype(np.float32)
run_test(
gelu_tanh_forward,
cuda_vectorized_gelu,
(input_val,),
equal=relclose,
verbose=True,
)
Testing report
(array([0.38916567, 0.01010664, 0.20098524, 0.93654305, 0.26475447, 0.02711666, 0.43591085, 0.7909023 , 0.02640386, 0.04154724, 0.65611833, 0.47337994, 0.65105635, 0.9289255 , 0.5399822 , 0.06580561, 0.64303064, 0.15780859, 0.5683575 , 0.6420226 , 0.38611475, 0.31390396, 0.5064093 , 0.15218998, 0.5781999 , 0.13255835, 0.45071495, 0.73919684, 0.51800865, 0.3687768 , 0.83240014, 0.11023618, 0.14749587, 0.89334303, 0.9199019 , 0.8821255 , 0.99873996, 0.2324525 , 0.2587263 , 0.09865194, 0.91596186, 0.86126864, 0.37732396, 0.6635682 , 0.17678984, 0.07984433, 0.45669428, 0.60289496, 0.99220383, 0.3832815 , 0.08342831, 0.6211122 , 0.91324633, 0.49621433, 0.8436951 , 0.00544985, 0.18173647, 0.7941845 , 0.1175717 , 0.32131624, 0.3600599 , 0.16189952, 0.8056253 , 0.8696681 , 0.8966021 , 0.5313608 , 0.980649 , 0.6631159 , 0.5826805 , 0.03971836, 0.9231509 , 0.7995121 , 0.3719242 , 0.49562123, 0.24867107, 0.5806849 , 0.39288136, 0.79061097, 0.8867972 , 0.70988667, 0.17520966, 0.06835352, 0.5016661 , 0.03497642, 0.10607172, 0.30873862, 0.32629046, 0.03611915, 0.79392153, 0.26159677, 0.8521278 , 0.30235165, 0.291528 , 0.7643648 , 0.9302362 , 0.12348044, 0.64986646, 0.09082226, 0.05749565, 0.6363848 ], dtype=float32),)
[0.25350475 0.00509407 0.11649956 0.7729887 0.16001624 0.01385164 0.2914175 0.62117213 0.01348003 0.02146206 0.48819003 0.3228447 0.4833625 0.7648759 0.3808782 0.03462912 0.47573528 0.08879809 0.40640742 0.47477958 0.2510816 0.1956229 0.35128585 0.08529948 0.41537035 0.0732687 0.30372617 0.56919426 0.36143333 0.2374325 0.6636684 0.05995619 0.08239541 0.7272118 0.7552877 0.7154208 0.83982706 0.13758926 0.15577169 0.05320223 0.7511089 0.6936089 0.24413522 0.49531832 0.10079875 0.04246275 0.3087379 0.43809676 0.83275586 0.24883701 0.04448767 0.45507485 0.7482315 0.34243485 0.6753472 0.00273677 0.10397204 0.6245089 0.06428774 0.2011454 0.23064888 0.0913609 0.6361733 0.7023753 0.7306451 0.37321466 0.82028127 0.49488473 0.41946867 0.02048836 0.7587372 0.6299342 0.23989482 0.3419219 0.14875194 0.4176419 0.25646454 0.6208762 0.7203263 0.5402459 0.09978905 0.03603924 0.34716 0.01797615 0.057516 0.19179793 0.20487356 0.01857992 0.6242414 0.15778947 0.6840965 0.18709503 0.17919308 0.5943542 0.7662706 0.06780756 0.48222965 0.04869734 0.0300659 0.4694444 ]
[0.25350475 0.00509407 0.11649956 0.77298903 0.16001624 0.01385164 0.2914175 0.6211722 0.01348003 0.02146206 0.48819003 0.3228447 0.48336256 0.7648762 0.3808782 0.03462912 0.47573528 0.08879809 0.40640742 0.4747796 0.25108156 0.1956229 0.35128585 0.08529948 0.41537035 0.0732687 0.30372617 0.56919426 0.36143333 0.2374325 0.66366845 0.05995619 0.08239541 0.727212 0.755288 0.71542096 0.83982766 0.13758926 0.15577169 0.05320224 0.7511091 0.693609 0.24413522 0.49531832 0.10079875 0.04246275 0.3087379 0.43809676 0.83275646 0.24883698 0.04448767 0.4550749 0.7482317 0.34243485 0.6753474 0.00273677 0.10397203 0.624509 0.06428774 0.2011454 0.2306489 0.0913609 0.6361733 0.7023754 0.7306453 0.37321466 0.82028174 0.4948848 0.41946867 0.02048836 0.7587375 0.62993425 0.23989482 0.3419219 0.14875193 0.4176419 0.2564645 0.62087625 0.7203265 0.54024595 0.09978905 0.03603924 0.34716004 0.01797615 0.057516 0.19179791 0.20487356 0.01857992 0.6242415 0.15778947 0.6840967 0.18709503 0.17919308 0.5943543 0.7662709 0.06780756 0.48222965 0.04869734 0.0300659 0.46944442]
Benchmark¶
if __name__ == "__main__":
input_val = np.random.random(50000000).astype(np.float32)
out = np.zeros_like(input_val)
print("original")
t_original = %timeit -o gelu_tanh_forward(input_val)
print("superoptimized")
t_superopt = %timeit -o cuda_vectorized_gelu(input_val, out=out)
print("t_original / t_superopt", t_original.best / t_superopt.best)
original 532 ms ± 1.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) superoptimized 198 ms ± 1.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) t_original / t_superopt 2.7060199654441752