Chapter 6: MLIR Backend¶
This chapter introduces the MLIR backend, which replaces the LLVM backend from previous chapters. MLIR provides a more flexible and extensible framework for compiler optimization and code generation. We show how to lower RVSDG-IR to MLIR dialects and use MLIR's pass infrastructure for optimization.
The chapter covers:
- How to implement an MLIR backend for RVSDG-IR
- How to use MLIR dialects for different abstraction levels
- How to apply MLIR passes for optimization
Imports and Setup¶
from __future__ import annotations
import ctypes
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, TypedDict
import mlir.dialects.arith as arith
import mlir.dialects.cf as cf
import mlir.dialects.func as func
import mlir.dialects.scf as scf
import mlir.execution_engine as execution_engine
import mlir.ir as ir
import mlir.passmanager as passmanager
import mlir.runtime as runtime
import numpy as np
from sealir import ase
from sealir.rvsdg import grammar as rg
from sealir.rvsdg import internal_prefix
from ch03_egraph_program_rewrites import (
run_test,
)
from ch04_1_typeinfer_ifelse import (
Attributes,
)
from ch04_1_typeinfer_ifelse import (
ExtendEGraphToRVSDG as ConditionalExtendGraphtoRVSDG,
)
from ch04_1_typeinfer_ifelse import (
Int64,
MyCostModel,
NbOp_Add_Float64,
NbOp_Add_Int64,
NbOp_CastI64ToF64,
NbOp_Div_Int64,
NbOp_Gt_Int64,
NbOp_Lt_Int64,
NbOp_Sub_Float64,
NbOp_Sub_Int64,
NbOp_Type,
SExpr,
TypeInt64,
)
from ch04_1_typeinfer_ifelse import base_ruleset as if_else_ruleset
from ch04_1_typeinfer_ifelse import jit_compiler as _ch04_1_jit_compiler
from ch04_1_typeinfer_ifelse import (
ruleset_type_infer_float,
setup_argtypes,
)
from ch04_2_typeinfer_loops import (
ExtendEGraphToRVSDG as LoopExtendEGraphToRVSDG,
)
from ch04_2_typeinfer_loops import (
NbOp_Not_Int64,
)
from ch04_2_typeinfer_loops import base_ruleset as loop_ruleset
from utils import IN_NOTEBOOK, Report, display
MLIR Backend Implementation¶
Define the core MLIR backend class that handles type lowering and expression compilation.
_DEBUG = False
@dataclass(frozen=True)
class LowerStates(ase.TraverseState):
push: Callable
get_region_args: Callable
function_block: func.FuncOp
constant_block: ir.Block
function_name = "func"
class Backend:
def __init__(self):
self.context = context = ir.Context()
self.f32 = ir.F32Type.get(context=context)
self.f64 = ir.F64Type.get(context=context)
self.i32 = ir.IntegerType.get_signless(32, context=context)
self.i64 = ir.IntegerType.get_signless(64, context=context)
self.boo = ir.IntegerType.get_signless(1, context=context)
def lower_type(self, ty: NbOp_Type):
"""Type Lowering
Convert SealIR types to MLIR types for compilation.
"""
match ty:
case NbOp_Type("Int64"):
return self.i64
case NbOp_Type("Float64"):
return self.f64
case NbOp_Type("Float32"):
return self.f32
raise NotImplementedError(f"unknown type: {ty}")
def lower(self, root: rg.Func, argtypes):
"""Expression Lowering
Lower RVSDG expressions to MLIR operations, handling control flow
and data flow constructs.
"""
context = self.context
self.loc = loc = ir.Location.unknown(context=context)
self.module = module = ir.Module.create(loc=loc)
# Get the module body pointer so we can insert content into the
# module.
self.module_body = module_body = ir.InsertionPoint(module.body)
# Convert SealIR types to MLIR types.
input_types = tuple([self.lower_type(x) for x in argtypes])
output_types = (
self.lower_type(
Attributes(root.body.begin.attrs).get_return_type(root.body)
),
)
with context, loc, module_body:
# Constuct a function that emits a callable C-interface.
fun = func.FuncOp(function_name, (input_types, output_types))
fun.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
# Define two blocks within the function, a constant block to
# define all the constants and a function block for the
# actual content. This is done to prevent non-dominant use
# of constants. (Use of a constant when declaration is done in
# a region that isn't initialized.)
const_block = fun.add_entry_block()
fun.body.blocks.append(*[], arg_locs=None)
func_block = fun.body.blocks[1]
# Define entry points of both the blocks.
constant_entry = ir.InsertionPoint(const_block)
function_entry = ir.InsertionPoint(func_block)
region_args = []
@contextmanager
def push(arg_values):
region_args.append(tuple(arg_values))
try:
yield
finally:
region_args.pop()
def get_region_args():
return region_args[-1]
with context, loc, function_entry:
memo = ase.traverse(
root,
self.lower_expr,
LowerStates(
push=push,
get_region_args=get_region_args,
function_block=fun,
constant_block=constant_entry,
),
)
# Use a break to jump from the constant block to the function block.
# note that this is being inserted at end of constant block after the
# Function construction when all the constants have been initialized.
with context, loc, constant_entry:
cf.br([], fun.body.blocks[1])
return module
def run_passes(self, module):
"""MLIR Pass Pipeline
Apply MLIR passes for optimization and lowering to LLVM IR.
"""
if _DEBUG:
module.dump()
if _DEBUG:
module.context.enable_multithreading(False)
if _DEBUG and not IN_NOTEBOOK:
# notebook may hang if ir_printing is enabled and and MLIR failed.
pass_man.enable_ir_printing()
pass_man = passmanager.PassManager(context=module.context)
pass_man.add("convert-linalg-to-loops")
pass_man.add("convert-scf-to-cf")
pass_man.add("finalize-memref-to-llvm")
pass_man.add("convert-math-to-libm")
pass_man.add("convert-func-to-llvm")
pass_man.add("convert-index-to-llvm")
pass_man.add("reconcile-unrealized-casts")
pass_man.enable_verifier(True)
pass_man.run(module.operation)
# Output LLVM-dialect MLIR
if _DEBUG:
module.dump()
return module
def lower_expr(self, expr: SExpr, state: LowerStates):
"""Expression Lowering Implementation
Implement the core expression lowering logic for various RVSDG
constructs including functions, regions, control flow, and operations.
"""
match expr:
case rg.Func(args=args, body=body):
names = {
argspec.name: state.function_block.arguments[i]
for i, argspec in enumerate(args.arguments)
}
argvalues = []
for k in body.begin.inports:
if k == internal_prefix("io"):
v = arith.constant(self.i32, 0)
else:
v = names[k]
argvalues.append(v)
with state.push(argvalues):
outs = yield body
portnames = [p.name for p in body.ports]
retval = outs[portnames.index(internal_prefix("ret"))]
func.ReturnOp([retval])
case rg.RegionBegin(inports=ins):
portvalues = []
for i, k in enumerate(ins):
pv = state.get_region_args()[i]
portvalues.append(pv)
return tuple(portvalues)
case rg.RegionEnd(
begin=rg.RegionBegin() as begin,
ports=ports,
):
yield begin
portvalues = []
for p in ports:
pv = yield p.value
portvalues.append(pv)
return tuple(portvalues)
case rg.ArgRef(idx=int(idx), name=str(name)):
return state.function_block.arguments[idx]
case rg.Unpack(val=source, idx=int(idx)):
ports = yield source
return ports[idx]
case rg.DbgValue(value=value):
val = yield value
return val
case rg.PyInt(int(ival)):
with state.constant_block:
const = arith.constant(self.i64, ival)
return const
case rg.PyBool(int(ival)):
with state.constant_block:
const = arith.constant(self.boo, ival)
return const
case rg.PyFloat(float(fval)):
with state.constant_block:
const = arith.constant(self.f64, fval)
return const
case NbOp_Gt_Int64(lhs, rhs):
lhs = yield lhs
rhs = yield rhs
return arith.cmpi(4, lhs, rhs)
case NbOp_Add_Int64(lhs, rhs):
lhs = yield lhs
rhs = yield rhs
return arith.addi(lhs, rhs)
case NbOp_Sub_Int64(lhs, rhs):
lhs = yield lhs
rhs = yield rhs
return arith.subi(lhs, rhs)
case NbOp_Add_Float64(lhs, rhs):
lhs = yield lhs
rhs = yield rhs
return arith.addf(lhs, rhs)
case NbOp_Sub_Float64(lhs, rhs):
lhs = yield lhs
rhs = yield rhs
return arith.subf(lhs, rhs)
case NbOp_Lt_Int64(lhs, rhs):
lhs = yield lhs
rhs = yield rhs
return arith.cmpi(2, lhs, rhs)
case NbOp_Sub_Int64(lhs, rhs):
lhs = yield lhs
rhs = yield rhs
return arith.subi(lhs, rhs)
case NbOp_CastI64ToF64(operand):
val = yield operand
return arith.sitofp(self.f64, val)
case NbOp_Div_Int64(lhs, rhs):
lhs = yield lhs
rhs = yield rhs
return arith.divf(
arith.sitofp(self.f64, lhs), arith.sitofp(self.f64, rhs)
)
##### more
case NbOp_Not_Int64(operand):
# Implement unary not
opval = yield operand
return arith.cmpi(0, opval, arith.constant(self.i64, 0))
case rg.PyBool(val):
return arith.constant(self.boo, val)
case rg.PyInt(val):
return arith.constant(self.i64, val)
case rg.IfElse(
cond=cond, body=body, orelse=orelse, operands=operands
):
condval = yield cond
# process operands
rettys = Attributes(body.begin.attrs)
result_tys = []
for i in range(0, rettys.num_output_types() + 1):
out_ty = rettys.get_output_type(i)
if out_ty is not None:
match out_ty.name:
case "Int64":
result_tys.append(self.i64)
case "Float64":
result_tys.append(self.f64)
case "Bool":
result_tys.append(self.boo)
else:
result_tys.append(self.i32)
if_op = scf.IfOp(
cond=condval, results_=result_tys, hasElse=bool(orelse)
)
with ir.InsertionPoint(if_op.then_block):
value_else = yield body
scf.YieldOp([x for x in value_else])
with ir.InsertionPoint(if_op.else_block):
value_else = yield orelse
scf.YieldOp([x for x in value_else])
return if_op.results
case rg.Loop(body=rg.RegionEnd() as body, operands=operands):
rettys = Attributes(body.begin.attrs)
# process operands
ops = []
for op in operands:
ops.append((yield op))
result_tys = []
for i in range(1, rettys.num_output_types() + 1):
out_ty = rettys.get_output_type(i)
if out_ty is not None:
match out_ty.name:
case "Int64":
result_tys.append(self.i64)
case "Float64":
result_tys.append(self.f64)
case "Bool":
result_tys.append(self.boo)
else:
result_tys.append(self.i32)
while_op = scf.WhileOp(
results_=result_tys, inits=[op for op in ops]
)
before_block = while_op.before.blocks.append(*result_tys)
after_block = while_op.after.blocks.append(*result_tys)
new_ops = before_block.arguments
# Before Region
with ir.InsertionPoint(before_block), state.push(new_ops):
values = yield body
scf.ConditionOp(
args=[val for val in values[1:]], condition=values[0]
)
# After Region
with ir.InsertionPoint(after_block):
scf.YieldOp(after_block.arguments)
while_op_res = scf._get_op_results_or_values(while_op)
return while_op_res
case _:
raise NotImplementedError(expr, type(expr))
# ## JIT Compilation
#
# Implement JIT compilation for MLIR modules using the MLIR execution
# engine.
def jit_compile(self, llmod, func_node: rg.Func, func_name="func"):
"""JIT Compilation
Convert the MLIR module into a JIT-callable function using the MLIR
execution engine.
"""
attributes = Attributes(func_node.body.begin.attrs)
# Convert SealIR types into MLIR types
with self.loc:
input_types = tuple(
[self.lower_type(x) for x in attributes.input_types()]
)
output_types = (
self.lower_type(
Attributes(func_node.body.begin.attrs).get_return_type(
func_node.body
)
),
)
return self.jit_compile_extra(llmod, input_types, output_types)
def jit_compile_extra(
self,
llmod,
input_types,
output_types,
function_name="func",
exec_engine=None,
is_ufunc=False,
**execution_engine_params,
):
# Converts the MLIR module into a JIT-callable function.
# Use MLIR's own internal execution engine
if exec_engine is None:
engine = execution_engine.ExecutionEngine(
llmod, **execution_engine_params
)
else:
engine = exec_engine
assert (
len(output_types) == 1
), "Execution of functions with output arguments > 1 not supported"
nout = len(output_types)
# Build a wrapper function
def jit_func(*args):
if is_ufunc:
input_args = args[:-nout]
output_args = args[-nout:]
else:
input_args = args
output_args = [None]
assert len(input_args) == len(input_types)
for arg, arg_ty in zip(input_args, input_types):
# assert isinstance(arg, arg_ty)
# TODO: Assert types here
pass
# Transform the input arguments into C-types
# with their respective values. All inputs to
# the internal execution engine should
# be C-Type pointers.
input_exec_ptrs = [
self.get_exec_ptr(ty, val)[0]
for ty, val in zip(input_types, input_args)
]
# Invokes the function that we built, internally calls
# _mlir_ciface_function_name as a void pointer with the given
# input pointers, there can only be one resulting pointer
# appended to the end of all input pointers in the invoke call.
res_ptr, res_val = self.get_exec_ptr(
output_types[0], output_args[0]
)
engine.invoke(function_name, *input_exec_ptrs, res_ptr)
return self.get_out_val(res_ptr, res_val)
return jit_func
@classmethod
def get_exec_ptr(self, mlir_ty, val):
"""Get Execution Pointer
Convert MLIR types to C-types and allocate memory for the value.
"""
if isinstance(mlir_ty, ir.IntegerType):
val = 0 if val is None else val
ptr = ctypes.pointer(ctypes.c_int64(val))
elif isinstance(mlir_ty, ir.F32Type):
val = 0.0 if val is None else val
ptr = ctypes.pointer(ctypes.c_float(val))
elif isinstance(mlir_ty, ir.F64Type):
val = 0.0 if val is None else val
ptr = ctypes.pointer(ctypes.c_double(val))
elif isinstance(mlir_ty, ir.MemRefType):
if isinstance(mlir_ty.element_type, ir.F64Type):
np_dtype = np.float64
elif isinstance(mlir_ty.element_type, ir.F32Type):
np_dtype = np.float32
else:
raise TypeError(
"The current array element type is not supported"
)
if val is None:
if not mlir_ty.has_static_shape:
raise ValueError(f"{mlir_ty} does not have static shape")
val = np.zeros(mlir_ty.shape, dtype=np_dtype)
ptr = ctypes.pointer(
ctypes.pointer(runtime.get_ranked_memref_descriptor(val))
)
return ptr, val
@classmethod
def get_out_val(cls, res_ptr, res_val):
if isinstance(res_val, np.ndarray):
return res_val
else:
return res_ptr.contents.value
Example 1: Simple If-Else¶
Demonstrate the MLIR backend with a simple conditional function.
def example_1(a, b):
if a > b:
z = a - b
else:
z = b - a
return z + a
compiler_config = dict(
converter_class=LoopExtendEGraphToRVSDG,
backend=Backend(),
cost_model=MyCostModel(),
verbose=True,
)
class RunBEPassOutput(TypedDict):
module: Any
def pipeline_run_be_passes(
backend, module, pipeline_report=Report.Sink()
) -> RunBEPassOutput:
with pipeline_report.nest("MLIR passes") as report:
backend.run_passes(module)
report.append("MLIR optimized", module)
return dict(module=module)
jit_compiler = _ch04_1_jit_compiler.insert(-1, pipeline_run_be_passes)
if __name__ == "__main__":
display(jit_compiler.visualize())
report = Report("Pipeline execution report", enable_nested_metadata=True)
jit_func = jit_compiler(
fn=example_1,
argtypes=(Int64, Int64),
ruleset=(if_else_ruleset | setup_argtypes(TypeInt64, TypeInt64)),
pipeline_report=report,
**compiler_config,
).jit_func
report.display()
args = (10, 33)
run_test(example_1, jit_func, args, verbose=True)
args = (7, 3)
run_test(example_1, jit_func, args, verbose=True)
Pipeline execution report
time elapsed 0.00ms timing breakdown:
module { func.func @func(%arg0: i64, %arg1: i64) -> i64 attributes {llvm.emit_c_interface} { cf.br ^bb1 ^bb1: // pred: ^bb0 %c0_i32 = arith.constant 0 : i32 %0 = arith.cmpi sgt, %arg0, %arg1 : i64 %1:4 = scf.if %0 -> (i32, i64, i64, i64) { %3 = arith.subi %arg0, %arg1 : i64 scf.yield %c0_i32, %arg0, %arg1, %3 : i32, i64, i64, i64 } else { %3 = arith.subi %arg1, %arg0 : i64 scf.yield %c0_i32, %arg0, %arg1, %3 : i32, i64, i64, i64 } %2 = arith.addi %1#3, %1#1 : i64 return %2 : i64 } }
time elapsed 2.36ms timing breakdown: 2.36ms: Lowered module
module { llvm.func @func(%arg0: i64, %arg1: i64) -> i64 attributes {llvm.emit_c_interface} { %0 = llvm.mlir.constant(0 : i32) : i32 llvm.br ^bb1 ^bb1: // pred: ^bb0 %1 = llvm.icmp "sgt" %arg0, %arg1 : i64 llvm.cond_br %1, ^bb2, ^bb3 ^bb2: // pred: ^bb1 %2 = llvm.sub %arg0, %arg1 : i64 llvm.br ^bb4(%0, %arg0, %arg1, %2 : i32, i64, i64, i64) ^bb3: // pred: ^bb1 %3 = llvm.sub %arg1, %arg0 : i64 llvm.br ^bb4(%0, %arg0, %arg1, %3 : i32, i64, i64, i64) ^bb4(%4: i32, %5: i64, %6: i64, %7: i64): // 2 preds: ^bb2, ^bb3 llvm.br ^bb5 ^bb5: // pred: ^bb4 %8 = llvm.add %7, %5 : i64 llvm.return %8 : i64 } llvm.func @_mlir_ciface_func(%arg0: i64, %arg1: i64) -> i64 attributes {llvm.emit_c_interface} { %0 = llvm.call @func(%arg0, %arg1) : (i64, i64) -> i64 llvm.return %0 : i64 } }
time elapsed 1.41ms timing breakdown: 1.41ms: MLIR optimized
Testing report
(10, 33)
33
33
Testing report
(7, 3)
11
11
Example 2: Float Operations¶
Test the MLIR backend with float operations and type conversion.
def example_2(a, b):
if a > b:
z = float(a - b)
else:
z = float(b) - 1 / a
return z - float(a)
Add rules for float()
if __name__ == "__main__":
report = Report("Pipeline execution report", enable_nested_metadata=True)
jit_func = jit_compiler(
fn=example_2,
argtypes=(Int64, Int64),
ruleset=(
if_else_ruleset
| setup_argtypes(TypeInt64, TypeInt64)
| ruleset_type_infer_float # < --- added for float()
),
pipeline_report=report,
**compiler_config,
).jit_func
report.display()
args = (10, 33)
run_test(example_2, jit_func, args, verbose=True)
args = (7, 3)
run_test(example_2, jit_func, args, verbose=True)
Pipeline execution report
time elapsed 0.00ms timing breakdown:
module { func.func @func(%arg0: i64, %arg1: i64) -> f64 attributes {llvm.emit_c_interface} { %c1_i64 = arith.constant 1 : i64 cf.br ^bb1 ^bb1: // pred: ^bb0 %c0_i32 = arith.constant 0 : i32 %0 = arith.cmpi sgt, %arg0, %arg1 : i64 %1:4 = scf.if %0 -> (i32, i64, i64, f64) { %4 = arith.subi %arg0, %arg1 : i64 %5 = arith.sitofp %4 : i64 to f64 scf.yield %c0_i32, %arg0, %arg1, %5 : i32, i64, i64, f64 } else { %4 = arith.sitofp %arg1 : i64 to f64 %5 = arith.sitofp %c1_i64 : i64 to f64 %6 = arith.sitofp %arg0 : i64 to f64 %7 = arith.divf %5, %6 : f64 %8 = arith.subf %4, %7 : f64 scf.yield %c0_i32, %arg0, %arg1, %8 : i32, i64, i64, f64 } %2 = arith.sitofp %1#1 : i64 to f64 %3 = arith.subf %1#3, %2 : f64 return %3 : f64 } }
time elapsed 1.46ms timing breakdown: 1.46ms: Lowered module
module { llvm.func @func(%arg0: i64, %arg1: i64) -> f64 attributes {llvm.emit_c_interface} { %0 = llvm.mlir.constant(1.000000e+00 : f64) : f64 %1 = llvm.mlir.constant(0 : i32) : i32 llvm.br ^bb1 ^bb1: // pred: ^bb0 %2 = llvm.icmp "sgt" %arg0, %arg1 : i64 llvm.cond_br %2, ^bb2, ^bb3 ^bb2: // pred: ^bb1 %3 = llvm.sub %arg0, %arg1 : i64 %4 = llvm.sitofp %3 : i64 to f64 llvm.br ^bb4(%1, %arg0, %arg1, %4 : i32, i64, i64, f64) ^bb3: // pred: ^bb1 %5 = llvm.sitofp %arg1 : i64 to f64 %6 = llvm.sitofp %arg0 : i64 to f64 %7 = llvm.fdiv %0, %6 : f64 %8 = llvm.fsub %5, %7 : f64 llvm.br ^bb4(%1, %arg0, %arg1, %8 : i32, i64, i64, f64) ^bb4(%9: i32, %10: i64, %11: i64, %12: f64): // 2 preds: ^bb2, ^bb3 llvm.br ^bb5 ^bb5: // pred: ^bb4 %13 = llvm.sitofp %10 : i64 to f64 %14 = llvm.fsub %12, %13 : f64 llvm.return %14 : f64 } llvm.func @_mlir_ciface_func(%arg0: i64, %arg1: i64) -> f64 attributes {llvm.emit_c_interface} { %0 = llvm.call @func(%arg0, %arg1) : (i64, i64) -> f64 llvm.return %0 : f64 } }
time elapsed 1.23ms timing breakdown: 1.23ms: MLIR optimized
Testing report
(10, 33)
22.9
22.9
Testing report
(7, 3)
-3.0
-3.0
Example 3: Simple While Loop¶
Demonstrate loop compilation with the MLIR backend.
def example_3(init, n):
c = float(init)
i = 0
while i < n:
c = c + float(i)
i = i + 1
return c
if __name__ == "__main__":
report = Report("Pipeline execution report", enable_nested_metadata=True)
jit_func = jit_compiler(
fn=example_3,
argtypes=(Int64, Int64),
ruleset=(loop_ruleset | setup_argtypes(TypeInt64, TypeInt64)),
pipeline_report=report,
**compiler_config,
).jit_func
report.display()
run_test(example_3, jit_func, (10, 7), verbose=True)
Pipeline execution report
time elapsed 0.00ms timing breakdown:
module { func.func @func(%arg0: i64, %arg1: i64) -> f64 attributes {llvm.emit_c_interface} { %c0_i64 = arith.constant 0 : i64 %true = arith.constant true %c1_i64 = arith.constant 1 : i64 cf.br ^bb1 ^bb1: // pred: ^bb0 %c0_i32 = arith.constant 0 : i32 %0 = arith.sitofp %arg0 : i64 to f64 %1:7 = scf.while (%arg2 = %c0_i32, %arg3 = %c0_i64, %arg4 = %true, %arg5 = %0, %arg6 = %c0_i64, %arg7 = %arg0, %arg8 = %arg1) : (i32, i64, i1, f64, i64, i64, i64) -> (i32, i64, i1, f64, i64, i64, i64) { %2 = arith.cmpi slt, %arg6, %arg8 : i64 %3:7 = scf.if %2 -> (i32, i64, i1, f64, i64, i64, i64) { %5 = arith.sitofp %arg6 : i64 to f64 %6 = arith.addf %arg5, %5 : f64 %7 = arith.addi %arg6, %c1_i64 : i64 scf.yield %arg2, %c0_i64, %arg4, %6, %7, %arg7, %arg8 : i32, i64, i1, f64, i64, i64, i64 } else { scf.yield %arg2, %c1_i64, %arg4, %arg5, %arg6, %arg7, %arg8 : i32, i64, i1, f64, i64, i64, i64 } %c0_i64_0 = arith.constant 0 : i64 %4 = arith.cmpi eq, %3#1, %c0_i64_0 : i64 scf.condition(%4) %3#0, %3#1, %4, %3#3, %3#4, %3#5, %3#6 : i32, i64, i1, f64, i64, i64, i64 } do { ^bb0(%arg2: i32, %arg3: i64, %arg4: i1, %arg5: f64, %arg6: i64, %arg7: i64, %arg8: i64): scf.yield %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8 : i32, i64, i1, f64, i64, i64, i64 } return %1#3 : f64 } }
time elapsed 1.93ms timing breakdown: 1.93ms: Lowered module
module { llvm.func @func(%arg0: i64, %arg1: i64) -> f64 attributes {llvm.emit_c_interface} { %0 = llvm.mlir.constant(0 : i32) : i32 %1 = llvm.mlir.constant(0 : i64) : i64 %2 = llvm.mlir.constant(true) : i1 %3 = llvm.mlir.constant(1 : i64) : i64 llvm.br ^bb1 ^bb1: // pred: ^bb0 %4 = llvm.sitofp %arg0 : i64 to f64 llvm.br ^bb2(%0, %1, %2, %4, %1, %arg0, %arg1 : i32, i64, i1, f64, i64, i64, i64) ^bb2(%5: i32, %6: i64, %7: i1, %8: f64, %9: i64, %10: i64, %11: i64): // 2 preds: ^bb1, ^bb6 %12 = llvm.icmp "slt" %9, %11 : i64 llvm.cond_br %12, ^bb3, ^bb4 ^bb3: // pred: ^bb2 %13 = llvm.sitofp %9 : i64 to f64 %14 = llvm.fadd %8, %13 : f64 %15 = llvm.add %9, %3 : i64 llvm.br ^bb5(%5, %1, %7, %14, %15, %10, %11 : i32, i64, i1, f64, i64, i64, i64) ^bb4: // pred: ^bb2 llvm.br ^bb5(%5, %3, %7, %8, %9, %10, %11 : i32, i64, i1, f64, i64, i64, i64) ^bb5(%16: i32, %17: i64, %18: i1, %19: f64, %20: i64, %21: i64, %22: i64): // 2 preds: ^bb3, ^bb4 llvm.br ^bb6 ^bb6: // pred: ^bb5 %23 = llvm.icmp "eq" %17, %1 : i64 llvm.cond_br %23, ^bb2(%16, %17, %23, %19, %20, %21, %22 : i32, i64, i1, f64, i64, i64, i64), ^bb7 ^bb7: // pred: ^bb6 llvm.return %19 : f64 } llvm.func @_mlir_ciface_func(%arg0: i64, %arg1: i64) -> f64 attributes {llvm.emit_c_interface} { %0 = llvm.call @func(%arg0, %arg1) : (i64, i64) -> f64 llvm.return %0 : f64 } }
time elapsed 1.35ms timing breakdown: 1.35ms: MLIR optimized
Testing report
(10, 7)
31.0
31.0
Example 4: Nested Loop¶
Test nested loop compilation with the MLIR backend.
def example_4(init, n):
c = float(init)
i = 0
while i < n:
j = 0
while j < i:
c = c + float(j)
j = j + 1
i = i + 1
return c
if __name__ == "__main__":
report = Report("Pipeline execution report", enable_nested_metadata=True)
jit_func = jit_compiler(
fn=example_4,
argtypes=(Int64, Int64),
ruleset=(loop_ruleset | setup_argtypes(TypeInt64, TypeInt64)),
pipeline_report=report,
**compiler_config,
).jit_func
report.display()
run_test(example_4, jit_func, (10, 7), verbose=True)
Pipeline execution report
time elapsed 0.00ms timing breakdown:
module { func.func @func(%arg0: i64, %arg1: i64) -> f64 attributes {llvm.emit_c_interface} { %c0_i64 = arith.constant 0 : i64 %true = arith.constant true %false = arith.constant false %c1_i64 = arith.constant 1 : i64 cf.br ^bb1 ^bb1: // pred: ^bb0 %c0_i32 = arith.constant 0 : i32 %0 = arith.sitofp %arg0 : i64 to f64 %1:10 = scf.while (%arg2 = %c0_i32, %arg3 = %c0_i64, %arg4 = %c0_i64, %arg5 = %true, %arg6 = %false, %arg7 = %0, %arg8 = %c0_i64, %arg9 = %arg0, %arg10 = %c0_i64, %arg11 = %arg1) : (i32, i64, i64, i1, i1, f64, i64, i64, i64, i64) -> (i32, i64, i64, i1, i1, f64, i64, i64, i64, i64) { %2 = arith.cmpi slt, %arg8, %arg11 : i64 %3:10 = scf.if %2 -> (i32, i64, i64, i1, i1, f64, i64, i64, i64, i64) { %5:10 = scf.while (%arg12 = %arg2, %arg13 = %arg3, %arg14 = %arg4, %arg15 = %arg5, %arg16 = %true, %arg17 = %arg7, %arg18 = %arg8, %arg19 = %arg9, %arg20 = %c0_i64, %arg21 = %arg11) : (i32, i64, i64, i1, i1, f64, i64, i64, i64, i64) -> (i32, i64, i64, i1, i1, f64, i64, i64, i64, i64) { %7 = arith.cmpi slt, %arg20, %arg18 : i64 %8:10 = scf.if %7 -> (i32, i64, i64, i1, i1, f64, i64, i64, i64, i64) { %10 = arith.sitofp %arg20 : i64 to f64 %11 = arith.addf %arg17, %10 : f64 %12 = arith.addi %arg20, %c1_i64 : i64 scf.yield %arg12, %arg13, %c0_i64, %arg15, %arg16, %11, %arg18, %arg19, %12, %arg21 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64 } else { scf.yield %arg12, %arg13, %c1_i64, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64 } %c0_i64_1 = arith.constant 0 : i64 %9 = arith.cmpi eq, %8#2, %c0_i64_1 : i64 scf.condition(%9) %8#0, %8#1, %8#2, %8#3, %9, %8#5, %8#6, %8#7, %8#8, %8#9 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64 } do { ^bb0(%arg12: i32, %arg13: i64, %arg14: i64, %arg15: i1, %arg16: i1, %arg17: f64, %arg18: i64, %arg19: i64, %arg20: i64, %arg21: i64): scf.yield %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64 } %6 = arith.addi %5#6, %c1_i64 : i64 scf.yield %5#0, %c0_i64, %5#2, %5#3, %5#4, %5#5, %6, %5#7, %5#8, %5#9 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64 } else { scf.yield %arg2, %c1_i64, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64 } %c0_i64_0 = arith.constant 0 : i64 %4 = arith.cmpi eq, %3#1, %c0_i64_0 : i64 scf.condition(%4) %3#0, %3#1, %3#2, %4, %3#4, %3#5, %3#6, %3#7, %3#8, %3#9 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64 } do { ^bb0(%arg2: i32, %arg3: i64, %arg4: i64, %arg5: i1, %arg6: i1, %arg7: f64, %arg8: i64, %arg9: i64, %arg10: i64, %arg11: i64): scf.yield %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64 } return %1#5 : f64 } }
time elapsed 3.26ms timing breakdown: 3.26ms: Lowered module
module { llvm.func @func(%arg0: i64, %arg1: i64) -> f64 attributes {llvm.emit_c_interface} { %0 = llvm.mlir.constant(0 : i32) : i32 %1 = llvm.mlir.constant(0 : i64) : i64 %2 = llvm.mlir.constant(true) : i1 %3 = llvm.mlir.constant(false) : i1 %4 = llvm.mlir.constant(1 : i64) : i64 llvm.br ^bb1 ^bb1: // pred: ^bb0 %5 = llvm.sitofp %arg0 : i64 to f64 llvm.br ^bb2(%0, %1, %1, %2, %3, %5, %1, %arg0, %1, %arg1 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64) ^bb2(%6: i32, %7: i64, %8: i64, %9: i1, %10: i1, %11: f64, %12: i64, %13: i64, %14: i64, %15: i64): // 2 preds: ^bb1, ^bb12 %16 = llvm.icmp "slt" %12, %15 : i64 llvm.cond_br %16, ^bb3, ^bb10 ^bb3: // pred: ^bb2 llvm.br ^bb4(%6, %7, %8, %9, %2, %11, %12, %13, %1, %15 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64) ^bb4(%17: i32, %18: i64, %19: i64, %20: i1, %21: i1, %22: f64, %23: i64, %24: i64, %25: i64, %26: i64): // 2 preds: ^bb3, ^bb8 %27 = llvm.icmp "slt" %25, %23 : i64 llvm.cond_br %27, ^bb5, ^bb6 ^bb5: // pred: ^bb4 %28 = llvm.sitofp %25 : i64 to f64 %29 = llvm.fadd %22, %28 : f64 %30 = llvm.add %25, %4 : i64 llvm.br ^bb7(%17, %18, %1, %20, %21, %29, %23, %24, %30, %26 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64) ^bb6: // pred: ^bb4 llvm.br ^bb7(%17, %18, %4, %20, %21, %22, %23, %24, %25, %26 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64) ^bb7(%31: i32, %32: i64, %33: i64, %34: i1, %35: i1, %36: f64, %37: i64, %38: i64, %39: i64, %40: i64): // 2 preds: ^bb5, ^bb6 llvm.br ^bb8 ^bb8: // pred: ^bb7 %41 = llvm.icmp "eq" %33, %1 : i64 llvm.cond_br %41, ^bb4(%31, %32, %33, %34, %41, %36, %37, %38, %39, %40 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64), ^bb9 ^bb9: // pred: ^bb8 %42 = llvm.add %37, %4 : i64 llvm.br ^bb11(%31, %1, %33, %34, %41, %36, %42, %38, %39, %40 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64) ^bb10: // pred: ^bb2 llvm.br ^bb11(%6, %4, %8, %9, %10, %11, %12, %13, %14, %15 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64) ^bb11(%43: i32, %44: i64, %45: i64, %46: i1, %47: i1, %48: f64, %49: i64, %50: i64, %51: i64, %52: i64): // 2 preds: ^bb9, ^bb10 llvm.br ^bb12 ^bb12: // pred: ^bb11 %53 = llvm.icmp "eq" %44, %1 : i64 llvm.cond_br %53, ^bb2(%43, %44, %45, %53, %47, %48, %49, %50, %51, %52 : i32, i64, i64, i1, i1, f64, i64, i64, i64, i64), ^bb13 ^bb13: // pred: ^bb12 llvm.return %48 : f64 } llvm.func @_mlir_ciface_func(%arg0: i64, %arg1: i64) -> f64 attributes {llvm.emit_c_interface} { %0 = llvm.call @func(%arg0, %arg1) : (i64, i64) -> f64 llvm.return %0 : f64 } }
time elapsed 1.74ms timing breakdown: 1.74ms: MLIR optimized
Testing report
(10, 7)
45.0
45.0