Demo 1: Tanh Approximation for GELU Activation Layer¶

(Depends on Ch.06)

This demo notebook shows how to use e-graphs and cost-based rewriting to optimize the GELU activation function, which is widely used in deep learning. We demonstrate how to encode and optimize a Pade44 rational approximation for tanh() as used in GELU, and how to lower the optimized computation to MLIR and run it efficiently.

The notebook demonstrates:

  • How to extend the compiler for NumPy and GELU-specific rewrites
  • How to encode and optimize the Pade44 approximation for tanh
  • How to lower and run the optimized computation using MLIR
  • How to compare the results and performance of the original and optimized versions

$$ \text{GELU}(x) \approx 0.5x \left(1 + \tanh\left(\sqrt{\frac{2}{\pi}} \left(x + 0.044715x^3\right)\right)\right) $$

Imports and Setup¶

In [1]:
import mlir.dialects.arith as arith
import mlir.dialects.math as math
import mlir.ir as ir
import numpy as np
from egglog import (
    Expr,
    StringLike,
    Vec,
    function,
    i64,
    i64Like,
    rewrite,
    rule,
    ruleset,
    set_,
    subsume,
    union,
)
from sealir import rvsdg
from sealir.eqsat.py_eqsat import (
    Py_AddIO,
    Py_AttrIO,
    Py_Call,
    Py_DivIO,
    Py_LoadGlobal,
    Py_MulIO,
    Py_PowIO,
)
from sealir.eqsat.rvsdg_eqsat import (
    Term,
    TermList,
)
from sealir.rvsdg import grammar as rg
In [2]:
from ch04_1_typeinfer_ifelse import (
    ExtendEGraphToRVSDG as ch04_1_ExtendEGraphToRVSDG,
)
from ch04_1_typeinfer_ifelse import (
    Grammar,
    NbOp_Base,
    String,
    Type,
    TypeFloat64,
    TypeInt64,
    TypeVar,
    make_rules_for_binop,
    setup_argtypes,
)
from ch05_typeinfer_array import MyCostModel as ch06_CostModel
from ch05_typeinfer_array import (
    base_ruleset,
)
from ch06_mlir_backend import LowerStates, jit_compiler, run_test
from ch07_mlir_ufunc import Backend as UfuncBackend
from ch07_mlir_ufunc import (
    Float32,
    TypeFloat32,
    ufunc_compiler,
    ufunc_vectorize,
)
from utils.report import Report

The GELU Function¶

Define the GELU activation function using NumPy, with the tanh-based approximation.

In [3]:
def gelu_tanh_forward(a):
    dt = np.float32
    result = (
        dt(0.5)
        * a
        * (
            dt(1)
            + np.tanh(np.sqrt(dt(2) / dt(np.pi)) * (a + dt(0.044715) * a**3))
        )
    )
    return result

Extend the Compiler for Needed Features¶

Add type inference and rules for NumPy modules, operations, and attributes required for the GELU function.

Type Inference for NumPy Module¶

In [4]:
class Module(Expr):
    def __init__(self, name: StringLike): ...

    def toType(self) -> Type: ...
In [5]:
@function
def ModuleGetAttr(mod: Module, attrname: StringLike) -> Term: ...
In [6]:
@ruleset
def facts_numpy_module(io: Term, name: String, op: Term, args: Vec[Term]):

    yield rule(
        op == Py_LoadGlobal(io, name),
        name == String("np"),
    ).then(set_(TypeVar(op).getType()).to(Module("numpy").toType()))

    # ------ attributes ------
    numpy_mod = Module("numpy")

    def unary_func(fname, target_func):
        return rule(
            op
            == (
                stmt := Py_Call(
                    func=ModuleGetAttr(numpy_mod, fname),
                    io=io,
                    args=TermList(args),
                )
            ),
            args.length() == i64(1),
        ).then(
            subsume(stmt),
            union(op.getPort(0)).with_(io),
            union(op.getPort(1)).with_(target_func(args[0])),
        )

    # np.pi
    const_pi = ModuleGetAttr(numpy_mod, "pi")
    yield rewrite(
        const_pi,
        subsume=True,
    ).to(Term.LiteralF64(np.pi))
    # np.float32
    yield unary_func("float32", Npy_float32)
    # np.sqrt
    yield unary_func("sqrt", Npy_sqrt)
    # np.tanh
    yield unary_func("tanh", Npy_tanh)

Type Inference for NumPy Operations¶

In [7]:
@function
def Npy_float32(val: Term) -> Term: ...
@function
def Npy_sqrt(val: Term) -> Term: ...
@function
def Npy_tanh(val: Term) -> Term: ...
In [8]:
@function
def Npy_cast_f64_to_f32(val: Term) -> Term: ...
@function
def Npy_cast_i64_to_f32(val: Term) -> Term: ...
@function
def Npy_sqrt_float32(val: Term) -> Term: ...
@function
def Npy_tanh_float32(val: Term) -> Term: ...
In [9]:
@ruleset
def ruleset_typeinfer_numpy_functions(res: Term, arg: Term):
    # float32()
    yield rewrite(Npy_float32(arg), subsume=True).to(
        Npy_cast_f64_to_f32(arg),
        TypeVar(arg).getType() == TypeFloat64,
    )
    yield rewrite(Npy_float32(arg), subsume=True).to(
        Npy_cast_i64_to_f32(arg),
        TypeVar(arg).getType() == TypeInt64,
    )

    for fn in [Npy_cast_f64_to_f32, Npy_cast_i64_to_f32]:
        yield rule(
            res == fn(arg),
        ).then(set_(TypeVar(res).getType()).to(TypeFloat32))
    # others

    for func, typed_func in [
        (Npy_sqrt, Npy_sqrt_float32),
        (Npy_tanh, Npy_tanh_float32),
    ]:
        yield rewrite(func(arg), subsume=True).to(
            typed_func(arg),
            TypeVar(arg).getType() == TypeFloat32,
        )
        yield rule(
            res == typed_func(arg),
        ).then(set_(TypeVar(res).getType()).to(TypeFloat32))

Handle module.attr¶

In [10]:
@ruleset
def ruleset_module(
    io: Term, name: String, modname: String, op: Term, obj: Term
):
    # Getattribute
    yield rule(
        op == Py_AttrIO(io, obj, name),
        TypeVar(obj).getType() == Module(modname).toType(),
    ).then(
        # Shortcut io
        union(op.getPort(0)).with_(io),
        # Setup getattr
        union(op.getPort(1)).with_(ModuleGetAttr(Module(modname), name)),
    )

Type Inference for float32 Operations¶

In [11]:
@function
def Nb_Add_Float32(lhs: Term, rhs: Term) -> Term: ...


@function
def Nb_Mul_Float32(lhs: Term, rhs: Term) -> Term: ...


@function
def Nb_Div_Float32(lhs: Term, rhs: Term) -> Term: ...


@function
def Nb_Pow_Float32_Int64(lhs: Term, rhs: Term) -> Term: ...
In [12]:
@ruleset
def ruleset_typeinfer_f32_ops(res: Term, x: Term, y: Term):
    yield from make_rules_for_binop(
        Py_AddIO, TypeFloat32, TypeFloat32, Nb_Add_Float32, TypeFloat32
    )
    yield from make_rules_for_binop(
        Py_MulIO, TypeFloat32, TypeFloat32, Nb_Mul_Float32, TypeFloat32
    )
    yield from make_rules_for_binop(
        Py_DivIO, TypeFloat32, TypeFloat32, Nb_Div_Float32, TypeFloat32
    )
    yield from make_rules_for_binop(
        Py_PowIO, TypeFloat32, TypeInt64, Nb_Pow_Float32_Int64, TypeFloat32
    )
In [13]:
additional_rules = (
    facts_numpy_module
    | ruleset_module
    | ruleset_typeinfer_numpy_functions
    | ruleset_typeinfer_f32_ops
)

Extend the RVSDG Grammar¶

In [14]:
SExpr = rvsdg.grammar.SExpr


class NbOp_F64_to_F32(NbOp_Base):
    operand: SExpr


class NbOp_I64_to_F32(NbOp_Base):
    operand: SExpr


class NpyOp_Sqrt_Float32(NbOp_Base):
    operand: SExpr


class NpyOp_Tanh_Float32(NbOp_Base):
    operand: SExpr


class NbOp_Mul_Float32(NbOp_Base):
    lhs: SExpr
    rhs: SExpr


class NbOp_Div_Float32(NbOp_Base):
    lhs: SExpr
    rhs: SExpr


class NbOp_Add_Float32(NbOp_Base):
    lhs: SExpr
    rhs: SExpr


class NbOp_Pow_Float32_Int64(NbOp_Base):
    lhs: SExpr
    rhs: SExpr


class NbOp_module(NbOp_Base):
    name: str
In [15]:
class ExtendEGraphToRVSDG(ch04_1_ExtendEGraphToRVSDG):

    def handle_Term(self, op: str, children: dict | list, grm: Grammar):

        match op, children:
            case "Py_Float", {"val": float(arg)}:
                return grm.write(rg.PyFloat(arg))

            case "Npy_cast_f64_to_f32", {"val": expr}:
                return grm.write(NbOp_F64_to_F32(expr))

            case "Npy_cast_i64_to_f32", {"val": expr}:
                return grm.write(NbOp_I64_to_F32(expr))

            case "Nb_Mul_Float32", {"lhs": lhs, "rhs": rhs}:
                return grm.write(NbOp_Mul_Float32(lhs=lhs, rhs=rhs))
            case "Nb_Add_Float32", {"lhs": lhs, "rhs": rhs}:
                return grm.write(NbOp_Add_Float32(lhs=lhs, rhs=rhs))
            case "Nb_Div_Float32", {"lhs": lhs, "rhs": rhs}:
                return grm.write(NbOp_Div_Float32(lhs=lhs, rhs=rhs))
            case "Nb_Pow_Float32_Int64", {"lhs": lhs, "rhs": rhs}:
                return grm.write(NbOp_Pow_Float32_Int64(lhs=lhs, rhs=rhs))
            case "Npy_sqrt_float32", {"val": val}:
                return grm.write(NpyOp_Sqrt_Float32(val))
            case "Npy_tanh_float32", {"val": val}:
                return grm.write(NpyOp_Tanh_Float32(val))
            # ---
            case "ModuleGetAttr", {"mod": mod, "attrname": str(attrname)}:
                return grm.write(rg.Undef(str(op)))
            case _:
                # Use parent's implementation for other terms.
                return super().handle_Term(op, children, grm)

    def handle_Module(
        self, key: str, op: str, children: dict | list, grm: Grammar
    ):
        return grm.write(rg.Undef(str(key)))

Extend the Backend to Support NumPy Operations¶

The Backend class extends UfuncBackend to handle the specific NumPy operations used in the GELU function. It provides MLIR lowering for:

  • Basic arithmetic operations (add, multiply, divide)
  • Type conversions (f64->f32, i64->f32)
  • NumPy mathematical functions (sqrt, tanh, pow)
  • Fallback handling for undefined operations
In [16]:
class Backend(UfuncBackend):
    def __init__(self):
        super().__init__()
        self.f32 = ir.F32Type.get(context=self.context)

    def get_mlir_type(self, seal_ty):
        match seal_ty.name:
            case "Float32":
                return self.f32
        return super().get_mlir_type(seal_ty)

    def lower_expr(self, expr: SExpr, state: LowerStates):
        match expr:
            case NbOp_Add_Float32(lhs, rhs):
                lhs = yield lhs
                rhs = yield rhs
                return arith.addf(lhs, rhs)
            case NbOp_Mul_Float32(lhs, rhs):
                lhs = yield lhs
                rhs = yield rhs
                return arith.mulf(lhs, rhs)
            case NbOp_Div_Float32(lhs, rhs):
                lhs = yield lhs
                rhs = yield rhs
                return arith.divf(lhs, rhs)
            case NbOp_F64_to_F32(val):
                val = yield val
                return arith.truncf(self.f32, val)
            case NbOp_I64_to_F32(val):
                val = yield val
                return arith.sitofp(self.f32, val)
            case NpyOp_Tanh_Float32(val):
                val = yield val
                return math.tanh(val)
            case NpyOp_Sqrt_Float32(val):
                val = yield val
                return math.sqrt(val)
            case NbOp_Pow_Float32_Int64(val, p):
                val = yield val
                p = yield p
                return math.powf(val, arith.sitofp(val.type, p))
            case rg.Undef(str(name)):
                return arith.constant(self.i32, 0)
        return (yield from super().lower_expr(expr, state))

Cost Model¶

Assign higher cost for the transcendental functions, and lower cost for arithmetic and type conversion operations. This encourages the optimizer to prefer rational approximations over transcendental functions when possible.

In [17]:
class MyCostModel(ch06_CostModel):
    def get_cost_function(self, nodename, op, ty, cost, children):
        match op:
            case "Npy_tanh" | "Npy_sqrt" | "Npy_float32":
                cost = float("inf")  # suppress untyped op
            case "Npy_tanh_float32":
                cost = 100
            case "Npy_sqrt_float32":
                cost = 50
            case "Nb_Pow_Float32_Int64":
                cost = 50

        # Fallthrough to parent's cost function
        return super().get_cost_function(nodename, op, ty, cost, children)

Run the Extended Pipeline¶

Compile and run the original GELU function using the extended pipeline and report the results.

In [18]:
compiler_config = dict(
    converter_class=ExtendEGraphToRVSDG,
    backend=Backend(),
    cost_model=MyCostModel(),
)
In [19]:
if __name__ == "__main__":
    report = Report("Pipeline execution report", enable_nested_metadata=True)
    jit_func = jit_compiler(
        fn=gelu_tanh_forward,
        argtypes=(Float32,),
        ruleset=(
            base_ruleset | setup_argtypes(TypeFloat32) | additional_rules
        ),
        pipeline_report=report,
        **compiler_config,
    ).jit_func
    report.display()
    run_test(gelu_tanh_forward, jit_func, (0.234,), verbose=True)

Pipeline execution report

1. Frontend (12.66ms) ▶
Frontend
Debug Info on RVSDG ▶
--------------------------------original source---------------------------------
   1|def gelu_tanh_forward(a):
   2|    dt = np.float32
   3|    result = (
   4|        dt(0.5)
   5|        * a
   6|        * (
   7|            dt(1)
   8|            + np.tanh(np.sqrt(dt(2) / dt(np.pi)) * (a + dt(0.044715) * a**3))
   9|        )
  10|    )
  11|    return result
----------------------------------inter source----------------------------------
   1|def transformed_gelu_tanh_forward(a):
   2|    """#file: /tmp/ipykernel_3859/2556971375.py"""
   3|    '#loc: 2:8-2:23'
   4|    dt = np.float32
   5|    '#loc: 3:8-10:9'
   6|    result = dt(0.5) * a * (dt(1) + np.tanh(np.sqrt(dt(2) / dt(np.pi)) * (a + dt(0.044715) * a ** 3)))
   7|    '#loc: 11:8-11:21'
   8|    return result
RVSDG ▶
transformed_gelu_tanh_forward = Func (Args (ArgSpec 'a' (PyNone)))
$0 = Region[804] <- !io a
{
  $1 = PyLoadGlobal $0[0] 'np'
  $2 = PyAttr $0[0] $1 'float32'
  $3 = DbgValue 'dt' $2[1]
  $4 = PyFloat 0.5
  $5 = PyCall $3 $2[0] $4
  $6 = PyBinOp * $5[0] $5[1], $0[1]
  $7 = PyInt 1
  $8 = PyCall $3 $6[0] $7
  $9 = PyInt 2
  $10 = PyCall $3 $8[0] $9
  $11 = PyLoadGlobal $10[0] 'np'
  $12 = PyAttr $10[0] $11 'pi'
  $13 = PyCall $3 $12[0] $12[1]
  $14 = PyBinOp / $13[0] $10[1], $13[1]
  $15 = PyLoadGlobal $14[0] 'np'
  $16 = PyAttr $14[0] $15 'sqrt'
  $17 = PyCall $16[1] $16[0] $14[1]
  $18 = PyFloat 0.044715
  $19 = PyCall $3 $17[0] $18
  $20 = PyInt 3
  $21 = PyBinOp ** $19[0] $0[1], $20
  $22 = PyBinOp * $21[0] $19[1], $21[1]
  $23 = PyBinOp + $22[0] $0[1], $22[1]
  $24 = PyBinOp * $23[0] $17[1], $23[1]
  $25 = PyLoadGlobal $24[0] 'np'
  $26 = PyAttr $24[0] $25 'tanh'
  $27 = PyCall $26[1] $26[0] $24[1]
  $28 = PyBinOp + $27[0] $8[1], $27[1]
  $29 = PyBinOp * $28[0] $6[1], $28[1]
  $30 = DbgValue 'result' $29[1]
} [1268] -> !io=$29[0] !ret=$30
[metadata] ▶
time elapsed 12.66ms
timing breakdown:
  8.82ms: Debug Info on RVSDG 
  3.84ms: RVSDG               
2. EGraph Conversion (69.66ms) ▶
EGraph Conversion
EGraph ▶
outer_cluster_InPorts-0 cluster_InPorts-0 outer_cluster_Port-77 cluster_Port-77 outer_cluster_Port-80 cluster_Port-80 outer_cluster_PortList-81 cluster_PortList-81 outer_cluster_Region-1 cluster_Region-1 outer_cluster_Term-51 cluster_Term-51 outer_cluster_Term-37 cluster_Term-37 outer_cluster_Term-6 cluster_Term-6 outer_cluster_Term-52 cluster_Term-52 outer_cluster_Term-29 cluster_Term-29 outer_cluster_Term-59 cluster_Term-59 outer_cluster_Term-79 cluster_Term-79 outer_cluster_Term-38 cluster_Term-38 outer_cluster_Term-68 cluster_Term-68 outer_cluster_Term-70 cluster_Term-70 outer_cluster_Term-74 cluster_Term-74 outer_cluster_Term-78 cluster_Term-78 outer_cluster_Term-61 cluster_Term-61 outer_cluster_Term-18 cluster_Term-18 outer_cluster_Term-15 cluster_Term-15 outer_cluster_Term-4 cluster_Term-4 outer_cluster_Term-27 cluster_Term-27 outer_cluster_Term-54 cluster_Term-54 outer_cluster_Term-42 cluster_Term-42 outer_cluster_Term-10 cluster_Term-10 outer_cluster_Term-67 cluster_Term-67 outer_cluster_Term-8 cluster_Term-8 outer_cluster_Term-12 cluster_Term-12 outer_cluster_Term-24 cluster_Term-24 outer_cluster_Term-23 cluster_Term-23 outer_cluster_Term-32 cluster_Term-32 outer_cluster_Term-55 cluster_Term-55 outer_cluster_Term-83 cluster_Term-83 outer_cluster_Term-16 cluster_Term-16 outer_cluster_Term-22 cluster_Term-22 outer_cluster_Term-49 cluster_Term-49 outer_cluster_Term-53 cluster_Term-53 outer_cluster_Term-45 cluster_Term-45 outer_cluster_Term-72 cluster_Term-72 outer_cluster_Term-48 cluster_Term-48 outer_cluster_Term-7 cluster_Term-7 outer_cluster_Term-60 cluster_Term-60 outer_cluster_Term-76 cluster_Term-76 outer_cluster_Term-31 cluster_Term-31 outer_cluster_Term-50 cluster_Term-50 outer_cluster_Term-5 cluster_Term-5 outer_cluster_Term-56 cluster_Term-56 outer_cluster_Term-64 cluster_Term-64 outer_cluster_Term-62 cluster_Term-62 outer_cluster_Term-69 cluster_Term-69 outer_cluster_Term-75 cluster_Term-75 outer_cluster_Term-19 cluster_Term-19 outer_cluster_Term-46 cluster_Term-46 outer_cluster_Term-73 cluster_Term-73 outer_cluster_Term-3 cluster_Term-3 outer_cluster_Term-20 cluster_Term-20 outer_cluster_Term-43 cluster_Term-43 outer_cluster_Term-25 cluster_Term-25 outer_cluster_Term-41 cluster_Term-41 outer_cluster_Term-47 cluster_Term-47 outer_cluster_Term-63 cluster_Term-63 outer_cluster_Term-2 cluster_Term-2 outer_cluster_Term-26 cluster_Term-26 outer_cluster_Term-57 cluster_Term-57 outer_cluster_Term-13 cluster_Term-13 outer_cluster_Term-65 cluster_Term-65 outer_cluster_Term-71 cluster_Term-71 outer_cluster_Term-58 cluster_Term-58 outer_cluster_Term-82 cluster_Term-82 outer_cluster_Term-34 cluster_Term-34 outer_cluster_Term-39 cluster_Term-39 outer_cluster_Term-33 cluster_Term-33 outer_cluster_Term-30 cluster_Term-30 outer_cluster_Term-14 cluster_Term-14 outer_cluster_Term-84 cluster_Term-84 outer_cluster_Term-11 cluster_Term-11 outer_cluster_Term-35 cluster_Term-35 outer_cluster_Term-36 cluster_Term-36 outer_cluster_TermList-9 cluster_TermList-9 outer_cluster_TermList-17 cluster_TermList-17 outer_cluster_TermList-28 cluster_TermList-28 outer_cluster_TermList-21 cluster_TermList-21 outer_cluster_TermList-40 cluster_TermList-40 outer_cluster_TermList-66 cluster_TermList-66 outer_cluster_TermList-44 cluster_TermList-44 outer_cluster_Vec_Port-0 cluster_Vec_Port-0 outer_cluster_Vec_String-0 cluster_Vec_String-0 outer_cluster_Vec_Term-0 cluster_Vec_Term-0 outer_cluster_Vec_Term-3 cluster_Vec_Term-3 outer_cluster_Vec_Term-5 cluster_Vec_Term-5 outer_cluster_Vec_Term-1 cluster_Vec_Term-1 outer_cluster_Vec_Term-6 cluster_Vec_Term-6 outer_cluster_Vec_Term-4 cluster_Vec_Term-4 outer_cluster_Vec_Term-2 cluster_Vec_Term-2 function-0-InPorts___init__:s->primitive-Vec_String-0 function-0-Port___init__:s->function-36-Term_getPort function-36-Term_getPort:s->function-3-Py_MulIO function-1-Port___init__:s->function-1-Term_DbgValue function-1-Term_DbgValue:s->function-37-Term_getPort function-0-PortList___init__:s->primitive-Vec_Port-0 primitive-Vec_Port-0:s->function-0-Port___init__ primitive-Vec_Port-0:s->function-1-Port___init__ function-0-Region___init__:s->function-0-InPorts___init__ function-20-Term_getPort:s->function-0-Py_PowIO function-0-Py_PowIO:s->function-1-Region_get function-0-Py_PowIO:s->function-17-Term_getPort function-0-Py_PowIO:s->function-2-Term_LiteralI64 function-13-Term_getPort:s->function-2-Py_AttrIO function-2-Py_AttrIO:s->function-12-Term_getPort function-2-Py_AttrIO:s->function-2-Py_LoadGlobal function-0-Term_DbgValue:s->function-0-Term_getPort function-0-Term_getPort:s->function-0-Py_AttrIO function-1-Py_MulIO:s->function-20-Term_getPort function-1-Py_MulIO:s->function-18-Term_getPort function-1-Py_MulIO:s->function-19-Term_getPort function-18-Term_getPort:s->function-0-Py_PowIO function-19-Term_getPort:s->function-5-Py_Call function-3-Py_Call:s->function-0-Term_DbgValue function-3-Py_Call:s->function-7-Term_getPort function-3-Py_Call:s->function-3-TermList___init__ function-7-Term_getPort:s->function-1-Py_AttrIO function-3-TermList___init__:s->primitive-Vec_Term-3 function-2-Py_MulIO:s->function-23-Term_getPort function-2-Py_MulIO:s->function-24-Term_getPort function-2-Py_MulIO:s->function-25-Term_getPort function-23-Term_getPort:s->function-0-Py_AddIO function-24-Term_getPort:s->function-4-Py_Call function-25-Term_getPort:s->function-0-Py_AddIO function-37-Term_getPort:s->function-3-Py_MulIO function-14-Term_getPort:s->function-2-Py_AttrIO function-30-Term_getPort:s->function-6-Py_Call function-6-Py_Call:s->function-27-Term_getPort function-6-Py_Call:s->function-28-Term_getPort function-6-Py_Call:s->function-6-TermList___init__ function-32-Term_getPort:s->function-6-Py_Call function-35-Term_getPort:s->function-1-Py_AddIO function-1-Py_AddIO:s->function-30-Term_getPort function-1-Py_AddIO:s->function-32-Term_getPort function-1-Py_AddIO:s->function-31-Term_getPort function-3-Py_MulIO:s->function-35-Term_getPort function-3-Py_MulIO:s->function-33-Term_getPort function-3-Py_MulIO:s->function-34-Term_getPort function-3-Py_LoadGlobal:s->function-26-Term_getPort function-26-Term_getPort:s->function-2-Py_MulIO function-1-Py_Call:s->function-0-Term_DbgValue function-1-Py_Call:s->function-4-Term_getPort function-1-Py_Call:s->function-1-TermList___init__ function-4-Term_getPort:s->function-0-Py_MulIO function-1-TermList___init__:s->primitive-Vec_Term-1 function-0-Py_MulIO:s->function-3-Term_getPort function-0-Py_MulIO:s->function-1-Region_get function-0-Py_MulIO:s->function-2-Term_getPort function-0-Py_AttrIO:s->function-0-Region_get function-0-Py_AttrIO:s->function-0-Py_LoadGlobal function-0-Region_get:s->function-0-Region___init__ function-0-Py_LoadGlobal:s->function-0-Region_get function-8-Term_getPort:s->function-1-Py_AttrIO function-1-Py_AttrIO:s->function-1-Py_LoadGlobal function-1-Py_AttrIO:s->function-6-Term_getPort function-22-Term_getPort:s->function-1-Py_MulIO function-16-Term_getPort:s->function-4-Py_Call function-4-Py_Call:s->function-13-Term_getPort function-4-Py_Call:s->function-14-Term_getPort function-4-Py_Call:s->function-4-TermList___init__ function-0-Py_Call:s->function-0-Term_DbgValue function-0-Py_Call:s->function-1-Term_getPort function-0-Py_Call:s->function-0-TermList___init__ function-1-Term_getPort:s->function-0-Py_AttrIO function-0-TermList___init__:s->primitive-Vec_Term-0 function-27-Term_getPort:s->function-3-Py_AttrIO function-28-Term_getPort:s->function-3-Py_AttrIO function-6-TermList___init__:s->primitive-Vec_Term-6 function-3-Term_getPort:s->function-0-Py_Call function-1-Py_LoadGlobal:s->function-6-Term_getPort function-6-Term_getPort:s->function-2-Py_Call function-2-Py_Call:s->function-0-Term_DbgValue function-2-Py_Call:s->function-5-Term_getPort function-2-Py_Call:s->function-2-TermList___init__ function-11-Term_getPort:s->function-3-Py_Call function-0-Py_AddIO:s->function-22-Term_getPort function-0-Py_AddIO:s->function-21-Term_getPort function-0-Py_AddIO:s->function-1-Region_get function-21-Term_getPort:s->function-1-Py_MulIO function-1-Region_get:s->function-0-Region___init__ function-0-Term_Func:s->function-0-Term_RegionEnd function-0-Term_RegionEnd:s->function-0-PortList___init__ function-0-Term_RegionEnd:s->function-0-Region___init__ function-5-Term_getPort:s->function-1-Py_Call function-2-TermList___init__:s->primitive-Vec_Term-2 function-5-Py_Call:s->function-0-Term_DbgValue function-5-Py_Call:s->function-16-Term_getPort function-5-Py_Call:s->function-5-TermList___init__ function-5-TermList___init__:s->primitive-Vec_Term-5 function-33-Term_getPort:s->function-1-Py_AddIO function-17-Term_getPort:s->function-5-Py_Call function-10-Term_getPort:s->function-2-Py_Call function-3-Py_AttrIO:s->function-3-Py_LoadGlobal function-3-Py_AttrIO:s->function-26-Term_getPort function-31-Term_getPort:s->function-1-Py_Call function-34-Term_getPort:s->function-0-Py_MulIO function-4-TermList___init__:s->primitive-Vec_Term-4 function-29-Term_getPort:s->function-2-Py_MulIO function-12-Term_getPort:s->function-0-Py_DivIO function-0-Py_DivIO:s->function-11-Term_getPort function-0-Py_DivIO:s->function-10-Term_getPort function-0-Py_DivIO:s->function-9-Term_getPort function-15-Term_getPort:s->function-0-Py_DivIO function-9-Term_getPort:s->function-3-Py_Call function-2-Term_getPort:s->function-0-Py_Call function-0-GraphRoot:s->function-0-Term_Func function-2-Py_LoadGlobal:s->function-12-Term_getPort primitive-Vec_Term-0:s->function-0-Term_LiteralF64 primitive-Vec_Term-1:s->function-0-Term_LiteralI64 primitive-Vec_Term-3:s->function-8-Term_getPort primitive-Vec_Term-2:s->function-1-Term_LiteralI64 primitive-Vec_Term-4:s->function-15-Term_getPort primitive-Vec_Term-6:s->function-29-Term_getPort primitive-Vec_Term-5:s->function-1-Term_LiteralF64 function-0-InPorts___init__ InPorts primitive-Vec_String-0 Vec("!io", "a") function-0-Port___init__ Port("!io", ·) function-36-Term_getPort ·.getPort(·, 0) function-1-Port___init__ Port("!ret", ·) function-1-Term_DbgValue Term.DbgValue("result", ·) function-0-PortList___init__ PortList primitive-Vec_Port-0 Vec function-0-Region___init__ Region("804", ·) function-20-Term_getPort ·.getPort(·, 1) function-0-Py_PowIO Py_PowIO function-13-Term_getPort ·.getPort(·, 1) function-2-Py_AttrIO Py_AttrIO(·, ·, "sqrt") function-0-Term_DbgValue Term.DbgValue("dt", ·) function-0-Term_getPort ·.getPort(·, 1) function-1-Py_MulIO Py_MulIO function-18-Term_getPort ·.getPort(·, 0) function-19-Term_getPort ·.getPort(·, 1) function-3-Py_Call Py_Call function-7-Term_getPort ·.getPort(·, 0) function-3-TermList___init__ TermList function-2-Py_MulIO Py_MulIO function-23-Term_getPort ·.getPort(·, 0) function-24-Term_getPort ·.getPort(·, 1) function-25-Term_getPort ·.getPort(·, 1) function-37-Term_getPort ·.getPort(·, 1) function-14-Term_getPort ·.getPort(·, 0) function-30-Term_getPort ·.getPort(·, 0) function-6-Py_Call Py_Call function-32-Term_getPort ·.getPort(·, 1) function-35-Term_getPort ·.getPort(·, 1) function-1-Py_AddIO Py_AddIO function-3-Py_MulIO Py_MulIO function-3-Py_LoadGlobal Py_LoadGlobal(·, "np") function-26-Term_getPort ·.getPort(·, 0) function-1-Py_Call Py_Call function-4-Term_getPort ·.getPort(·, 0) function-1-TermList___init__ TermList function-0-Py_MulIO Py_MulIO function-0-Py_AttrIO Py_AttrIO(·, ·, "float32") function-0-Region_get ·.get(·, 0) function-0-Py_LoadGlobal Py_LoadGlobal(·, "np") function-8-Term_getPort ·.getPort(·, 1) function-1-Py_AttrIO Py_AttrIO(·, ·, "pi") function-22-Term_getPort ·.getPort(·, 1) function-16-Term_getPort ·.getPort(·, 0) function-4-Py_Call Py_Call function-0-Py_Call Py_Call function-1-Term_getPort ·.getPort(·, 0) function-0-TermList___init__ TermList function-27-Term_getPort ·.getPort(·, 1) function-28-Term_getPort ·.getPort(·, 0) function-6-TermList___init__ TermList function-0-Term_LiteralF64 Term.LiteralF64(0.5) function-3-Term_getPort ·.getPort(·, 1) function-1-Py_LoadGlobal Py_LoadGlobal(·, "np") function-6-Term_getPort ·.getPort(·, 0) function-2-Py_Call Py_Call function-11-Term_getPort ·.getPort(·, 1) function-0-Py_AddIO Py_AddIO function-21-Term_getPort ·.getPort(·, 0) function-1-Region_get ·.get(·, 1) function-0-Term_Func Term.Func("1274", "transformed_gelu_tanh_forward", ·) function-0-Term_RegionEnd Term.RegionEnd function-0-Term_LiteralI64 Term.LiteralI64(1) function-5-Term_getPort ·.getPort(·, 0) function-2-TermList___init__ TermList function-5-Py_Call Py_Call function-5-TermList___init__ TermList function-33-Term_getPort ·.getPort(·, 0) function-17-Term_getPort ·.getPort(·, 0) function-2-Term_LiteralI64 Term.LiteralI64(3) function-10-Term_getPort ·.getPort(·, 1) function-3-Py_AttrIO Py_AttrIO(·, ·, "tanh") function-31-Term_getPort ·.getPort(·, 1) function-34-Term_getPort ·.getPort(·, 1) function-1-Term_LiteralI64 Term.LiteralI64(2) function-1-Term_LiteralF64 Term.LiteralF64(0.044715) function-4-TermList___init__ TermList function-29-Term_getPort ·.getPort(·, 1) function-12-Term_getPort ·.getPort(·, 0) function-0-Py_DivIO Py_DivIO function-15-Term_getPort ·.getPort(·, 1) function-9-Term_getPort ·.getPort(·, 0) function-2-Term_getPort ·.getPort(·, 0) function-0-GraphRoot GraphRoot function-2-Py_LoadGlobal Py_LoadGlobal(·, "np") primitive-Vec_Term-0 Vec primitive-Vec_Term-1 Vec primitive-Vec_Term-3 Vec primitive-Vec_Term-2 Vec primitive-Vec_Term-4 Vec primitive-Vec_Term-6 Vec primitive-Vec_Term-5 Vec
[metadata] ▶
time elapsed 69.66ms
timing breakdown:
  69.66ms: EGraph              
3. Egraph Saturation (0.00ms) ▶
Egraph Saturation
[metadata] ▶
time elapsed 0.00ms
timing breakdown:
4. EGraph Extraction (11.96ms) ▶
EGraph Extraction
Extracted RVSDG ▶
transformed_gelu_tanh_forward = Func (Args (ArgSpec 'a' (PyNone)))
$0 = Region[1575] <- !io a; #attrs (_, Float32)->(_, Float32)
{
  $1 = PyFloat 0.5
  $2 = NbOp_F64_to_F32 $1
  $3 = NbOp_Mul_Float32 $2 $0[1]
  $4 = PyInt 1
  $5 = NbOp_I64_to_F32 $4
  $6 = PyInt 2
  $7 = NbOp_I64_to_F32 $6
  $8 = PyFloat 3.141592653589793
  $9 = NbOp_F64_to_F32 $8
  $10 = NbOp_Div_Float32 $7 $9
  $11 = NpyOp_Sqrt_Float32 $10
  $12 = PyFloat 0.044715
  $13 = NbOp_F64_to_F32 $12
  $14 = PyInt 3
  $15 = NbOp_Pow_Float32_Int64 $0[1] $14
  $16 = NbOp_Mul_Float32 $13 $15
  $17 = NbOp_Add_Float32 $0[1] $16
  $18 = NbOp_Mul_Float32 $11 $17
  $19 = NpyOp_Tanh_Float32 $18
  $20 = NbOp_Add_Float32 $5 $19
  $21 = NbOp_Mul_Float32 $3 $20
} [1693] -> !io=$0[0] !ret=$21
Extracted cost ▶
34139.0
[metadata] ▶
time elapsed 11.96ms
timing breakdown:
  11.95ms: Extracted RVSDG     
  0.01ms: Extracted cost      
5. Backend (2.73ms) ▶
Backend
Lowered module ▶
module {
  func.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} {
    %cst = arith.constant 5.000000e-01 : f64
    %c1_i64 = arith.constant 1 : i64
    %c2_i64 = arith.constant 2 : i64
    %cst_0 = arith.constant 3.1415926535897931 : f64
    %cst_1 = arith.constant 4.471500e-02 : f64
    %c3_i64 = arith.constant 3 : i64
    cf.br ^bb1
  ^bb1:  // pred: ^bb0
    %c0_i32 = arith.constant 0 : i32
    %0 = arith.truncf %cst : f64 to f32
    %1 = arith.mulf %0, %arg0 : f32
    %2 = arith.sitofp %c1_i64 : i64 to f32
    %3 = arith.sitofp %c2_i64 : i64 to f32
    %4 = arith.truncf %cst_0 : f64 to f32
    %5 = arith.divf %3, %4 : f32
    %6 = math.sqrt %5 : f32
    %7 = arith.truncf %cst_1 : f64 to f32
    %8 = arith.sitofp %c3_i64 : i64 to f32
    %9 = math.powf %arg0, %8 : f32
    %10 = arith.mulf %7, %9 : f32
    %11 = arith.addf %arg0, %10 : f32
    %12 = arith.mulf %6, %11 : f32
    %13 = math.tanh %12 : f32
    %14 = arith.addf %2, %13 : f32
    %15 = arith.mulf %1, %14 : f32
    return %15 : f32
  }
}
[metadata] ▶
time elapsed 2.73ms
timing breakdown:
  2.73ms: Lowered module      
6. MLIR passes (1.90ms) ▶
MLIR passes
MLIR optimized ▶
module {
  llvm.func @tanhf(f32) -> f32 attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, sym_visibility = "private"}
  llvm.func @powf(f32, f32) -> f32 attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, sym_visibility = "private"}
  llvm.func @sqrtf(f32) -> f32 attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, sym_visibility = "private"}
  llvm.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(3.000000e+00 : f32) : f32
    %1 = llvm.mlir.constant(2.000000e+00 : f32) : f32
    %2 = llvm.mlir.constant(1.000000e+00 : f32) : f32
    %3 = llvm.mlir.constant(5.000000e-01 : f32) : f32
    %4 = llvm.mlir.constant(3.1415926535897931 : f64) : f64
    %5 = llvm.mlir.constant(4.471500e-02 : f64) : f64
    llvm.br ^bb1
  ^bb1:  // pred: ^bb0
    %6 = llvm.fmul %arg0, %3  : f32
    %7 = llvm.fptrunc %4 : f64 to f32
    %8 = llvm.fdiv %1, %7  : f32
    %9 = llvm.call @sqrtf(%8) : (f32) -> f32
    %10 = llvm.fptrunc %5 : f64 to f32
    %11 = llvm.call @powf(%arg0, %0) : (f32, f32) -> f32
    %12 = llvm.fmul %10, %11  : f32
    %13 = llvm.fadd %arg0, %12  : f32
    %14 = llvm.fmul %9, %13  : f32
    %15 = llvm.call @tanhf(%14) : (f32) -> f32
    %16 = llvm.fadd %15, %2  : f32
    %17 = llvm.fmul %6, %16  : f32
    llvm.return %17 : f32
  }
  llvm.func @_mlir_ciface_func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} {
    %0 = llvm.call @func(%arg0) : (f32) -> f32
    llvm.return %0 : f32
  }
}
[metadata] ▶
time elapsed 1.90ms
timing breakdown:
  1.90ms: MLIR optimized      

Testing report

1. Args ▶
(0.234,)
2. JIT output ▶
0.13864579796791077
3. Expected output ▶
0.1386458

Add Rules to Optimize¶

Add rewrite rules for the Pade44 approximation of tanh(x) and for expanding power operations into multiplication chains. These rules enable the optimizer to replace expensive transcendental functions with efficient arithmetic.

Define Pade44 approximation rewrite rule for tanh(x)¶

The Pade44 approximation replaces tanh(x) with a rational function:

$$ tanh(x) ≈ (10x^3 + 105x) / (x^4 + 45x^2 + 105) $$

This approximation provides good accuracy for small to moderate values of x while avoiding the computational cost of the transcendental tanh function. The rational form allows for efficient evaluation using only basic arithmetic operations (addition, multiplication, division) and power operations.

In [20]:
@ruleset
def pade44_tanh_expansion(x: Term):
    flt = lambda f: Npy_float32(Term.LiteralF64(float(f)))
    liti64 = Term.LiteralI64
    pow = Nb_Pow_Float32_Int64
    mul = Nb_Mul_Float32
    add = Nb_Add_Float32
    div = Nb_Div_Float32
    # Rewrite tanh(x) to the pade44 approximation
    # Note the use of pow(), they will be optimized
    # by the `pow_expansion` ruleset.
    yield rewrite(Npy_tanh_float32(x)).to(
        div(
            add(mul(flt(10), pow(x, liti64(3))), mul(flt(105), x)),
            add(
                add(pow(x, liti64(4)), mul(flt(45), pow(x, liti64(2)))),
                flt(105),
            ),
        )
    )

Define expansion of power operations (e.g., x^N) to a sequence of multiplications This converts expensive power operations into more efficient multiplication chains For example: x^3 becomes x * x * x, and x^0 becomes 1.0 Note: The exponent N must be a compile-time constant for this optimization to work

In [21]:
@ruleset
def pow_expansion(x: Term, ival: i64):
    # Rules to expand pow(x, i) to multiplcations
    powf = Nb_Pow_Float32_Int64
    lit64 = Term.LiteralI64
    mulf = Nb_Mul_Float32
    yield rewrite(powf(x, lit64(ival))).to(
        mulf(x, powf(x, lit64(ival - 1))),
        ival >= 1,
    )

    yield rewrite(powf(x, lit64(i64(0))), subsume=True).to(
        Npy_float32(Term.LiteralF64(float(1))),
    )

Combine the rules. Rules are composable.

In [22]:
optimize_rules = pade44_tanh_expansion | pow_expansion

Run the Optimized Function¶

Compile and run the optimized GELU function using the new rules, and report the results.

In [23]:
if __name__ == "__main__":
    report = Report("Pipeline execution report", enable_nested_metadata=True)
    jit_func = jit_compiler(
        fn=gelu_tanh_forward,
        argtypes=(Float32,),
        ruleset=(
            base_ruleset
            | setup_argtypes(TypeFloat32)
            | additional_rules
            | optimize_rules
        ),
        pipeline_report=report,
        **compiler_config,
    ).jit_func
    report.display()

Pipeline execution report

1. Frontend (9.57ms) ▶
Frontend
Debug Info on RVSDG ▶
--------------------------------original source---------------------------------
   1|def gelu_tanh_forward(a):
   2|    dt = np.float32
   3|    result = (
   4|        dt(0.5)
   5|        * a
   6|        * (
   7|            dt(1)
   8|            + np.tanh(np.sqrt(dt(2) / dt(np.pi)) * (a + dt(0.044715) * a**3))
   9|        )
  10|    )
  11|    return result
----------------------------------inter source----------------------------------
   1|def transformed_gelu_tanh_forward(a):
   2|    """#file: /tmp/ipykernel_3859/2556971375.py"""
   3|    '#loc: 2:8-2:23'
   4|    dt = np.float32
   5|    '#loc: 3:8-10:9'
   6|    result = dt(0.5) * a * (dt(1) + np.tanh(np.sqrt(dt(2) / dt(np.pi)) * (a + dt(0.044715) * a ** 3)))
   7|    '#loc: 11:8-11:21'
   8|    return result
RVSDG ▶
transformed_gelu_tanh_forward = Func (Args (ArgSpec 'a' (PyNone)))
$0 = Region[804] <- !io a
{
  $1 = PyLoadGlobal $0[0] 'np'
  $2 = PyAttr $0[0] $1 'float32'
  $3 = DbgValue 'dt' $2[1]
  $4 = PyFloat 0.5
  $5 = PyCall $3 $2[0] $4
  $6 = PyBinOp * $5[0] $5[1], $0[1]
  $7 = PyInt 1
  $8 = PyCall $3 $6[0] $7
  $9 = PyInt 2
  $10 = PyCall $3 $8[0] $9
  $11 = PyLoadGlobal $10[0] 'np'
  $12 = PyAttr $10[0] $11 'pi'
  $13 = PyCall $3 $12[0] $12[1]
  $14 = PyBinOp / $13[0] $10[1], $13[1]
  $15 = PyLoadGlobal $14[0] 'np'
  $16 = PyAttr $14[0] $15 'sqrt'
  $17 = PyCall $16[1] $16[0] $14[1]
  $18 = PyFloat 0.044715
  $19 = PyCall $3 $17[0] $18
  $20 = PyInt 3
  $21 = PyBinOp ** $19[0] $0[1], $20
  $22 = PyBinOp * $21[0] $19[1], $21[1]
  $23 = PyBinOp + $22[0] $0[1], $22[1]
  $24 = PyBinOp * $23[0] $17[1], $23[1]
  $25 = PyLoadGlobal $24[0] 'np'
  $26 = PyAttr $24[0] $25 'tanh'
  $27 = PyCall $26[1] $26[0] $24[1]
  $28 = PyBinOp + $27[0] $8[1], $27[1]
  $29 = PyBinOp * $28[0] $6[1], $28[1]
  $30 = DbgValue 'result' $29[1]
} [1268] -> !io=$29[0] !ret=$30
[metadata] ▶
time elapsed 9.57ms
timing breakdown:
  6.40ms: Debug Info on RVSDG 
  3.17ms: RVSDG               
2. EGraph Conversion (62.73ms) ▶
EGraph Conversion
EGraph ▶
outer_cluster_InPorts-0 cluster_InPorts-0 outer_cluster_Port-77 cluster_Port-77 outer_cluster_Port-80 cluster_Port-80 outer_cluster_PortList-81 cluster_PortList-81 outer_cluster_Region-1 cluster_Region-1 outer_cluster_Term-50 cluster_Term-50 outer_cluster_Term-58 cluster_Term-58 outer_cluster_Term-65 cluster_Term-65 outer_cluster_Term-2 cluster_Term-2 outer_cluster_Term-36 cluster_Term-36 outer_cluster_Term-55 cluster_Term-55 outer_cluster_Term-19 cluster_Term-19 outer_cluster_Term-22 cluster_Term-22 outer_cluster_Term-13 cluster_Term-13 outer_cluster_Term-49 cluster_Term-49 outer_cluster_Term-62 cluster_Term-62 outer_cluster_Term-54 cluster_Term-54 outer_cluster_Term-59 cluster_Term-59 outer_cluster_Term-8 cluster_Term-8 outer_cluster_Term-83 cluster_Term-83 outer_cluster_Term-48 cluster_Term-48 outer_cluster_Term-14 cluster_Term-14 outer_cluster_Term-3 cluster_Term-3 outer_cluster_Term-73 cluster_Term-73 outer_cluster_Term-26 cluster_Term-26 outer_cluster_Term-27 cluster_Term-27 outer_cluster_Term-38 cluster_Term-38 outer_cluster_Term-42 cluster_Term-42 outer_cluster_Term-63 cluster_Term-63 outer_cluster_Term-64 cluster_Term-64 outer_cluster_Term-10 cluster_Term-10 outer_cluster_Term-37 cluster_Term-37 outer_cluster_Term-45 cluster_Term-45 outer_cluster_Term-31 cluster_Term-31 outer_cluster_Term-71 cluster_Term-71 outer_cluster_Term-5 cluster_Term-5 outer_cluster_Term-70 cluster_Term-70 outer_cluster_Term-15 cluster_Term-15 outer_cluster_Term-78 cluster_Term-78 outer_cluster_Term-29 cluster_Term-29 outer_cluster_Term-52 cluster_Term-52 outer_cluster_Term-68 cluster_Term-68 outer_cluster_Term-72 cluster_Term-72 outer_cluster_Term-82 cluster_Term-82 outer_cluster_Term-47 cluster_Term-47 outer_cluster_Term-6 cluster_Term-6 outer_cluster_Term-56 cluster_Term-56 outer_cluster_Term-61 cluster_Term-61 outer_cluster_Term-4 cluster_Term-4 outer_cluster_Term-43 cluster_Term-43 outer_cluster_Term-7 cluster_Term-7 outer_cluster_Term-11 cluster_Term-11 outer_cluster_Term-18 cluster_Term-18 outer_cluster_Term-39 cluster_Term-39 outer_cluster_Term-25 cluster_Term-25 outer_cluster_Term-60 cluster_Term-60 outer_cluster_Term-34 cluster_Term-34 outer_cluster_Term-16 cluster_Term-16 outer_cluster_Term-32 cluster_Term-32 outer_cluster_Term-57 cluster_Term-57 outer_cluster_Term-24 cluster_Term-24 outer_cluster_Term-23 cluster_Term-23 outer_cluster_Term-46 cluster_Term-46 outer_cluster_Term-67 cluster_Term-67 outer_cluster_Term-84 cluster_Term-84 outer_cluster_Term-51 cluster_Term-51 outer_cluster_Term-76 cluster_Term-76 outer_cluster_Term-53 cluster_Term-53 outer_cluster_Term-75 cluster_Term-75 outer_cluster_Term-41 cluster_Term-41 outer_cluster_Term-20 cluster_Term-20 outer_cluster_Term-35 cluster_Term-35 outer_cluster_Term-79 cluster_Term-79 outer_cluster_Term-12 cluster_Term-12 outer_cluster_Term-74 cluster_Term-74 outer_cluster_Term-30 cluster_Term-30 outer_cluster_Term-69 cluster_Term-69 outer_cluster_Term-33 cluster_Term-33 outer_cluster_TermList-21 cluster_TermList-21 outer_cluster_TermList-40 cluster_TermList-40 outer_cluster_TermList-9 cluster_TermList-9 outer_cluster_TermList-28 cluster_TermList-28 outer_cluster_TermList-66 cluster_TermList-66 outer_cluster_TermList-17 cluster_TermList-17 outer_cluster_TermList-44 cluster_TermList-44 outer_cluster_Vec_Port-0 cluster_Vec_Port-0 outer_cluster_Vec_String-0 cluster_Vec_String-0 outer_cluster_Vec_Term-0 cluster_Vec_Term-0 outer_cluster_Vec_Term-2 cluster_Vec_Term-2 outer_cluster_Vec_Term-3 cluster_Vec_Term-3 outer_cluster_Vec_Term-5 cluster_Vec_Term-5 outer_cluster_Vec_Term-4 cluster_Vec_Term-4 outer_cluster_Vec_Term-6 cluster_Vec_Term-6 outer_cluster_Vec_Term-1 cluster_Vec_Term-1 function-0-InPorts___init__:s->primitive-Vec_String-0 function-0-Port___init__:s->function-36-Term_getPort function-36-Term_getPort:s->function-3-Py_MulIO function-1-Port___init__:s->function-1-Term_DbgValue function-1-Term_DbgValue:s->function-37-Term_getPort function-0-PortList___init__:s->primitive-Vec_Port-0 primitive-Vec_Port-0:s->function-0-Port___init__ primitive-Vec_Port-0:s->function-1-Port___init__ function-0-Region___init__:s->function-0-InPorts___init__ function-19-Term_getPort:s->function-5-Py_Call function-5-Py_Call:s->function-0-Term_DbgValue function-5-Py_Call:s->function-16-Term_getPort function-5-Py_Call:s->function-5-TermList___init__ function-25-Term_getPort:s->function-0-Py_AddIO function-0-Py_AddIO:s->function-21-Term_getPort function-0-Py_AddIO:s->function-1-Region_get function-0-Py_AddIO:s->function-22-Term_getPort function-29-Term_getPort:s->function-2-Py_MulIO function-2-Py_MulIO:s->function-25-Term_getPort function-2-Py_MulIO:s->function-23-Term_getPort function-2-Py_MulIO:s->function-24-Term_getPort function-0-Region_get:s->function-0-Region___init__ function-2-Py_AttrIO:s->function-12-Term_getPort function-2-Py_AttrIO:s->function-2-Py_LoadGlobal function-12-Term_getPort:s->function-0-Py_DivIO function-2-Py_LoadGlobal:s->function-12-Term_getPort function-21-Term_getPort:s->function-1-Py_MulIO function-1-Region_get:s->function-0-Region___init__ function-22-Term_getPort:s->function-1-Py_MulIO function-5-Term_getPort:s->function-1-Py_Call function-1-Py_Call:s->function-0-Term_DbgValue function-1-Py_Call:s->function-4-Term_getPort function-1-Py_Call:s->function-1-TermList___init__ function-2-Py_Call:s->function-5-Term_getPort function-2-Py_Call:s->function-0-Term_DbgValue function-2-Py_Call:s->function-2-TermList___init__ function-0-Term_DbgValue:s->function-0-Term_getPort function-2-TermList___init__:s->primitive-Vec_Term-2 function-18-Term_getPort:s->function-0-Py_PowIO function-0-Py_PowIO:s->function-1-Region_get function-0-Py_PowIO:s->function-17-Term_getPort function-0-Py_PowIO:s->function-2-Term_LiteralI64 function-3-Py_AttrIO:s->function-26-Term_getPort function-3-Py_AttrIO:s->function-3-Py_LoadGlobal function-26-Term_getPort:s->function-2-Py_MulIO function-3-Py_LoadGlobal:s->function-26-Term_getPort function-1-Py_MulIO:s->function-19-Term_getPort function-1-Py_MulIO:s->function-18-Term_getPort function-1-Py_MulIO:s->function-20-Term_getPort function-23-Term_getPort:s->function-0-Py_AddIO function-24-Term_getPort:s->function-4-Py_Call function-0-Term_Func:s->function-0-Term_RegionEnd function-0-Term_RegionEnd:s->function-0-PortList___init__ function-0-Term_RegionEnd:s->function-0-Region___init__ function-17-Term_getPort:s->function-5-Py_Call function-0-Py_MulIO:s->function-1-Region_get function-0-Py_MulIO:s->function-2-Term_getPort function-0-Py_MulIO:s->function-3-Term_getPort function-2-Term_getPort:s->function-0-Py_Call function-3-Term_getPort:s->function-0-Py_Call function-0-Py_LoadGlobal:s->function-0-Region_get function-34-Term_getPort:s->function-0-Py_MulIO function-7-Term_getPort:s->function-1-Py_AttrIO function-1-Py_AttrIO:s->function-6-Term_getPort function-1-Py_AttrIO:s->function-1-Py_LoadGlobal function-8-Term_getPort:s->function-1-Py_AttrIO function-14-Term_getPort:s->function-2-Py_AttrIO function-16-Term_getPort:s->function-4-Py_Call function-4-Py_Call:s->function-14-Term_getPort function-4-Py_Call:s->function-13-Term_getPort function-4-Py_Call:s->function-4-TermList___init__ function-27-Term_getPort:s->function-3-Py_AttrIO function-28-Term_getPort:s->function-3-Py_AttrIO function-0-Py_Call:s->function-0-Term_DbgValue function-0-Py_Call:s->function-1-Term_getPort function-0-Py_Call:s->function-0-TermList___init__ function-1-Term_getPort:s->function-0-Py_AttrIO function-0-TermList___init__:s->primitive-Vec_Term-0 function-13-Term_getPort:s->function-2-Py_AttrIO function-5-TermList___init__:s->primitive-Vec_Term-5 function-10-Term_getPort:s->function-2-Py_Call function-1-Py_AddIO:s->function-30-Term_getPort function-1-Py_AddIO:s->function-31-Term_getPort function-1-Py_AddIO:s->function-32-Term_getPort function-30-Term_getPort:s->function-6-Py_Call function-31-Term_getPort:s->function-1-Py_Call function-32-Term_getPort:s->function-6-Py_Call function-0-Term_getPort:s->function-0-Py_AttrIO function-0-Py_AttrIO:s->function-0-Region_get function-0-Py_AttrIO:s->function-0-Py_LoadGlobal function-6-Py_Call:s->function-27-Term_getPort function-6-Py_Call:s->function-28-Term_getPort function-6-Py_Call:s->function-6-TermList___init__ function-4-Term_getPort:s->function-0-Py_MulIO function-37-Term_getPort:s->function-3-Py_MulIO function-3-Py_MulIO:s->function-34-Term_getPort function-3-Py_MulIO:s->function-33-Term_getPort function-3-Py_MulIO:s->function-35-Term_getPort function-3-Py_Call:s->function-0-Term_DbgValue function-3-Py_Call:s->function-7-Term_getPort function-3-Py_Call:s->function-3-TermList___init__ function-3-TermList___init__:s->primitive-Vec_Term-3 function-20-Term_getPort:s->function-0-Py_PowIO function-33-Term_getPort:s->function-1-Py_AddIO function-1-TermList___init__:s->primitive-Vec_Term-1 function-15-Term_getPort:s->function-0-Py_DivIO function-0-Py_DivIO:s->function-10-Term_getPort function-0-Py_DivIO:s->function-11-Term_getPort function-0-Py_DivIO:s->function-9-Term_getPort function-6-Term_getPort:s->function-2-Py_Call function-1-Py_LoadGlobal:s->function-6-Term_getPort function-11-Term_getPort:s->function-3-Py_Call function-6-TermList___init__:s->primitive-Vec_Term-6 function-0-GraphRoot:s->function-0-Term_Func function-35-Term_getPort:s->function-1-Py_AddIO function-4-TermList___init__:s->primitive-Vec_Term-4 function-9-Term_getPort:s->function-3-Py_Call primitive-Vec_Term-2:s->function-1-Term_LiteralI64 primitive-Vec_Term-4:s->function-15-Term_getPort primitive-Vec_Term-0:s->function-0-Term_LiteralF64 primitive-Vec_Term-3:s->function-8-Term_getPort primitive-Vec_Term-6:s->function-29-Term_getPort primitive-Vec_Term-1:s->function-0-Term_LiteralI64 primitive-Vec_Term-5:s->function-1-Term_LiteralF64 function-0-InPorts___init__ InPorts primitive-Vec_String-0 Vec("!io", "a") function-0-Port___init__ Port("!io", ·) function-36-Term_getPort ·.getPort(·, 0) function-1-Port___init__ Port("!ret", ·) function-1-Term_DbgValue Term.DbgValue("result", ·) function-0-PortList___init__ PortList primitive-Vec_Port-0 Vec function-0-Region___init__ Region("804", ·) function-19-Term_getPort ·.getPort(·, 1) function-5-Py_Call Py_Call function-25-Term_getPort ·.getPort(·, 1) function-0-Py_AddIO Py_AddIO function-29-Term_getPort ·.getPort(·, 1) function-2-Py_MulIO Py_MulIO function-0-Region_get ·.get(·, 0) function-2-Py_AttrIO Py_AttrIO(·, ·, "sqrt") function-12-Term_getPort ·.getPort(·, 0) function-2-Py_LoadGlobal Py_LoadGlobal(·, "np") function-21-Term_getPort ·.getPort(·, 0) function-1-Region_get ·.get(·, 1) function-22-Term_getPort ·.getPort(·, 1) function-5-Term_getPort ·.getPort(·, 0) function-1-Py_Call Py_Call function-2-Py_Call Py_Call function-0-Term_DbgValue Term.DbgValue("dt", ·) function-2-TermList___init__ TermList function-18-Term_getPort ·.getPort(·, 0) function-0-Py_PowIO Py_PowIO function-3-Py_AttrIO Py_AttrIO(·, ·, "tanh") function-26-Term_getPort ·.getPort(·, 0) function-3-Py_LoadGlobal Py_LoadGlobal(·, "np") function-1-Py_MulIO Py_MulIO function-23-Term_getPort ·.getPort(·, 0) function-24-Term_getPort ·.getPort(·, 1) function-0-Term_LiteralF64 Term.LiteralF64(0.5) function-0-Term_Func Term.Func("1274", "transformed_gelu_tanh_forward", ·) function-0-Term_RegionEnd Term.RegionEnd function-17-Term_getPort ·.getPort(·, 0) function-2-Term_LiteralI64 Term.LiteralI64(3) function-0-Py_MulIO Py_MulIO function-2-Term_getPort ·.getPort(·, 0) function-3-Term_getPort ·.getPort(·, 1) function-0-Py_LoadGlobal Py_LoadGlobal(·, "np") function-34-Term_getPort ·.getPort(·, 1) function-7-Term_getPort ·.getPort(·, 0) function-1-Py_AttrIO Py_AttrIO(·, ·, "pi") function-8-Term_getPort ·.getPort(·, 1) function-14-Term_getPort ·.getPort(·, 0) function-16-Term_getPort ·.getPort(·, 0) function-4-Py_Call Py_Call function-27-Term_getPort ·.getPort(·, 1) function-28-Term_getPort ·.getPort(·, 0) function-0-Py_Call Py_Call function-1-Term_getPort ·.getPort(·, 0) function-0-TermList___init__ TermList function-13-Term_getPort ·.getPort(·, 1) function-5-TermList___init__ TermList function-10-Term_getPort ·.getPort(·, 1) function-1-Py_AddIO Py_AddIO function-30-Term_getPort ·.getPort(·, 0) function-31-Term_getPort ·.getPort(·, 1) function-32-Term_getPort ·.getPort(·, 1) function-0-Term_getPort ·.getPort(·, 1) function-0-Py_AttrIO Py_AttrIO(·, ·, "float32") function-6-Py_Call Py_Call function-4-Term_getPort ·.getPort(·, 0) function-37-Term_getPort ·.getPort(·, 1) function-3-Py_MulIO Py_MulIO function-3-Py_Call Py_Call function-3-TermList___init__ TermList function-20-Term_getPort ·.getPort(·, 1) function-33-Term_getPort ·.getPort(·, 0) function-1-Term_LiteralF64 Term.LiteralF64(0.044715) function-1-TermList___init__ TermList function-15-Term_getPort ·.getPort(·, 1) function-0-Py_DivIO Py_DivIO function-6-Term_getPort ·.getPort(·, 0) function-1-Py_LoadGlobal Py_LoadGlobal(·, "np") function-0-Term_LiteralI64 Term.LiteralI64(1) function-11-Term_getPort ·.getPort(·, 1) function-6-TermList___init__ TermList function-0-GraphRoot GraphRoot function-35-Term_getPort ·.getPort(·, 1) function-4-TermList___init__ TermList function-1-Term_LiteralI64 Term.LiteralI64(2) function-9-Term_getPort ·.getPort(·, 0) primitive-Vec_Term-2 Vec primitive-Vec_Term-4 Vec primitive-Vec_Term-0 Vec primitive-Vec_Term-3 Vec primitive-Vec_Term-6 Vec primitive-Vec_Term-1 Vec primitive-Vec_Term-5 Vec
[metadata] ▶
time elapsed 62.73ms
timing breakdown:
  62.73ms: EGraph              
3. Egraph Saturation (0.00ms) ▶
Egraph Saturation
[metadata] ▶
time elapsed 0.00ms
timing breakdown:
4. EGraph Extraction (17.75ms) ▶
EGraph Extraction
Extracted RVSDG ▶
transformed_gelu_tanh_forward = Func (Args (ArgSpec 'a' (PyNone)))
$0 = Region[1575] <- !io a; #attrs (_, Float32)->(_, Float32)
{
  $1 = PyFloat 0.5
  $2 = NbOp_F64_to_F32 $1
  $3 = NbOp_Mul_Float32 $2 $0[1]
  $4 = PyInt 1
  $5 = NbOp_I64_to_F32 $4
  $6 = PyFloat 10.0
  $7 = NbOp_F64_to_F32 $6
  $8 = PyInt 2
  $9 = NbOp_I64_to_F32 $8
  $10 = PyFloat 3.141592653589793
  $11 = NbOp_F64_to_F32 $10
  $12 = NbOp_Div_Float32 $9 $11
  $13 = NpyOp_Sqrt_Float32 $12
  $14 = PyFloat 0.044715
  $15 = NbOp_F64_to_F32 $14
  $16 = PyFloat 1.0
  $17 = NbOp_F64_to_F32 $16
  $18 = NbOp_Mul_Float32 $0[1] $17
  $19 = NbOp_Mul_Float32 $0[1] $18
  $20 = NbOp_Mul_Float32 $0[1] $19
  $21 = NbOp_Mul_Float32 $15 $20
  $22 = NbOp_Add_Float32 $0[1] $21
  $23 = NbOp_Mul_Float32 $13 $22
  $24 = NbOp_Mul_Float32 $23 $17
  $25 = NbOp_Mul_Float32 $23 $24
  $26 = NbOp_Mul_Float32 $23 $25
  $27 = NbOp_Mul_Float32 $7 $26
  $28 = PyFloat 105.0
  $29 = NbOp_F64_to_F32 $28
  $30 = NbOp_Mul_Float32 $29 $23
  $31 = NbOp_Add_Float32 $27 $30
  $32 = NbOp_Mul_Float32 $23 $26
  $33 = PyFloat 45.0
  $34 = NbOp_F64_to_F32 $33
  $35 = NbOp_Mul_Float32 $34 $25
  $36 = NbOp_Add_Float32 $32 $35
  $37 = NbOp_Add_Float32 $36 $29
  $38 = NbOp_Div_Float32 $31 $37
  $39 = NbOp_Add_Float32 $5 $38
  $40 = NbOp_Mul_Float32 $3 $39
} [1782] -> !io=$0[0] !ret=$40
Extracted cost ▶
14747.0
[metadata] ▶
time elapsed 17.75ms
timing breakdown:
  17.74ms: Extracted RVSDG     
  0.01ms: Extracted cost      
5. Backend (2.13ms) ▶
Backend
Lowered module ▶
module {
  func.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} {
    %cst = arith.constant 5.000000e-01 : f64
    %c1_i64 = arith.constant 1 : i64
    %cst_0 = arith.constant 1.000000e+01 : f64
    %c2_i64 = arith.constant 2 : i64
    %cst_1 = arith.constant 3.1415926535897931 : f64
    %cst_2 = arith.constant 4.471500e-02 : f64
    %cst_3 = arith.constant 1.000000e+00 : f64
    %cst_4 = arith.constant 1.050000e+02 : f64
    %cst_5 = arith.constant 4.500000e+01 : f64
    cf.br ^bb1
  ^bb1:  // pred: ^bb0
    %c0_i32 = arith.constant 0 : i32
    %0 = arith.truncf %cst : f64 to f32
    %1 = arith.mulf %0, %arg0 : f32
    %2 = arith.sitofp %c1_i64 : i64 to f32
    %3 = arith.truncf %cst_0 : f64 to f32
    %4 = arith.sitofp %c2_i64 : i64 to f32
    %5 = arith.truncf %cst_1 : f64 to f32
    %6 = arith.divf %4, %5 : f32
    %7 = math.sqrt %6 : f32
    %8 = arith.truncf %cst_2 : f64 to f32
    %9 = arith.truncf %cst_3 : f64 to f32
    %10 = arith.mulf %arg0, %9 : f32
    %11 = arith.mulf %arg0, %10 : f32
    %12 = arith.mulf %arg0, %11 : f32
    %13 = arith.mulf %8, %12 : f32
    %14 = arith.addf %arg0, %13 : f32
    %15 = arith.mulf %7, %14 : f32
    %16 = arith.mulf %15, %9 : f32
    %17 = arith.mulf %15, %16 : f32
    %18 = arith.mulf %15, %17 : f32
    %19 = arith.mulf %3, %18 : f32
    %20 = arith.truncf %cst_4 : f64 to f32
    %21 = arith.mulf %20, %15 : f32
    %22 = arith.addf %19, %21 : f32
    %23 = arith.mulf %15, %18 : f32
    %24 = arith.truncf %cst_5 : f64 to f32
    %25 = arith.mulf %24, %17 : f32
    %26 = arith.addf %23, %25 : f32
    %27 = arith.addf %26, %20 : f32
    %28 = arith.divf %22, %27 : f32
    %29 = arith.addf %2, %28 : f32
    %30 = arith.mulf %1, %29 : f32
    return %30 : f32
  }
}
[metadata] ▶
time elapsed 2.13ms
timing breakdown:
  2.13ms: Lowered module      
6. MLIR passes (1.69ms) ▶
MLIR passes
MLIR optimized ▶
module {
  llvm.func @sqrtf(f32) -> f32 attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, sym_visibility = "private"}
  llvm.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(4.500000e+01 : f32) : f32
    %1 = llvm.mlir.constant(1.050000e+02 : f32) : f32
    %2 = llvm.mlir.constant(2.000000e+00 : f32) : f32
    %3 = llvm.mlir.constant(1.000000e+01 : f32) : f32
    %4 = llvm.mlir.constant(1.000000e+00 : f32) : f32
    %5 = llvm.mlir.constant(5.000000e-01 : f32) : f32
    %6 = llvm.mlir.constant(3.1415926535897931 : f64) : f64
    %7 = llvm.mlir.constant(4.471500e-02 : f64) : f64
    llvm.br ^bb1
  ^bb1:  // pred: ^bb0
    %8 = llvm.fmul %arg0, %5  : f32
    %9 = llvm.fptrunc %6 : f64 to f32
    %10 = llvm.fdiv %2, %9  : f32
    %11 = llvm.call @sqrtf(%10) : (f32) -> f32
    %12 = llvm.fptrunc %7 : f64 to f32
    %13 = llvm.fmul %arg0, %arg0  : f32
    %14 = llvm.fmul %arg0, %13  : f32
    %15 = llvm.fmul %12, %14  : f32
    %16 = llvm.fadd %arg0, %15  : f32
    %17 = llvm.fmul %11, %16  : f32
    %18 = llvm.fmul %17, %17  : f32
    %19 = llvm.fmul %17, %18  : f32
    %20 = llvm.fmul %19, %3  : f32
    %21 = llvm.fmul %17, %1  : f32
    %22 = llvm.fadd %20, %21  : f32
    %23 = llvm.fmul %17, %19  : f32
    %24 = llvm.fmul %18, %0  : f32
    %25 = llvm.fadd %23, %24  : f32
    %26 = llvm.fadd %25, %1  : f32
    %27 = llvm.fdiv %22, %26  : f32
    %28 = llvm.fadd %27, %4  : f32
    %29 = llvm.fmul %8, %28  : f32
    llvm.return %29 : f32
  }
  llvm.func @_mlir_ciface_func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} {
    %0 = llvm.call @func(%arg0) : (f32) -> f32
    llvm.return %0 : f32
  }
}
[metadata] ▶
time elapsed 1.69ms
timing breakdown:
  1.69ms: MLIR optimized      

Compare the Result¶

Compare the output of the original and optimized GELU functions, allowing for a small tolerance due to floating-point approximation.

In [24]:
if __name__ == "__main__":
    relclose = lambda x, y: np.allclose(x, y, rtol=1e-6)
    run_test(
        gelu_tanh_forward, jit_func, (0.234,), equal=relclose, verbose=True
    )

Testing report

1. Args ▶
(0.234,)
2. JIT output ▶
0.13864579796791077
3. Expected output ▶
0.1386458

Ufunc Version¶

Demonstrate vectorized (ufunc) compilation and execution of the optimized GELU function, and compare results for a batch of inputs.

In [25]:
if __name__ == "__main__":
    report = Report("Pipeline execution report", enable_nested_metadata=True)
    vectorized_gelu = ufunc_vectorize(
        input_type=Float32,
        ndim=1,
        compiler_config={**compiler_config, "pipeline_report": report},
        extra_ruleset=additional_rules | optimize_rules,
    )(gelu_tanh_forward)
    report.display()
    relclose = lambda x, y: np.allclose(x, y, rtol=1e-6)
    input_val = np.random.random(100).astype(np.float32)

    run_test(
        gelu_tanh_forward,
        vectorized_gelu,
        (input_val,),
        equal=relclose,
        verbose=True,
    )

Pipeline execution report

1. Frontend (9.11ms) ▶
Frontend
Debug Info on RVSDG ▶
--------------------------------original source---------------------------------
   1|def gelu_tanh_forward(a):
   2|    dt = np.float32
   3|    result = (
   4|        dt(0.5)
   5|        * a
   6|        * (
   7|            dt(1)
   8|            + np.tanh(np.sqrt(dt(2) / dt(np.pi)) * (a + dt(0.044715) * a**3))
   9|        )
  10|    )
  11|    return result
----------------------------------inter source----------------------------------
   1|def transformed_gelu_tanh_forward(a):
   2|    """#file: /tmp/ipykernel_3859/2556971375.py"""
   3|    '#loc: 2:8-2:23'
   4|    dt = np.float32
   5|    '#loc: 3:8-10:9'
   6|    result = dt(0.5) * a * (dt(1) + np.tanh(np.sqrt(dt(2) / dt(np.pi)) * (a + dt(0.044715) * a ** 3)))
   7|    '#loc: 11:8-11:21'
   8|    return result
RVSDG ▶
transformed_gelu_tanh_forward = Func (Args (ArgSpec 'a' (PyNone)))
$0 = Region[804] <- !io a
{
  $1 = PyLoadGlobal $0[0] 'np'
  $2 = PyAttr $0[0] $1 'float32'
  $3 = DbgValue 'dt' $2[1]
  $4 = PyFloat 0.5
  $5 = PyCall $3 $2[0] $4
  $6 = PyBinOp * $5[0] $5[1], $0[1]
  $7 = PyInt 1
  $8 = PyCall $3 $6[0] $7
  $9 = PyInt 2
  $10 = PyCall $3 $8[0] $9
  $11 = PyLoadGlobal $10[0] 'np'
  $12 = PyAttr $10[0] $11 'pi'
  $13 = PyCall $3 $12[0] $12[1]
  $14 = PyBinOp / $13[0] $10[1], $13[1]
  $15 = PyLoadGlobal $14[0] 'np'
  $16 = PyAttr $14[0] $15 'sqrt'
  $17 = PyCall $16[1] $16[0] $14[1]
  $18 = PyFloat 0.044715
  $19 = PyCall $3 $17[0] $18
  $20 = PyInt 3
  $21 = PyBinOp ** $19[0] $0[1], $20
  $22 = PyBinOp * $21[0] $19[1], $21[1]
  $23 = PyBinOp + $22[0] $0[1], $22[1]
  $24 = PyBinOp * $23[0] $17[1], $23[1]
  $25 = PyLoadGlobal $24[0] 'np'
  $26 = PyAttr $24[0] $25 'tanh'
  $27 = PyCall $26[1] $26[0] $24[1]
  $28 = PyBinOp + $27[0] $8[1], $27[1]
  $29 = PyBinOp * $28[0] $6[1], $28[1]
  $30 = DbgValue 'result' $29[1]
} [1268] -> !io=$29[0] !ret=$30
[metadata] ▶
time elapsed 9.11ms
timing breakdown:
  6.13ms: Debug Info on RVSDG 
  2.98ms: RVSDG               
2. EGraph Conversion (62.41ms) ▶
EGraph Conversion
EGraph ▶
outer_cluster_InPorts-0 cluster_InPorts-0 outer_cluster_Port-77 cluster_Port-77 outer_cluster_Port-80 cluster_Port-80 outer_cluster_PortList-81 cluster_PortList-81 outer_cluster_Region-1 cluster_Region-1 outer_cluster_Term-19 cluster_Term-19 outer_cluster_Term-71 cluster_Term-71 outer_cluster_Term-60 cluster_Term-60 outer_cluster_Term-12 cluster_Term-12 outer_cluster_Term-38 cluster_Term-38 outer_cluster_Term-2 cluster_Term-2 outer_cluster_Term-22 cluster_Term-22 outer_cluster_Term-62 cluster_Term-62 outer_cluster_Term-54 cluster_Term-54 outer_cluster_Term-70 cluster_Term-70 outer_cluster_Term-34 cluster_Term-34 outer_cluster_Term-72 cluster_Term-72 outer_cluster_Term-49 cluster_Term-49 outer_cluster_Term-27 cluster_Term-27 outer_cluster_Term-84 cluster_Term-84 outer_cluster_Term-39 cluster_Term-39 outer_cluster_Term-32 cluster_Term-32 outer_cluster_Term-31 cluster_Term-31 outer_cluster_Term-50 cluster_Term-50 outer_cluster_Term-83 cluster_Term-83 outer_cluster_Term-33 cluster_Term-33 outer_cluster_Term-64 cluster_Term-64 outer_cluster_Term-78 cluster_Term-78 outer_cluster_Term-35 cluster_Term-35 outer_cluster_Term-82 cluster_Term-82 outer_cluster_Term-16 cluster_Term-16 outer_cluster_Term-26 cluster_Term-26 outer_cluster_Term-68 cluster_Term-68 outer_cluster_Term-14 cluster_Term-14 outer_cluster_Term-43 cluster_Term-43 outer_cluster_Term-7 cluster_Term-7 outer_cluster_Term-76 cluster_Term-76 outer_cluster_Term-30 cluster_Term-30 outer_cluster_Term-3 cluster_Term-3 outer_cluster_Term-42 cluster_Term-42 outer_cluster_Term-24 cluster_Term-24 outer_cluster_Term-57 cluster_Term-57 outer_cluster_Term-69 cluster_Term-69 outer_cluster_Term-63 cluster_Term-63 outer_cluster_Term-23 cluster_Term-23 outer_cluster_Term-65 cluster_Term-65 outer_cluster_Term-25 cluster_Term-25 outer_cluster_Term-29 cluster_Term-29 outer_cluster_Term-6 cluster_Term-6 outer_cluster_Term-59 cluster_Term-59 outer_cluster_Term-61 cluster_Term-61 outer_cluster_Term-20 cluster_Term-20 outer_cluster_Term-51 cluster_Term-51 outer_cluster_Term-67 cluster_Term-67 outer_cluster_Term-13 cluster_Term-13 outer_cluster_Term-11 cluster_Term-11 outer_cluster_Term-58 cluster_Term-58 outer_cluster_Term-55 cluster_Term-55 outer_cluster_Term-73 cluster_Term-73 outer_cluster_Term-74 cluster_Term-74 outer_cluster_Term-36 cluster_Term-36 outer_cluster_Term-79 cluster_Term-79 outer_cluster_Term-46 cluster_Term-46 outer_cluster_Term-48 cluster_Term-48 outer_cluster_Term-10 cluster_Term-10 outer_cluster_Term-45 cluster_Term-45 outer_cluster_Term-47 cluster_Term-47 outer_cluster_Term-37 cluster_Term-37 outer_cluster_Term-4 cluster_Term-4 outer_cluster_Term-8 cluster_Term-8 outer_cluster_Term-75 cluster_Term-75 outer_cluster_Term-5 cluster_Term-5 outer_cluster_Term-52 cluster_Term-52 outer_cluster_Term-15 cluster_Term-15 outer_cluster_Term-18 cluster_Term-18 outer_cluster_Term-56 cluster_Term-56 outer_cluster_Term-53 cluster_Term-53 outer_cluster_Term-41 cluster_Term-41 outer_cluster_TermList-40 cluster_TermList-40 outer_cluster_TermList-9 cluster_TermList-9 outer_cluster_TermList-44 cluster_TermList-44 outer_cluster_TermList-66 cluster_TermList-66 outer_cluster_TermList-17 cluster_TermList-17 outer_cluster_TermList-21 cluster_TermList-21 outer_cluster_TermList-28 cluster_TermList-28 outer_cluster_Vec_Port-0 cluster_Vec_Port-0 outer_cluster_Vec_String-0 cluster_Vec_String-0 outer_cluster_Vec_Term-3 cluster_Vec_Term-3 outer_cluster_Vec_Term-6 cluster_Vec_Term-6 outer_cluster_Vec_Term-2 cluster_Vec_Term-2 outer_cluster_Vec_Term-0 cluster_Vec_Term-0 outer_cluster_Vec_Term-5 cluster_Vec_Term-5 outer_cluster_Vec_Term-1 cluster_Vec_Term-1 outer_cluster_Vec_Term-4 cluster_Vec_Term-4 function-0-InPorts___init__:s->primitive-Vec_String-0 function-0-Port___init__:s->function-36-Term_getPort function-36-Term_getPort:s->function-3-Py_MulIO function-1-Port___init__:s->function-1-Term_DbgValue function-1-Term_DbgValue:s->function-37-Term_getPort function-0-PortList___init__:s->primitive-Vec_Port-0 primitive-Vec_Port-0:s->function-0-Port___init__ primitive-Vec_Port-0:s->function-1-Port___init__ function-0-Region___init__:s->function-0-InPorts___init__ function-5-Term_getPort:s->function-1-Py_Call function-1-Py_Call:s->function-0-Term_DbgValue function-1-Py_Call:s->function-4-Term_getPort function-1-Py_Call:s->function-1-TermList___init__ function-1-Py_AddIO:s->function-30-Term_getPort function-1-Py_AddIO:s->function-31-Term_getPort function-1-Py_AddIO:s->function-32-Term_getPort function-30-Term_getPort:s->function-6-Py_Call function-31-Term_getPort:s->function-1-Py_Call function-32-Term_getPort:s->function-6-Py_Call function-26-Term_getPort:s->function-2-Py_MulIO function-2-Py_MulIO:s->function-24-Term_getPort function-2-Py_MulIO:s->function-23-Term_getPort function-2-Py_MulIO:s->function-25-Term_getPort function-3-Term_getPort:s->function-0-Py_Call function-0-Py_Call:s->function-0-Term_DbgValue function-0-Py_Call:s->function-1-Term_getPort function-0-Py_Call:s->function-0-TermList___init__ function-14-Term_getPort:s->function-2-Py_AttrIO function-2-Py_AttrIO:s->function-12-Term_getPort function-2-Py_AttrIO:s->function-2-Py_LoadGlobal function-0-Region_get:s->function-0-Region___init__ function-2-Py_Call:s->function-5-Term_getPort function-2-Py_Call:s->function-0-Term_DbgValue function-2-Py_Call:s->function-2-TermList___init__ function-0-Term_DbgValue:s->function-0-Term_getPort function-2-TermList___init__:s->primitive-Vec_Term-2 function-3-Py_AttrIO:s->function-26-Term_getPort function-3-Py_AttrIO:s->function-3-Py_LoadGlobal function-3-Py_LoadGlobal:s->function-26-Term_getPort function-22-Term_getPort:s->function-1-Py_MulIO function-1-Py_MulIO:s->function-18-Term_getPort function-1-Py_MulIO:s->function-19-Term_getPort function-1-Py_MulIO:s->function-20-Term_getPort function-6-Py_Call:s->function-28-Term_getPort function-6-Py_Call:s->function-27-Term_getPort function-6-Py_Call:s->function-6-TermList___init__ function-12-Term_getPort:s->function-0-Py_DivIO function-0-Py_DivIO:s->function-11-Term_getPort function-0-Py_DivIO:s->function-10-Term_getPort function-0-Py_DivIO:s->function-9-Term_getPort function-33-Term_getPort:s->function-1-Py_AddIO function-18-Term_getPort:s->function-0-Py_PowIO function-0-Py_PowIO:s->function-1-Region_get function-0-Py_PowIO:s->function-17-Term_getPort function-0-Py_PowIO:s->function-2-Term_LiteralI64 function-8-Term_getPort:s->function-1-Py_AttrIO function-1-Py_AttrIO:s->function-1-Py_LoadGlobal function-1-Py_AttrIO:s->function-6-Term_getPort function-0-GraphRoot:s->function-0-Term_Func function-0-Term_Func:s->function-0-Term_RegionEnd function-15-Term_getPort:s->function-0-Py_DivIO function-11-Term_getPort:s->function-3-Py_Call function-3-Py_Call:s->function-0-Term_DbgValue function-3-Py_Call:s->function-7-Term_getPort function-3-Py_Call:s->function-3-TermList___init__ function-10-Term_getPort:s->function-2-Py_Call function-19-Term_getPort:s->function-5-Py_Call function-5-Py_Call:s->function-0-Term_DbgValue function-5-Py_Call:s->function-16-Term_getPort function-5-Py_Call:s->function-5-TermList___init__ function-0-Term_RegionEnd:s->function-0-PortList___init__ function-0-Term_RegionEnd:s->function-0-Region___init__ function-9-Term_getPort:s->function-3-Py_Call function-28-Term_getPort:s->function-3-Py_AttrIO function-37-Term_getPort:s->function-3-Py_MulIO function-3-Py_MulIO:s->function-33-Term_getPort function-3-Py_MulIO:s->function-34-Term_getPort function-3-Py_MulIO:s->function-35-Term_getPort function-2-Py_LoadGlobal:s->function-12-Term_getPort function-7-Term_getPort:s->function-1-Py_AttrIO function-0-Py_MulIO:s->function-3-Term_getPort function-0-Py_MulIO:s->function-2-Term_getPort function-0-Py_MulIO:s->function-1-Region_get function-2-Term_getPort:s->function-0-Py_Call function-1-Region_get:s->function-0-Region___init__ function-1-Term_getPort:s->function-0-Py_AttrIO function-0-Py_AttrIO:s->function-0-Region_get function-0-Py_AttrIO:s->function-0-Py_LoadGlobal function-0-Py_LoadGlobal:s->function-0-Region_get function-16-Term_getPort:s->function-4-Py_Call function-4-Py_Call:s->function-14-Term_getPort function-4-Py_Call:s->function-13-Term_getPort function-4-Py_Call:s->function-4-TermList___init__ function-1-Py_LoadGlobal:s->function-6-Term_getPort function-6-Term_getPort:s->function-2-Py_Call function-24-Term_getPort:s->function-4-Py_Call function-27-Term_getPort:s->function-3-Py_AttrIO function-29-Term_getPort:s->function-2-Py_MulIO function-3-TermList___init__:s->primitive-Vec_Term-3 function-0-Term_getPort:s->function-0-Py_AttrIO function-23-Term_getPort:s->function-0-Py_AddIO function-25-Term_getPort:s->function-0-Py_AddIO function-20-Term_getPort:s->function-0-Py_PowIO function-6-TermList___init__:s->primitive-Vec_Term-6 function-0-Py_AddIO:s->function-22-Term_getPort function-0-Py_AddIO:s->function-1-Region_get function-0-Py_AddIO:s->function-21-Term_getPort function-21-Term_getPort:s->function-1-Py_MulIO function-34-Term_getPort:s->function-0-Py_MulIO function-35-Term_getPort:s->function-1-Py_AddIO function-17-Term_getPort:s->function-5-Py_Call function-0-TermList___init__:s->primitive-Vec_Term-0 function-5-TermList___init__:s->primitive-Vec_Term-5 function-13-Term_getPort:s->function-2-Py_AttrIO function-4-Term_getPort:s->function-0-Py_MulIO function-1-TermList___init__:s->primitive-Vec_Term-1 function-4-TermList___init__:s->primitive-Vec_Term-4 primitive-Vec_Term-4:s->function-15-Term_getPort primitive-Vec_Term-0:s->function-0-Term_LiteralF64 primitive-Vec_Term-5:s->function-1-Term_LiteralF64 primitive-Vec_Term-6:s->function-29-Term_getPort primitive-Vec_Term-1:s->function-0-Term_LiteralI64 primitive-Vec_Term-2:s->function-1-Term_LiteralI64 primitive-Vec_Term-3:s->function-8-Term_getPort function-0-InPorts___init__ InPorts primitive-Vec_String-0 Vec("!io", "a") function-0-Port___init__ Port("!io", ·) function-36-Term_getPort ·.getPort(·, 0) function-1-Port___init__ Port("!ret", ·) function-1-Term_DbgValue Term.DbgValue("result", ·) function-0-PortList___init__ PortList primitive-Vec_Port-0 Vec function-0-Region___init__ Region("804", ·) function-5-Term_getPort ·.getPort(·, 0) function-1-Py_Call Py_Call function-1-Py_AddIO Py_AddIO function-30-Term_getPort ·.getPort(·, 0) function-31-Term_getPort ·.getPort(·, 1) function-32-Term_getPort ·.getPort(·, 1) function-26-Term_getPort ·.getPort(·, 0) function-2-Py_MulIO Py_MulIO function-3-Term_getPort ·.getPort(·, 1) function-0-Py_Call Py_Call function-14-Term_getPort ·.getPort(·, 0) function-2-Py_AttrIO Py_AttrIO(·, ·, "sqrt") function-0-Region_get ·.get(·, 0) function-2-Py_Call Py_Call function-0-Term_DbgValue Term.DbgValue("dt", ·) function-2-TermList___init__ TermList function-3-Py_AttrIO Py_AttrIO(·, ·, "tanh") function-3-Py_LoadGlobal Py_LoadGlobal(·, "np") function-22-Term_getPort ·.getPort(·, 1) function-1-Py_MulIO Py_MulIO function-6-Py_Call Py_Call function-12-Term_getPort ·.getPort(·, 0) function-0-Py_DivIO Py_DivIO function-33-Term_getPort ·.getPort(·, 0) function-18-Term_getPort ·.getPort(·, 0) function-0-Py_PowIO Py_PowIO function-8-Term_getPort ·.getPort(·, 1) function-1-Py_AttrIO Py_AttrIO(·, ·, "pi") function-0-GraphRoot GraphRoot function-0-Term_Func Term.Func("1274", "transformed_gelu_tanh_forward", ·) function-15-Term_getPort ·.getPort(·, 1) function-11-Term_getPort ·.getPort(·, 1) function-3-Py_Call Py_Call function-10-Term_getPort ·.getPort(·, 1) function-19-Term_getPort ·.getPort(·, 1) function-5-Py_Call Py_Call function-0-Term_RegionEnd Term.RegionEnd function-9-Term_getPort ·.getPort(·, 0) function-28-Term_getPort ·.getPort(·, 0) function-37-Term_getPort ·.getPort(·, 1) function-3-Py_MulIO Py_MulIO function-2-Py_LoadGlobal Py_LoadGlobal(·, "np") function-0-Term_LiteralI64 Term.LiteralI64(1) function-7-Term_getPort ·.getPort(·, 0) function-0-Py_MulIO Py_MulIO function-2-Term_getPort ·.getPort(·, 0) function-1-Region_get ·.get(·, 1) function-1-Term_LiteralF64 Term.LiteralF64(0.044715) function-1-Term_getPort ·.getPort(·, 0) function-0-Py_AttrIO Py_AttrIO(·, ·, "float32") function-0-Py_LoadGlobal Py_LoadGlobal(·, "np") function-16-Term_getPort ·.getPort(·, 0) function-4-Py_Call Py_Call function-1-Py_LoadGlobal Py_LoadGlobal(·, "np") function-6-Term_getPort ·.getPort(·, 0) function-24-Term_getPort ·.getPort(·, 1) function-27-Term_getPort ·.getPort(·, 1) function-29-Term_getPort ·.getPort(·, 1) function-3-TermList___init__ TermList function-0-Term_getPort ·.getPort(·, 1) function-23-Term_getPort ·.getPort(·, 0) function-25-Term_getPort ·.getPort(·, 1) function-1-Term_LiteralI64 Term.LiteralI64(2) function-20-Term_getPort ·.getPort(·, 1) function-6-TermList___init__ TermList function-0-Py_AddIO Py_AddIO function-21-Term_getPort ·.getPort(·, 0) function-34-Term_getPort ·.getPort(·, 1) function-35-Term_getPort ·.getPort(·, 1) function-17-Term_getPort ·.getPort(·, 0) function-2-Term_LiteralI64 Term.LiteralI64(3) function-0-TermList___init__ TermList function-5-TermList___init__ TermList function-13-Term_getPort ·.getPort(·, 1) function-0-Term_LiteralF64 Term.LiteralF64(0.5) function-4-Term_getPort ·.getPort(·, 0) function-1-TermList___init__ TermList function-4-TermList___init__ TermList primitive-Vec_Term-4 Vec primitive-Vec_Term-0 Vec primitive-Vec_Term-5 Vec primitive-Vec_Term-6 Vec primitive-Vec_Term-1 Vec primitive-Vec_Term-2 Vec primitive-Vec_Term-3 Vec
[metadata] ▶
time elapsed 62.41ms
timing breakdown:
  62.41ms: EGraph              
3. Egraph Saturation (0.00ms) ▶
Egraph Saturation
[metadata] ▶
time elapsed 0.00ms
timing breakdown:
4. EGraph Extraction (16.56ms) ▶
EGraph Extraction
Extracted RVSDG ▶
transformed_gelu_tanh_forward = Func (Args (ArgSpec 'a' (PyNone)))
$0 = Region[1575] <- !io a; #attrs (_, Float32)->(_, Float32)
{
  $1 = PyFloat 0.5
  $2 = NbOp_F64_to_F32 $1
  $3 = NbOp_Mul_Float32 $2 $0[1]
  $4 = PyInt 1
  $5 = NbOp_I64_to_F32 $4
  $6 = PyFloat 10.0
  $7 = NbOp_F64_to_F32 $6
  $8 = PyInt 2
  $9 = NbOp_I64_to_F32 $8
  $10 = PyFloat 3.141592653589793
  $11 = NbOp_F64_to_F32 $10
  $12 = NbOp_Div_Float32 $9 $11
  $13 = NpyOp_Sqrt_Float32 $12
  $14 = PyFloat 0.044715
  $15 = NbOp_F64_to_F32 $14
  $16 = PyFloat 1.0
  $17 = NbOp_F64_to_F32 $16
  $18 = NbOp_Mul_Float32 $0[1] $17
  $19 = NbOp_Mul_Float32 $0[1] $18
  $20 = NbOp_Mul_Float32 $0[1] $19
  $21 = NbOp_Mul_Float32 $15 $20
  $22 = NbOp_Add_Float32 $0[1] $21
  $23 = NbOp_Mul_Float32 $13 $22
  $24 = NbOp_Mul_Float32 $23 $17
  $25 = NbOp_Mul_Float32 $23 $24
  $26 = NbOp_Mul_Float32 $23 $25
  $27 = NbOp_Mul_Float32 $7 $26
  $28 = PyFloat 105.0
  $29 = NbOp_F64_to_F32 $28
  $30 = NbOp_Mul_Float32 $29 $23
  $31 = NbOp_Add_Float32 $27 $30
  $32 = NbOp_Mul_Float32 $23 $26
  $33 = PyFloat 45.0
  $34 = NbOp_F64_to_F32 $33
  $35 = NbOp_Mul_Float32 $34 $25
  $36 = NbOp_Add_Float32 $32 $35
  $37 = NbOp_Add_Float32 $36 $29
  $38 = NbOp_Div_Float32 $31 $37
  $39 = NbOp_Add_Float32 $5 $38
  $40 = NbOp_Mul_Float32 $3 $39
} [1782] -> !io=$0[0] !ret=$40
Extracted cost ▶
14747.0
[metadata] ▶
time elapsed 16.56ms
timing breakdown:
  16.55ms: Extracted RVSDG     
  0.01ms: Extracted cost      
5. Backend (2.06ms) ▶
Backend
Lowered module ▶
module {
  func.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} {
    %cst = arith.constant 5.000000e-01 : f64
    %c1_i64 = arith.constant 1 : i64
    %cst_0 = arith.constant 1.000000e+01 : f64
    %c2_i64 = arith.constant 2 : i64
    %cst_1 = arith.constant 3.1415926535897931 : f64
    %cst_2 = arith.constant 4.471500e-02 : f64
    %cst_3 = arith.constant 1.000000e+00 : f64
    %cst_4 = arith.constant 1.050000e+02 : f64
    %cst_5 = arith.constant 4.500000e+01 : f64
    cf.br ^bb1
  ^bb1:  // pred: ^bb0
    %c0_i32 = arith.constant 0 : i32
    %0 = arith.truncf %cst : f64 to f32
    %1 = arith.mulf %0, %arg0 : f32
    %2 = arith.sitofp %c1_i64 : i64 to f32
    %3 = arith.truncf %cst_0 : f64 to f32
    %4 = arith.sitofp %c2_i64 : i64 to f32
    %5 = arith.truncf %cst_1 : f64 to f32
    %6 = arith.divf %4, %5 : f32
    %7 = math.sqrt %6 : f32
    %8 = arith.truncf %cst_2 : f64 to f32
    %9 = arith.truncf %cst_3 : f64 to f32
    %10 = arith.mulf %arg0, %9 : f32
    %11 = arith.mulf %arg0, %10 : f32
    %12 = arith.mulf %arg0, %11 : f32
    %13 = arith.mulf %8, %12 : f32
    %14 = arith.addf %arg0, %13 : f32
    %15 = arith.mulf %7, %14 : f32
    %16 = arith.mulf %15, %9 : f32
    %17 = arith.mulf %15, %16 : f32
    %18 = arith.mulf %15, %17 : f32
    %19 = arith.mulf %3, %18 : f32
    %20 = arith.truncf %cst_4 : f64 to f32
    %21 = arith.mulf %20, %15 : f32
    %22 = arith.addf %19, %21 : f32
    %23 = arith.mulf %15, %18 : f32
    %24 = arith.truncf %cst_5 : f64 to f32
    %25 = arith.mulf %24, %17 : f32
    %26 = arith.addf %23, %25 : f32
    %27 = arith.addf %26, %20 : f32
    %28 = arith.divf %22, %27 : f32
    %29 = arith.addf %2, %28 : f32
    %30 = arith.mulf %1, %29 : f32
    return %30 : f32
  }
}
[metadata] ▶
time elapsed 2.06ms
timing breakdown:
  2.06ms: Lowered module      
6. MLIR passes (3.10ms) ▶
MLIR passes
MLIR optimized ▶
module {
  llvm.func @sqrtf(f32) -> f32 attributes {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, sym_visibility = "private"}
  llvm.func @func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.constant(4.500000e+01 : f32) : f32
    %1 = llvm.mlir.constant(1.050000e+02 : f32) : f32
    %2 = llvm.mlir.constant(2.000000e+00 : f32) : f32
    %3 = llvm.mlir.constant(1.000000e+01 : f32) : f32
    %4 = llvm.mlir.constant(1.000000e+00 : f32) : f32
    %5 = llvm.mlir.constant(5.000000e-01 : f32) : f32
    %6 = llvm.mlir.constant(3.1415926535897931 : f64) : f64
    %7 = llvm.mlir.constant(4.471500e-02 : f64) : f64
    llvm.br ^bb1
  ^bb1:  // pred: ^bb0
    %8 = llvm.fmul %arg0, %5  : f32
    %9 = llvm.fptrunc %6 : f64 to f32
    %10 = llvm.fdiv %2, %9  : f32
    %11 = llvm.call @sqrtf(%10) : (f32) -> f32
    %12 = llvm.fptrunc %7 : f64 to f32
    %13 = llvm.fmul %arg0, %arg0  : f32
    %14 = llvm.fmul %arg0, %13  : f32
    %15 = llvm.fmul %12, %14  : f32
    %16 = llvm.fadd %arg0, %15  : f32
    %17 = llvm.fmul %11, %16  : f32
    %18 = llvm.fmul %17, %17  : f32
    %19 = llvm.fmul %17, %18  : f32
    %20 = llvm.fmul %19, %3  : f32
    %21 = llvm.fmul %17, %1  : f32
    %22 = llvm.fadd %20, %21  : f32
    %23 = llvm.fmul %17, %19  : f32
    %24 = llvm.fmul %18, %0  : f32
    %25 = llvm.fadd %23, %24  : f32
    %26 = llvm.fadd %25, %1  : f32
    %27 = llvm.fdiv %22, %26  : f32
    %28 = llvm.fadd %27, %4  : f32
    %29 = llvm.fmul %8, %28  : f32
    llvm.return %29 : f32
  }
  llvm.func @_mlir_ciface_func(%arg0: f32) -> f32 attributes {llvm.emit_c_interface} {
    %0 = llvm.call @func(%arg0) : (f32) -> f32
    llvm.return %0 : f32
  }
  llvm.func @ufunc(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64) attributes {llvm.emit_c_interface} {
    %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %1 = llvm.insertvalue %arg5, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %2 = llvm.insertvalue %arg6, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %3 = llvm.insertvalue %arg7, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %4 = llvm.insertvalue %arg8, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %5 = llvm.insertvalue %arg9, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %7 = llvm.insertvalue %arg0, %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %8 = llvm.insertvalue %arg1, %7[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %9 = llvm.insertvalue %arg2, %8[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %10 = llvm.insertvalue %arg3, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %11 = llvm.insertvalue %arg4, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %12 = llvm.mlir.constant(1 : index) : i64
    %13 = llvm.mlir.constant(0 : index) : i64
    %14 = llvm.extractvalue %11[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    llvm.br ^bb1(%13 : i64)
  ^bb1(%15: i64):  // 2 preds: ^bb0, ^bb2
    %16 = llvm.icmp "slt" %15, %14 : i64
    llvm.cond_br %16, ^bb2, ^bb3
  ^bb2:  // pred: ^bb1
    %17 = llvm.extractvalue %11[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %18 = llvm.getelementptr %17[%15] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    %19 = llvm.load %18 : !llvm.ptr -> f32
    %20 = llvm.call @func(%19) : (f32) -> f32
    %21 = llvm.extractvalue %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %22 = llvm.getelementptr %21[%15] : (!llvm.ptr, i64) -> !llvm.ptr, f32
    llvm.store %20, %22 : f32, !llvm.ptr
    %23 = llvm.add %15, %12 : i64
    llvm.br ^bb1(%23 : i64)
  ^bb3:  // pred: ^bb1
    llvm.return
  }
  llvm.func @_mlir_ciface_ufunc(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {llvm.emit_c_interface} {
    %0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %4 = llvm.extractvalue %0[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %5 = llvm.extractvalue %0[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %6 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %7 = llvm.extractvalue %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %8 = llvm.extractvalue %6[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %9 = llvm.extractvalue %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %10 = llvm.extractvalue %6[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %11 = llvm.extractvalue %6[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    llvm.call @ufunc(%1, %2, %3, %4, %5, %7, %8, %9, %10, %11) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> ()
    llvm.return
  }
}
[metadata] ▶
time elapsed 3.10ms
timing breakdown:
  3.10ms: MLIR optimized      

Testing report

1. Args ▶
(array([0.53754056, 0.43233588, 0.10236798, 0.46508574, 0.58549654,
       0.8923938 , 0.2757588 , 0.08432309, 0.8884488 , 0.8587762 ,
       0.6457952 , 0.11737864, 0.9524025 , 0.9533351 , 0.3317381 ,
       0.1680605 , 0.0680922 , 0.17509076, 0.07749847, 0.972892  ,
       0.43774995, 0.9594856 , 0.3802283 , 0.4063943 , 0.6109139 ,
       0.82032937, 0.96378887, 0.21142137, 0.8676044 , 0.8083702 ,
       0.53161585, 0.36552882, 0.2643556 , 0.00743109, 0.21400109,
       0.45045698, 0.87366974, 0.7190363 , 0.54310775, 0.18760236,
       0.72153026, 0.61276686, 0.9093058 , 0.91447484, 0.5454792 ,
       0.5947745 , 0.9831064 , 0.7109356 , 0.15139958, 0.9206733 ,
       0.8234191 , 0.41100514, 0.445924  , 0.15229666, 0.13950907,
       0.19849318, 0.02099547, 0.2286102 , 0.02923151, 0.2675173 ,
       0.8535502 , 0.43683502, 0.48123544, 0.91435367, 0.8591862 ,
       0.60497624, 0.6497597 , 0.47559726, 0.745626  , 0.068115  ,
       0.93472457, 0.32189476, 0.51888883, 0.08371709, 0.08502585,
       0.24354592, 0.34441778, 0.91207695, 0.6931497 , 0.3731865 ,
       0.9115862 , 0.10682532, 0.09896134, 0.53871924, 0.1919585 ,
       0.24836382, 0.7622515 , 0.9566354 , 0.53254044, 0.7762569 ,
       0.8494233 , 0.5688198 , 0.61417764, 0.8689236 , 0.5286076 ,
       0.07920949, 0.7584067 , 0.72011465, 0.7771502 , 0.89118475],
      dtype=float32),)
2. JIT output ▶
[0.3787034  0.28846663 0.05535727 0.31581023 0.42205015 0.72621244
 0.16783419 0.0449948  0.7220623  0.69101226 0.47835886 0.06417319
 0.78993255 0.7909312  0.2089769  0.09524503 0.03589438 0.09971316
 0.04114288 0.8119259  0.2929388  0.7975227  0.24642435 0.26730743
 0.44554824 0.65123934 0.8021408  0.12341041 0.70021915 0.63897943
 0.37344074 0.23489869 0.15973456 0.00373758 0.12513152 0.30351046
 0.70656013 0.5492407  0.38366732 0.10775949 0.55169916 0.44727504
 0.74406016 0.74953294 0.3857873  0.43058652 0.8229314  0.54127514
 0.0848093  0.75610644 0.6544156  0.27103543 0.2997272  0.08536569
 0.07749382 0.11486163 0.01067358 0.13497381 0.0149566  0.16197063
 0.6855748  0.2921817  0.3295473  0.7494046  0.6914393  0.44002742
 0.482128   0.3247327  0.5755953  0.03590702 0.77105045 0.20157808
 0.36220664 0.04465127 0.04539355 0.14520308 0.2186094  0.74699306
 0.5238929  0.2408843  0.7464735  0.05795657 0.05338125 0.3797528
 0.11058928 0.14853863 0.5922312  0.7944669  0.37426063 0.60633624
 0.68128777 0.40682715 0.44859096 0.7015972  0.37077665 0.04210514
 0.5883735  0.5503034  0.60723865 0.72494   ]
3. Expected output ▶
[0.3787034  0.28846663 0.05535727 0.31581023 0.42205015 0.7262126
 0.16783419 0.0449948  0.72206247 0.6910124  0.4783589  0.06417319
 0.78993297 0.7909316  0.2089769  0.09524503 0.03589438 0.09971316
 0.04114288 0.8119263  0.2929388  0.79752314 0.24642432 0.26730743
 0.44554824 0.6512394  0.8021411  0.12341041 0.7002193  0.63897943
 0.37344074 0.23489869 0.15973456 0.00373758 0.12513152 0.30351046
 0.7065603  0.54924077 0.38366732 0.10775949 0.55169916 0.44727504
 0.74406034 0.74953324 0.3857873  0.43058652 0.82293195 0.54127514
 0.0848093  0.75610673 0.65441567 0.27103543 0.2997272  0.08536569
 0.07749382 0.11486163 0.01067358 0.13497381 0.0149566  0.16197063
 0.68557495 0.2921817  0.3295473  0.7494048  0.6914394  0.44002742
 0.48212805 0.3247327  0.5755953  0.03590702 0.77105075 0.20157808
 0.36220664 0.04465127 0.04539355 0.14520308 0.2186094  0.7469933
 0.5238929  0.2408843  0.7464738  0.05795657 0.05338125 0.37975284
 0.11058928 0.14853863 0.5922312  0.7944673  0.37426063 0.60633624
 0.6812878  0.40682715 0.44859102 0.7015974  0.37077665 0.04210514
 0.58837354 0.55030346 0.6072387  0.7249402 ]

Benchmark¶

In [26]:
if __name__ == "__main__":
    input_val = np.random.random(300000).astype(np.float32)
    out = np.zeros_like(input_val)

    print("original")
    %timeit gelu_tanh_forward(input_val)
    print("superoptimized")
    %timeit vectorized_gelu(input_val, out=out)
original
4.35 ms ± 16.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
superoptimized
260 μs ± 3.46 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)