Demo 1: Tanh Approximation for GELU Activation Layer¶
(Depends on Ch.06)
This demo notebook shows how to use e-graphs and cost-based rewriting to
optimize the GELU activation function, which is widely used in deep learning.
We demonstrate how to encode and optimize a Pade44 rational approximation for
tanh()
as used in GELU, and how to lower the optimized computation to MLIR
and run it efficiently.
The notebook demonstrates:
- How to extend the compiler for NumPy and GELU-specific rewrites
- How to encode and optimize the Pade44 approximation for tanh
- How to lower and run the optimized computation using MLIR
- How to compare the results and performance of the original and optimized versions
$$ \text{GELU}(x) \approx 0.5x \left(1 + \tanh\left(\sqrt{\frac{2}{\pi}} \left(x + 0.044715x^3\right)\right)\right) $$
Imports and Setup¶
import mlir.dialects.arith as arith
import mlir.dialects.math as math
import mlir.ir as ir
import numpy as np
from egglog import (
Expr,
StringLike,
Vec,
function,
i64,
i64Like,
rewrite,
rule,
ruleset,
set_,
subsume,
union,
)
from sealir import rvsdg
from sealir.eqsat.py_eqsat import (
Py_AddIO,
Py_AttrIO,
Py_Call,
Py_DivIO,
Py_LoadGlobal,
Py_MulIO,
Py_PowIO,
)
from sealir.eqsat.rvsdg_eqsat import (
Term,
TermList,
)
from sealir.rvsdg import grammar as rg
from ch04_1_typeinfer_ifelse import (
ExtendEGraphToRVSDG as ch04_1_ExtendEGraphToRVSDG,
)
from ch04_1_typeinfer_ifelse import (
Grammar,
NbOp_Base,
String,
Type,
TypeFloat64,
TypeInt64,
TypeVar,
make_rules_for_binop,
setup_argtypes,
)
from ch05_typeinfer_array import MyCostModel as ch06_CostModel
from ch05_typeinfer_array import (
base_ruleset,
)
from ch06_mlir_backend import LowerStates, jit_compiler, run_test
from ch07_mlir_ufunc import Backend as UfuncBackend
from ch07_mlir_ufunc import (
Float32,
TypeFloat32,
ufunc_compiler,
ufunc_vectorize,
)
from utils.report import Report
The GELU Function¶
Define the GELU activation function using NumPy, with the tanh-based approximation.
def gelu_tanh_forward(a):
dt = np.float32
result = (
dt(0.5)
* a
* (
dt(1)
+ np.tanh(np.sqrt(dt(2) / dt(np.pi)) * (a + dt(0.044715) * a**3))
)
)
return result
Extend the Compiler for Needed Features¶
Add type inference and rules for NumPy modules, operations, and attributes required for the GELU function.
Type Inference for NumPy Module¶
class Module(Expr):
def __init__(self, name: StringLike): ...
def toType(self) -> Type: ...
@function
def ModuleGetAttr(mod: Module, attrname: StringLike) -> Term: ...
@ruleset
def facts_numpy_module(io: Term, name: String, op: Term, args: Vec[Term]):
yield rule(
op == Py_LoadGlobal(io, name),
name == String("np"),
).then(set_(TypeVar(op).getType()).to(Module("numpy").toType()))
# ------ attributes ------
numpy_mod = Module("numpy")
def unary_func(fname, target_func):
return rule(
op
== (
stmt := Py_Call(
func=ModuleGetAttr(numpy_mod, fname),
io=io,
args=TermList(args),
)
),
args.length() == i64(1),
).then(
subsume(stmt),
union(op.getPort(0)).with_(io),
union(op.getPort(1)).with_(target_func(args[0])),
)
# np.pi
const_pi = ModuleGetAttr(numpy_mod, "pi")
yield rewrite(
const_pi,
subsume=True,
).to(Term.LiteralF64(np.pi))
# np.float32
yield unary_func("float32", Npy_float32)
# np.sqrt
yield unary_func("sqrt", Npy_sqrt)
# np.tanh
yield unary_func("tanh", Npy_tanh)
Type Inference for NumPy Operations¶
@function
def Npy_float32(val: Term) -> Term: ...
@function
def Npy_sqrt(val: Term) -> Term: ...
@function
def Npy_tanh(val: Term) -> Term: ...
@function
def Npy_cast_f64_to_f32(val: Term) -> Term: ...
@function
def Npy_cast_i64_to_f32(val: Term) -> Term: ...
@function
def Npy_sqrt_float32(val: Term) -> Term: ...
@function
def Npy_tanh_float32(val: Term) -> Term: ...
@ruleset
def ruleset_typeinfer_numpy_functions(res: Term, arg: Term):
# float32()
yield rewrite(Npy_float32(arg), subsume=True).to(
Npy_cast_f64_to_f32(arg),
TypeVar(arg).getType() == TypeFloat64,
)
yield rewrite(Npy_float32(arg), subsume=True).to(
Npy_cast_i64_to_f32(arg),
TypeVar(arg).getType() == TypeInt64,
)
for fn in [Npy_cast_f64_to_f32, Npy_cast_i64_to_f32]:
yield rule(
res == fn(arg),
).then(set_(TypeVar(res).getType()).to(TypeFloat32))
# others
for func, typed_func in [
(Npy_sqrt, Npy_sqrt_float32),
(Npy_tanh, Npy_tanh_float32),
]:
yield rewrite(func(arg), subsume=True).to(
typed_func(arg),
TypeVar(arg).getType() == TypeFloat32,
)
yield rule(
res == typed_func(arg),
).then(set_(TypeVar(res).getType()).to(TypeFloat32))
Handle module.attr
¶
@ruleset
def ruleset_module(
io: Term, name: String, modname: String, op: Term, obj: Term
):
# Getattribute
yield rule(
op == Py_AttrIO(io, obj, name),
TypeVar(obj).getType() == Module(modname).toType(),
).then(
# Shortcut io
union(op.getPort(0)).with_(io),
# Setup getattr
union(op.getPort(1)).with_(ModuleGetAttr(Module(modname), name)),
)
Type Inference for float32
Operations¶
@function
def Nb_Add_Float32(lhs: Term, rhs: Term) -> Term: ...
@function
def Nb_Mul_Float32(lhs: Term, rhs: Term) -> Term: ...
@function
def Nb_Div_Float32(lhs: Term, rhs: Term) -> Term: ...
@function
def Nb_Pow_Float32_Int64(lhs: Term, rhs: Term) -> Term: ...
@ruleset
def ruleset_typeinfer_f32_ops(res: Term, x: Term, y: Term):
yield from make_rules_for_binop(
Py_AddIO, TypeFloat32, TypeFloat32, Nb_Add_Float32, TypeFloat32
)
yield from make_rules_for_binop(
Py_MulIO, TypeFloat32, TypeFloat32, Nb_Mul_Float32, TypeFloat32
)
yield from make_rules_for_binop(
Py_DivIO, TypeFloat32, TypeFloat32, Nb_Div_Float32, TypeFloat32
)
yield from make_rules_for_binop(
Py_PowIO, TypeFloat32, TypeInt64, Nb_Pow_Float32_Int64, TypeFloat32
)
additional_rules = (
facts_numpy_module
| ruleset_module
| ruleset_typeinfer_numpy_functions
| ruleset_typeinfer_f32_ops
)
Extend the RVSDG Grammar¶
SExpr = rvsdg.grammar.SExpr
class NbOp_F64_to_F32(NbOp_Base):
operand: SExpr
class NbOp_I64_to_F32(NbOp_Base):
operand: SExpr
class NpyOp_Sqrt_Float32(NbOp_Base):
operand: SExpr
class NpyOp_Tanh_Float32(NbOp_Base):
operand: SExpr
class NbOp_Mul_Float32(NbOp_Base):
lhs: SExpr
rhs: SExpr
class NbOp_Div_Float32(NbOp_Base):
lhs: SExpr
rhs: SExpr
class NbOp_Add_Float32(NbOp_Base):
lhs: SExpr
rhs: SExpr
class NbOp_Pow_Float32_Int64(NbOp_Base):
lhs: SExpr
rhs: SExpr
class NbOp_module(NbOp_Base):
name: str
class ExtendEGraphToRVSDG(ch04_1_ExtendEGraphToRVSDG):
def handle_Term(self, op: str, children: dict | list, grm: Grammar):
match op, children:
case "Py_Float", {"val": float(arg)}:
return grm.write(rg.PyFloat(arg))
case "Npy_cast_f64_to_f32", {"val": expr}:
return grm.write(NbOp_F64_to_F32(expr))
case "Npy_cast_i64_to_f32", {"val": expr}:
return grm.write(NbOp_I64_to_F32(expr))
case "Nb_Mul_Float32", {"lhs": lhs, "rhs": rhs}:
return grm.write(NbOp_Mul_Float32(lhs=lhs, rhs=rhs))
case "Nb_Add_Float32", {"lhs": lhs, "rhs": rhs}:
return grm.write(NbOp_Add_Float32(lhs=lhs, rhs=rhs))
case "Nb_Div_Float32", {"lhs": lhs, "rhs": rhs}:
return grm.write(NbOp_Div_Float32(lhs=lhs, rhs=rhs))
case "Nb_Pow_Float32_Int64", {"lhs": lhs, "rhs": rhs}:
return grm.write(NbOp_Pow_Float32_Int64(lhs=lhs, rhs=rhs))
case "Npy_sqrt_float32", {"val": val}:
return grm.write(NpyOp_Sqrt_Float32(val))
case "Npy_tanh_float32", {"val": val}:
return grm.write(NpyOp_Tanh_Float32(val))
# ---
case "ModuleGetAttr", {"mod": mod, "attrname": str(attrname)}:
return grm.write(rg.Undef(str(op)))
case _:
# Use parent's implementation for other terms.
return super().handle_Term(op, children, grm)
def handle_Module(
self, key: str, op: str, children: dict | list, grm: Grammar
):
return grm.write(rg.Undef(str(key)))
Extend the Backend to Support NumPy Operations¶
The Backend class extends UfuncBackend to handle the specific NumPy operations used in the GELU function. It provides MLIR lowering for:
- Basic arithmetic operations (add, multiply, divide)
- Type conversions (f64->f32, i64->f32)
- NumPy mathematical functions (sqrt, tanh, pow)
- Fallback handling for undefined operations
class Backend(UfuncBackend):
def __init__(self):
super().__init__()
self.f32 = ir.F32Type.get(context=self.context)
def get_mlir_type(self, seal_ty):
match seal_ty.name:
case "Float32":
return self.f32
return super().get_mlir_type(seal_ty)
def lower_expr(self, expr: SExpr, state: LowerStates):
match expr:
case NbOp_Add_Float32(lhs, rhs):
lhs = yield lhs
rhs = yield rhs
return arith.addf(lhs, rhs)
case NbOp_Mul_Float32(lhs, rhs):
lhs = yield lhs
rhs = yield rhs
return arith.mulf(lhs, rhs)
case NbOp_Div_Float32(lhs, rhs):
lhs = yield lhs
rhs = yield rhs
return arith.divf(lhs, rhs)
case NbOp_F64_to_F32(val):
val = yield val
return arith.truncf(self.f32, val)
case NbOp_I64_to_F32(val):
val = yield val
return arith.sitofp(self.f32, val)
case NpyOp_Tanh_Float32(val):
val = yield val
return math.tanh(val)
case NpyOp_Sqrt_Float32(val):
val = yield val
return math.sqrt(val)
case NbOp_Pow_Float32_Int64(val, p):
val = yield val
p = yield p
return math.powf(val, arith.sitofp(val.type, p))
case rg.Undef(str(name)):
return arith.constant(self.i32, 0)
return (yield from super().lower_expr(expr, state))
Cost Model¶
Assign higher cost for the transcendental functions, and lower cost for arithmetic and type conversion operations. This encourages the optimizer to prefer rational approximations over transcendental functions when possible.
class MyCostModel(ch06_CostModel):
def get_cost_function(self, nodename, op, ty, cost, children):
match op:
case "Npy_tanh" | "Npy_sqrt" | "Npy_float32":
cost = float("inf") # suppress untyped op
case "Npy_tanh_float32":
cost = 100
case "Npy_sqrt_float32":
cost = 50
case "Nb_Pow_Float32_Int64":
cost = 50
# Fallthrough to parent's cost function
return super().get_cost_function(nodename, op, ty, cost, children)
Run the Extended Pipeline¶
Compile and run the original GELU function using the extended pipeline and report the results.
compiler_config = dict(
converter_class=ExtendEGraphToRVSDG,
backend=Backend(),
cost_model=MyCostModel(),
)
if __name__ == "__main__":
report = Report("Pipeline execution report", enable_nested_metadata=True)
jit_func = jit_compiler(
fn=gelu_tanh_forward,
argtypes=(Float32,),
ruleset=(
base_ruleset | setup_argtypes(TypeFloat32) | additional_rules
),
pipeline_report=report,
**compiler_config,
).jit_func
report.display()
run_test(gelu_tanh_forward, jit_func, (0.234,), verbose=True)
Pipeline execution report
time elapsed 0.00ms timing breakdown:
module { func.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %cst = arith.constant 5.000000e-01 : f64 %c1_i64 = arith.constant 1 : i64 %c2_i64 = arith.constant 2 : i64 %cst_0 = arith.constant 3.1415926535897931 : f64 %cst_1 = arith.constant 4.471500e-02 : f64 %c3_i64 = arith.constant 3 : i64 cf.br ^bb1 ^bb1: // pred: ^bb0 %c0_i32 = arith.constant 0 : i32 %0 = arith.truncf %cst : f64 to f32 %1 = arith.mulf %0, %arg0 : f32 %2 = arith.sitofp %c1_i64 : i64 to f32 %3 = arith.sitofp %c2_i64 : i64 to f32 %4 = arith.truncf %cst_0 : f64 to f32 %5 = arith.divf %3, %4 : f32 %6 = math.sqrt %5 : f32 %7 = arith.truncf %cst_1 : f64 to f32 %8 = arith.sitofp %c3_i64 : i64 to f32 %9 = math.powf %arg0, %8 : f32 %10 = arith.mulf %7, %9 : f32 %11 = arith.addf %arg0, %10 : f32 %12 = arith.mulf %6, %11 : f32 %13 = math.tanh %12 : f32 %14 = arith.addf %2, %13 : f32 %15 = arith.mulf %1, %14 : f32 return %15 : f32 } }
time elapsed 2.73ms timing breakdown: 2.73ms: Lowered module
module { llvm.func @tanhf(f32) -> f32 attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, sym_visibility = "private"} llvm.func @powf(f32, f32) -> f32 attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, sym_visibility = "private"} llvm.func @sqrtf(f32) -> f32 attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, sym_visibility = "private"} llvm.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %0 = llvm.mlir.constant(3.000000e+00 : f32) : f32 %1 = llvm.mlir.constant(2.000000e+00 : f32) : f32 %2 = llvm.mlir.constant(1.000000e+00 : f32) : f32 %3 = llvm.mlir.constant(5.000000e-01 : f32) : f32 %4 = llvm.mlir.constant(3.1415926535897931 : f64) : f64 %5 = llvm.mlir.constant(4.471500e-02 : f64) : f64 llvm.br ^bb1 ^bb1: // pred: ^bb0 %6 = llvm.fmul %arg0, %3 : f32 %7 = llvm.fptrunc %4 : f64 to f32 %8 = llvm.fdiv %1, %7 : f32 %9 = llvm.call @sqrtf(%8) : (f32) -> f32 %10 = llvm.fptrunc %5 : f64 to f32 %11 = llvm.call @powf(%arg0, %0) : (f32, f32) -> f32 %12 = llvm.fmul %10, %11 : f32 %13 = llvm.fadd %arg0, %12 : f32 %14 = llvm.fmul %9, %13 : f32 %15 = llvm.call @tanhf(%14) : (f32) -> f32 %16 = llvm.fadd %15, %2 : f32 %17 = llvm.fmul %6, %16 : f32 llvm.return %17 : f32 } llvm.func @_mlir_ciface_func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %0 = llvm.call @func(%arg0) : (f32) -> f32 llvm.return %0 : f32 } }
time elapsed 1.90ms timing breakdown: 1.90ms: MLIR optimized
Testing report
(0.234,)
0.13864579796791077
0.1386458
Add Rules to Optimize¶
Add rewrite rules for the Pade44 approximation of tanh(x) and for expanding power operations into multiplication chains. These rules enable the optimizer to replace expensive transcendental functions with efficient arithmetic.
Define Pade44 approximation rewrite rule for tanh(x)¶
The Pade44 approximation replaces tanh(x) with a rational function:
$$ tanh(x) ≈ (10x^3 + 105x) / (x^4 + 45x^2 + 105) $$
This approximation provides good accuracy for small to moderate values of x while avoiding the computational cost of the transcendental tanh function. The rational form allows for efficient evaluation using only basic arithmetic operations (addition, multiplication, division) and power operations.
@ruleset
def pade44_tanh_expansion(x: Term):
flt = lambda f: Npy_float32(Term.LiteralF64(float(f)))
liti64 = Term.LiteralI64
pow = Nb_Pow_Float32_Int64
mul = Nb_Mul_Float32
add = Nb_Add_Float32
div = Nb_Div_Float32
# Rewrite tanh(x) to the pade44 approximation
# Note the use of pow(), they will be optimized
# by the `pow_expansion` ruleset.
yield rewrite(Npy_tanh_float32(x)).to(
div(
add(mul(flt(10), pow(x, liti64(3))), mul(flt(105), x)),
add(
add(pow(x, liti64(4)), mul(flt(45), pow(x, liti64(2)))),
flt(105),
),
)
)
Define expansion of power operations (e.g., x^N) to a sequence of multiplications This converts expensive power operations into more efficient multiplication chains For example: x^3 becomes x * x * x, and x^0 becomes 1.0 Note: The exponent N must be a compile-time constant for this optimization to work
@ruleset
def pow_expansion(x: Term, ival: i64):
# Rules to expand pow(x, i) to multiplcations
powf = Nb_Pow_Float32_Int64
lit64 = Term.LiteralI64
mulf = Nb_Mul_Float32
yield rewrite(powf(x, lit64(ival))).to(
mulf(x, powf(x, lit64(ival - 1))),
ival >= 1,
)
yield rewrite(powf(x, lit64(i64(0))), subsume=True).to(
Npy_float32(Term.LiteralF64(float(1))),
)
Combine the rules. Rules are composable.
optimize_rules = pade44_tanh_expansion | pow_expansion
Run the Optimized Function¶
Compile and run the optimized GELU function using the new rules, and report the results.
if __name__ == "__main__":
report = Report("Pipeline execution report", enable_nested_metadata=True)
jit_func = jit_compiler(
fn=gelu_tanh_forward,
argtypes=(Float32,),
ruleset=(
base_ruleset
| setup_argtypes(TypeFloat32)
| additional_rules
| optimize_rules
),
pipeline_report=report,
**compiler_config,
).jit_func
report.display()
Pipeline execution report
time elapsed 0.00ms timing breakdown:
module { func.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %cst = arith.constant 5.000000e-01 : f64 %c1_i64 = arith.constant 1 : i64 %cst_0 = arith.constant 1.000000e+01 : f64 %c2_i64 = arith.constant 2 : i64 %cst_1 = arith.constant 3.1415926535897931 : f64 %cst_2 = arith.constant 4.471500e-02 : f64 %cst_3 = arith.constant 1.000000e+00 : f64 %cst_4 = arith.constant 1.050000e+02 : f64 %cst_5 = arith.constant 4.500000e+01 : f64 cf.br ^bb1 ^bb1: // pred: ^bb0 %c0_i32 = arith.constant 0 : i32 %0 = arith.truncf %cst : f64 to f32 %1 = arith.mulf %0, %arg0 : f32 %2 = arith.sitofp %c1_i64 : i64 to f32 %3 = arith.truncf %cst_0 : f64 to f32 %4 = arith.sitofp %c2_i64 : i64 to f32 %5 = arith.truncf %cst_1 : f64 to f32 %6 = arith.divf %4, %5 : f32 %7 = math.sqrt %6 : f32 %8 = arith.truncf %cst_2 : f64 to f32 %9 = arith.truncf %cst_3 : f64 to f32 %10 = arith.mulf %arg0, %9 : f32 %11 = arith.mulf %arg0, %10 : f32 %12 = arith.mulf %arg0, %11 : f32 %13 = arith.mulf %8, %12 : f32 %14 = arith.addf %arg0, %13 : f32 %15 = arith.mulf %7, %14 : f32 %16 = arith.mulf %15, %9 : f32 %17 = arith.mulf %15, %16 : f32 %18 = arith.mulf %15, %17 : f32 %19 = arith.mulf %3, %18 : f32 %20 = arith.truncf %cst_4 : f64 to f32 %21 = arith.mulf %20, %15 : f32 %22 = arith.addf %19, %21 : f32 %23 = arith.mulf %15, %18 : f32 %24 = arith.truncf %cst_5 : f64 to f32 %25 = arith.mulf %24, %17 : f32 %26 = arith.addf %23, %25 : f32 %27 = arith.addf %26, %20 : f32 %28 = arith.divf %22, %27 : f32 %29 = arith.addf %2, %28 : f32 %30 = arith.mulf %1, %29 : f32 return %30 : f32 } }
time elapsed 2.13ms timing breakdown: 2.13ms: Lowered module
module { llvm.func @sqrtf(f32) -> f32 attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, sym_visibility = "private"} llvm.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %0 = llvm.mlir.constant(4.500000e+01 : f32) : f32 %1 = llvm.mlir.constant(1.050000e+02 : f32) : f32 %2 = llvm.mlir.constant(2.000000e+00 : f32) : f32 %3 = llvm.mlir.constant(1.000000e+01 : f32) : f32 %4 = llvm.mlir.constant(1.000000e+00 : f32) : f32 %5 = llvm.mlir.constant(5.000000e-01 : f32) : f32 %6 = llvm.mlir.constant(3.1415926535897931 : f64) : f64 %7 = llvm.mlir.constant(4.471500e-02 : f64) : f64 llvm.br ^bb1 ^bb1: // pred: ^bb0 %8 = llvm.fmul %arg0, %5 : f32 %9 = llvm.fptrunc %6 : f64 to f32 %10 = llvm.fdiv %2, %9 : f32 %11 = llvm.call @sqrtf(%10) : (f32) -> f32 %12 = llvm.fptrunc %7 : f64 to f32 %13 = llvm.fmul %arg0, %arg0 : f32 %14 = llvm.fmul %arg0, %13 : f32 %15 = llvm.fmul %12, %14 : f32 %16 = llvm.fadd %arg0, %15 : f32 %17 = llvm.fmul %11, %16 : f32 %18 = llvm.fmul %17, %17 : f32 %19 = llvm.fmul %17, %18 : f32 %20 = llvm.fmul %19, %3 : f32 %21 = llvm.fmul %17, %1 : f32 %22 = llvm.fadd %20, %21 : f32 %23 = llvm.fmul %17, %19 : f32 %24 = llvm.fmul %18, %0 : f32 %25 = llvm.fadd %23, %24 : f32 %26 = llvm.fadd %25, %1 : f32 %27 = llvm.fdiv %22, %26 : f32 %28 = llvm.fadd %27, %4 : f32 %29 = llvm.fmul %8, %28 : f32 llvm.return %29 : f32 } llvm.func @_mlir_ciface_func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %0 = llvm.call @func(%arg0) : (f32) -> f32 llvm.return %0 : f32 } }
time elapsed 1.69ms timing breakdown: 1.69ms: MLIR optimized
Compare the Result¶
Compare the output of the original and optimized GELU functions, allowing for a small tolerance due to floating-point approximation.
if __name__ == "__main__":
relclose = lambda x, y: np.allclose(x, y, rtol=1e-6)
run_test(
gelu_tanh_forward, jit_func, (0.234,), equal=relclose, verbose=True
)
Testing report
(0.234,)
0.13864579796791077
0.1386458
Ufunc Version¶
Demonstrate vectorized (ufunc) compilation and execution of the optimized GELU function, and compare results for a batch of inputs.
if __name__ == "__main__":
report = Report("Pipeline execution report", enable_nested_metadata=True)
vectorized_gelu = ufunc_vectorize(
input_type=Float32,
ndim=1,
compiler_config={**compiler_config, "pipeline_report": report},
extra_ruleset=additional_rules | optimize_rules,
)(gelu_tanh_forward)
report.display()
relclose = lambda x, y: np.allclose(x, y, rtol=1e-6)
input_val = np.random.random(100).astype(np.float32)
run_test(
gelu_tanh_forward,
vectorized_gelu,
(input_val,),
equal=relclose,
verbose=True,
)
Pipeline execution report
time elapsed 0.00ms timing breakdown:
module { func.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %cst = arith.constant 5.000000e-01 : f64 %c1_i64 = arith.constant 1 : i64 %cst_0 = arith.constant 1.000000e+01 : f64 %c2_i64 = arith.constant 2 : i64 %cst_1 = arith.constant 3.1415926535897931 : f64 %cst_2 = arith.constant 4.471500e-02 : f64 %cst_3 = arith.constant 1.000000e+00 : f64 %cst_4 = arith.constant 1.050000e+02 : f64 %cst_5 = arith.constant 4.500000e+01 : f64 cf.br ^bb1 ^bb1: // pred: ^bb0 %c0_i32 = arith.constant 0 : i32 %0 = arith.truncf %cst : f64 to f32 %1 = arith.mulf %0, %arg0 : f32 %2 = arith.sitofp %c1_i64 : i64 to f32 %3 = arith.truncf %cst_0 : f64 to f32 %4 = arith.sitofp %c2_i64 : i64 to f32 %5 = arith.truncf %cst_1 : f64 to f32 %6 = arith.divf %4, %5 : f32 %7 = math.sqrt %6 : f32 %8 = arith.truncf %cst_2 : f64 to f32 %9 = arith.truncf %cst_3 : f64 to f32 %10 = arith.mulf %arg0, %9 : f32 %11 = arith.mulf %arg0, %10 : f32 %12 = arith.mulf %arg0, %11 : f32 %13 = arith.mulf %8, %12 : f32 %14 = arith.addf %arg0, %13 : f32 %15 = arith.mulf %7, %14 : f32 %16 = arith.mulf %15, %9 : f32 %17 = arith.mulf %15, %16 : f32 %18 = arith.mulf %15, %17 : f32 %19 = arith.mulf %3, %18 : f32 %20 = arith.truncf %cst_4 : f64 to f32 %21 = arith.mulf %20, %15 : f32 %22 = arith.addf %19, %21 : f32 %23 = arith.mulf %15, %18 : f32 %24 = arith.truncf %cst_5 : f64 to f32 %25 = arith.mulf %24, %17 : f32 %26 = arith.addf %23, %25 : f32 %27 = arith.addf %26, %20 : f32 %28 = arith.divf %22, %27 : f32 %29 = arith.addf %2, %28 : f32 %30 = arith.mulf %1, %29 : f32 return %30 : f32 } }
time elapsed 2.06ms timing breakdown: 2.06ms: Lowered module
module { llvm.func @sqrtf(f32) -> f32 attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, sym_visibility = "private"} llvm.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %0 = llvm.mlir.constant(4.500000e+01 : f32) : f32 %1 = llvm.mlir.constant(1.050000e+02 : f32) : f32 %2 = llvm.mlir.constant(2.000000e+00 : f32) : f32 %3 = llvm.mlir.constant(1.000000e+01 : f32) : f32 %4 = llvm.mlir.constant(1.000000e+00 : f32) : f32 %5 = llvm.mlir.constant(5.000000e-01 : f32) : f32 %6 = llvm.mlir.constant(3.1415926535897931 : f64) : f64 %7 = llvm.mlir.constant(4.471500e-02 : f64) : f64 llvm.br ^bb1 ^bb1: // pred: ^bb0 %8 = llvm.fmul %arg0, %5 : f32 %9 = llvm.fptrunc %6 : f64 to f32 %10 = llvm.fdiv %2, %9 : f32 %11 = llvm.call @sqrtf(%10) : (f32) -> f32 %12 = llvm.fptrunc %7 : f64 to f32 %13 = llvm.fmul %arg0, %arg0 : f32 %14 = llvm.fmul %arg0, %13 : f32 %15 = llvm.fmul %12, %14 : f32 %16 = llvm.fadd %arg0, %15 : f32 %17 = llvm.fmul %11, %16 : f32 %18 = llvm.fmul %17, %17 : f32 %19 = llvm.fmul %17, %18 : f32 %20 = llvm.fmul %19, %3 : f32 %21 = llvm.fmul %17, %1 : f32 %22 = llvm.fadd %20, %21 : f32 %23 = llvm.fmul %17, %19 : f32 %24 = llvm.fmul %18, %0 : f32 %25 = llvm.fadd %23, %24 : f32 %26 = llvm.fadd %25, %1 : f32 %27 = llvm.fdiv %22, %26 : f32 %28 = llvm.fadd %27, %4 : f32 %29 = llvm.fmul %8, %28 : f32 llvm.return %29 : f32 } llvm.func @_mlir_ciface_func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} { %0 = llvm.call @func(%arg0) : (f32) -> f32 llvm.return %0 : f32 } llvm.func @ufunc(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64) attributes {llvm.emit_c_interface} { %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %1 = llvm.insertvalue %arg5, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %2 = llvm.insertvalue %arg6, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %3 = llvm.insertvalue %arg7, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %4 = llvm.insertvalue %arg8, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %5 = llvm.insertvalue %arg9, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %7 = llvm.insertvalue %arg0, %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %8 = llvm.insertvalue %arg1, %7[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %9 = llvm.insertvalue %arg2, %8[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %10 = llvm.insertvalue %arg3, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %11 = llvm.insertvalue %arg4, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %12 = llvm.mlir.constant(1 : index) : i64 %13 = llvm.mlir.constant(0 : index) : i64 %14 = llvm.extractvalue %11[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> llvm.br ^bb1(%13 : i64) ^bb1(%15: i64): // 2 preds: ^bb0, ^bb2 %16 = llvm.icmp "slt" %15, %14 : i64 llvm.cond_br %16, ^bb2, ^bb3 ^bb2: // pred: ^bb1 %17 = llvm.extractvalue %11[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %18 = llvm.getelementptr %17[%15] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %19 = llvm.load %18 : !llvm.ptr -> f32 %20 = llvm.call @func(%19) : (f32) -> f32 %21 = llvm.extractvalue %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %22 = llvm.getelementptr %21[%15] : (!llvm.ptr, i64) -> !llvm.ptr, f32 llvm.store %20, %22 : f32, !llvm.ptr %23 = llvm.add %15, %12 : i64 llvm.br ^bb1(%23 : i64) ^bb3: // pred: ^bb1 llvm.return } llvm.func @_mlir_ciface_ufunc(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {llvm.emit_c_interface} { %0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %4 = llvm.extractvalue %0[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %5 = llvm.extractvalue %0[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %6 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %7 = llvm.extractvalue %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %8 = llvm.extractvalue %6[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %9 = llvm.extractvalue %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %10 = llvm.extractvalue %6[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %11 = llvm.extractvalue %6[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> llvm.call @ufunc(%1, %2, %3, %4, %5, %7, %8, %9, %10, %11) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> () llvm.return } }
time elapsed 3.10ms timing breakdown: 3.10ms: MLIR optimized
Testing report
(array([0.53754056, 0.43233588, 0.10236798, 0.46508574, 0.58549654, 0.8923938 , 0.2757588 , 0.08432309, 0.8884488 , 0.8587762 , 0.6457952 , 0.11737864, 0.9524025 , 0.9533351 , 0.3317381 , 0.1680605 , 0.0680922 , 0.17509076, 0.07749847, 0.972892 , 0.43774995, 0.9594856 , 0.3802283 , 0.4063943 , 0.6109139 , 0.82032937, 0.96378887, 0.21142137, 0.8676044 , 0.8083702 , 0.53161585, 0.36552882, 0.2643556 , 0.00743109, 0.21400109, 0.45045698, 0.87366974, 0.7190363 , 0.54310775, 0.18760236, 0.72153026, 0.61276686, 0.9093058 , 0.91447484, 0.5454792 , 0.5947745 , 0.9831064 , 0.7109356 , 0.15139958, 0.9206733 , 0.8234191 , 0.41100514, 0.445924 , 0.15229666, 0.13950907, 0.19849318, 0.02099547, 0.2286102 , 0.02923151, 0.2675173 , 0.8535502 , 0.43683502, 0.48123544, 0.91435367, 0.8591862 , 0.60497624, 0.6497597 , 0.47559726, 0.745626 , 0.068115 , 0.93472457, 0.32189476, 0.51888883, 0.08371709, 0.08502585, 0.24354592, 0.34441778, 0.91207695, 0.6931497 , 0.3731865 , 0.9115862 , 0.10682532, 0.09896134, 0.53871924, 0.1919585 , 0.24836382, 0.7622515 , 0.9566354 , 0.53254044, 0.7762569 , 0.8494233 , 0.5688198 , 0.61417764, 0.8689236 , 0.5286076 , 0.07920949, 0.7584067 , 0.72011465, 0.7771502 , 0.89118475], dtype=float32),)
[0.3787034 0.28846663 0.05535727 0.31581023 0.42205015 0.72621244 0.16783419 0.0449948 0.7220623 0.69101226 0.47835886 0.06417319 0.78993255 0.7909312 0.2089769 0.09524503 0.03589438 0.09971316 0.04114288 0.8119259 0.2929388 0.7975227 0.24642435 0.26730743 0.44554824 0.65123934 0.8021408 0.12341041 0.70021915 0.63897943 0.37344074 0.23489869 0.15973456 0.00373758 0.12513152 0.30351046 0.70656013 0.5492407 0.38366732 0.10775949 0.55169916 0.44727504 0.74406016 0.74953294 0.3857873 0.43058652 0.8229314 0.54127514 0.0848093 0.75610644 0.6544156 0.27103543 0.2997272 0.08536569 0.07749382 0.11486163 0.01067358 0.13497381 0.0149566 0.16197063 0.6855748 0.2921817 0.3295473 0.7494046 0.6914393 0.44002742 0.482128 0.3247327 0.5755953 0.03590702 0.77105045 0.20157808 0.36220664 0.04465127 0.04539355 0.14520308 0.2186094 0.74699306 0.5238929 0.2408843 0.7464735 0.05795657 0.05338125 0.3797528 0.11058928 0.14853863 0.5922312 0.7944669 0.37426063 0.60633624 0.68128777 0.40682715 0.44859096 0.7015972 0.37077665 0.04210514 0.5883735 0.5503034 0.60723865 0.72494 ]
[0.3787034 0.28846663 0.05535727 0.31581023 0.42205015 0.7262126 0.16783419 0.0449948 0.72206247 0.6910124 0.4783589 0.06417319 0.78993297 0.7909316 0.2089769 0.09524503 0.03589438 0.09971316 0.04114288 0.8119263 0.2929388 0.79752314 0.24642432 0.26730743 0.44554824 0.6512394 0.8021411 0.12341041 0.7002193 0.63897943 0.37344074 0.23489869 0.15973456 0.00373758 0.12513152 0.30351046 0.7065603 0.54924077 0.38366732 0.10775949 0.55169916 0.44727504 0.74406034 0.74953324 0.3857873 0.43058652 0.82293195 0.54127514 0.0848093 0.75610673 0.65441567 0.27103543 0.2997272 0.08536569 0.07749382 0.11486163 0.01067358 0.13497381 0.0149566 0.16197063 0.68557495 0.2921817 0.3295473 0.7494048 0.6914394 0.44002742 0.48212805 0.3247327 0.5755953 0.03590702 0.77105075 0.20157808 0.36220664 0.04465127 0.04539355 0.14520308 0.2186094 0.7469933 0.5238929 0.2408843 0.7464738 0.05795657 0.05338125 0.37975284 0.11058928 0.14853863 0.5922312 0.7944673 0.37426063 0.60633624 0.6812878 0.40682715 0.44859102 0.7015974 0.37077665 0.04210514 0.58837354 0.55030346 0.6072387 0.7249402 ]
Benchmark¶
if __name__ == "__main__":
input_val = np.random.random(300000).astype(np.float32)
out = np.zeros_like(input_val)
print("original")
%timeit gelu_tanh_forward(input_val)
print("superoptimized")
%timeit vectorized_gelu(input_val, out=out)
original
4.35 ms ± 16.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) superoptimized
260 μs ± 3.46 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)