Chapter 4 Part 0: Type Inference Prelude¶

This chapter introduces the basics of type inference in the compiler pipeline. We show how to add type inference logic for scalar operations and how to extend the compiler pipeline to support type-aware rewrites.

The chapter covers:

  • How to represent types in the e-graph
  • How to add type inference rules for basic operations
  • How to extend the compiler pipeline for type inference

Imports and Setup¶

In [1]:
from egglog import (
    EGraph,
    Expr,
    String,
    StringLike,
    function,
    rewrite,
    rule,
    ruleset,
    union,
)
from sealir import grammar, rvsdg
from sealir.eqsat import rvsdg_eqsat
from sealir.eqsat.py_eqsat import Py_AddIO
from sealir.eqsat.rvsdg_convert import egraph_conversion
from sealir.eqsat.rvsdg_eqsat import GraphRoot, PortList, Region, Term
from sealir.eqsat.rvsdg_extract import (
    CostModel,
    EGraphToRVSDG,
    egraph_extraction,
)
from sealir.llvm_pyapi_backend import SSAValue
In [2]:
from ch02_egraph_basic import (
    BackendOutput,
    EGraphExtractionOutput,
    backend,
    jit_compile,
    pipeline_egraph_extraction,
)
from ch03_egraph_program_rewrites import (
    compiler_pipeline as _ch03_compiler_pipeline,
)
from ch03_egraph_program_rewrites import (
    ruleset_const_propagate,
    run_test,
)
from utils import IN_NOTEBOOK, Report, display

First, we need some modifications to the compiler-pipeline. The middle-end is augmented with the following:

  • converter_class is for customizing EGraph-to-RVSDG conversion as we will be introducing new RVSDG operations for typed operations.
  • cost_model is for customizing the cost of the new operations.
In [3]:
def pipeline_egraph_extraction(
    egraph,
    rvsdg_expr,
    converter_class,
    cost_model,
    pipeline_report=Report.Sink(),
) -> EGraphExtractionOutput:
    with pipeline_report.nest(
        "EGraph Extraction", default_expanded=True
    ) as report:
        cost, extracted = egraph_extraction(
            egraph,
            rvsdg_expr,
            converter_class=converter_class,  # <---- new
            cost_model=cost_model,  # <-------------- new
        )
        report.append("Cost", cost)
        report.append("Extracted", rvsdg.format_rvsdg(extracted))
        return {"cost": cost, "extracted": extracted}
In [4]:
pipeline_new_extract = _ch03_compiler_pipeline.replace(
    "pipeline_egraph_extraction", pipeline_egraph_extraction
)

The compiler_pipeline will have a codegen_extension for defining LLVM code-generation for the new operations.

In [5]:
def extended_backend(
    extracted, codegen_extension, pipeline_report=Report.Sink()
) -> BackendOutput:
    with pipeline_report.nest("Backend", default_expanded=True) as report:
        llmod = backend(extracted, codegen_extension=codegen_extension)
        report.append("LLVM", llmod)
        jt = jit_compile(llmod, extracted)
        return {"jit_func": jt, "llmod": llmod}
In [6]:
# extend the pipeline with the new backend
pipeline_backend = pipeline_new_extract.replace(
    "pipeline_backend", extended_backend
)
compiler_pipeline = pipeline_backend
In [7]:
# visualize the pipeline
if __name__ == "__main__":
    display(compiler_pipeline.visualize())
Pipeline-extended_backend cluster_legend Legend input_codegen_extension codegen_extension Any stage_5 Stage 5 extended_backend → BackendOutput input_codegen_extension->stage_5 codegen_extension input_converter_class converter_class Any stage_4 Stage 4 pipeline_egraph_extraction → EGraphExtractionOutput input_converter_class->stage_4 converter_class input_cost_model cost_model Any input_cost_model->stage_4 cost_model input_fn fn Any stage_1 Stage 1 pipeline_frontend → FrontendOutput input_fn->stage_1 fn input_ruleset ruleset Ruleset stage_3 Stage 3 egraph_saturation → EGraphOutput input_ruleset->stage_3 ruleset input_pipeline_report pipeline_report Any = DummyReport() input_pipeline_report->stage_1 pipeline_report stage_2 Stage 2 pipeline_egraph_conversion → EGraphOutput input_pipeline_report->stage_2 pipeline_report input_pipeline_report->stage_3 pipeline_report input_pipeline_report->stage_4 pipeline_report input_pipeline_report->stage_5 pipeline_report stage_1->stage_2 rvsdg_expr stage_1->stage_4 rvsdg_expr out_dbginfo dbginfo object stage_1->out_dbginfo dbginfo out_rvsdg_expr rvsdg_expr object stage_1->out_rvsdg_expr rvsdg_expr stage_2->stage_3 egraph stage_2->stage_3 egraph_root stage_3->stage_4 egraph out_egraph egraph EGraph stage_3->out_egraph egraph out_egraph_root egraph_root GraphRoot stage_3->out_egraph_root egraph_root stage_4->stage_5 extracted out_cost cost float stage_4->out_cost cost out_extracted extracted Any stage_4->out_extracted extracted out_jit_func jit_func Any stage_5->out_jit_func jit_func out_llmod llmod Any stage_5->out_llmod llmod legend_input Required Input legend_optional Optional Input legend_input->legend_optional legend_stage Processing Stage legend_optional->legend_stage legend_output Output legend_stage->legend_output

A Simple Type Inference Example¶

First, we will start with a simple binary add operation.

In [8]:
def add_x_y(x, y):
    return x + y

We will start with the same ruleset as in chapter 3.

In [9]:
basic_ruleset = rvsdg_eqsat.ruleset_rvsdg_basic | ruleset_const_propagate

We will test our base compiler (ch 3 compiler behavior) on our function to set the baseline. At this stage, no type inference is happening.

In [10]:
if __name__ == "__main__":
    # start with previous compiler pipeline
    report = Report("Compiler Pipeline", default_expanded=True)
    jt = compiler_pipeline(
        fn=add_x_y,
        ruleset=basic_ruleset,
        converter_class=EGraphToRVSDG,
        codegen_extension=None,
        cost_model=None,
        pipeline_report=report,
    ).jit_func
    report.display()
    run_test(add_x_y, jt, (123, 321), verbose=True)

Compiler Pipeline

1. Frontend ▶
Frontend
Debug Info on RVSDG ▶
--------------------------------original source---------------------------------
   1|def add_x_y(x, y):
   2|    return x + y
----------------------------------inter source----------------------------------
   1|def transformed_add_x_y(x, y):
   2|    """#file: /tmp/ipykernel_3518/1710880194.py"""
   3|    '#loc: 2:8-2:20'
   4|    return x + y
RVSDG ▶
transformed_add_x_y = Func (Args (ArgSpec 'x' (PyNone)) (ArgSpec 'y' (PyNone)))
$0 = Region[239] <- !io x y
{
  $1 = PyBinOp + $0[0] $0[1], $0[2]
} [314] -> !io=$1[0] !ret=$1[1]
2. EGraph Conversion ▶
EGraph Conversion
EGraph ▶
outer_cluster_InPorts-0 cluster_InPorts-0 outer_cluster_Port-9 cluster_Port-9 outer_cluster_Port-7 cluster_Port-7 outer_cluster_PortList-10 cluster_PortList-10 outer_cluster_Region-1 cluster_Region-1 outer_cluster_Term-13 cluster_Term-13 outer_cluster_Term-12 cluster_Term-12 outer_cluster_Term-4 cluster_Term-4 outer_cluster_Term-6 cluster_Term-6 outer_cluster_Term-8 cluster_Term-8 outer_cluster_Term-3 cluster_Term-3 outer_cluster_Term-5 cluster_Term-5 outer_cluster_Term-11 cluster_Term-11 outer_cluster_Term-2 cluster_Term-2 outer_cluster_Vec_Port-0 cluster_Vec_Port-0 outer_cluster_Vec_String-0 cluster_Vec_String-0 function-0-InPorts___init__:s->primitive-Vec_String-0 function-1-Port___init__:s->function-1-Term_getPort function-1-Term_getPort:s->function-0-Py_AddIO function-0-Port___init__:s->function-0-Term_getPort function-0-Term_getPort:s->function-0-Py_AddIO function-0-PortList___init__:s->primitive-Vec_Port-0 primitive-Vec_Port-0:s->function-1-Port___init__ primitive-Vec_Port-0:s->function-0-Port___init__ function-0-Region___init__:s->function-0-InPorts___init__ function-0-GraphRoot:s->function-0-Term_Func function-0-Term_Func:s->function-0-Term_RegionEnd function-0-Term_RegionEnd:s->function-0-PortList___init__ function-0-Term_RegionEnd:s->function-0-Region___init__ function-2-Region_get:s->function-0-Region___init__ function-0-Py_AddIO:s->function-2-Region_get function-0-Py_AddIO:s->function-1-Region_get function-0-Py_AddIO:s->function-0-Region_get function-1-Region_get:s->function-0-Region___init__ function-0-Region_get:s->function-0-Region___init__ function-0-InPorts___init__ InPorts primitive-Vec_String-0 Vec("!io", "x", "y") function-1-Port___init__ Port("!ret", ·) function-1-Term_getPort ·.getPort(·, 1) function-0-Port___init__ Port("!io", ·) function-0-Term_getPort ·.getPort(·, 0) function-0-PortList___init__ PortList primitive-Vec_Port-0 Vec function-0-Region___init__ Region("239", ·) function-0-GraphRoot GraphRoot function-0-Term_Func Term.Func("320", "transformed_add_x_y", ·) function-0-Term_RegionEnd Term.RegionEnd function-2-Region_get ·.get(·, 2) function-0-Py_AddIO Py_AddIO function-1-Region_get ·.get(·, 1) function-0-Region_get ·.get(·, 0)
3. EGraph Saturated ▶
outer_cluster_InPorts-0 cluster_InPorts-0 outer_cluster_Port-25 cluster_Port-25 outer_cluster_Port-20 cluster_Port-20 outer_cluster_PortList-10 cluster_PortList-10 outer_cluster_Region-1 cluster_Region-1 outer_cluster_String-10 cluster_String-10 outer_cluster_String-2684354572 cluster_String-2684354572 outer_cluster_Term-3 cluster_Term-3 outer_cluster_Term-5 cluster_Term-5 outer_cluster_Term-22 cluster_Term-22 outer_cluster_Term-13 cluster_Term-13 outer_cluster_Term-27 cluster_Term-27 outer_cluster_Term-2 cluster_Term-2 outer_cluster_Term-12 cluster_Term-12 outer_cluster_Term-28 cluster_Term-28 outer_cluster_Term-17 cluster_Term-17 outer_cluster_Term-4 cluster_Term-4 outer_cluster_Term-11 cluster_Term-11 outer_cluster_Vec_Port-2 cluster_Vec_Port-2 outer_cluster_Vec_String-0 cluster_Vec_String-0 function-0-InPorts___init__:s->primitive-Vec_String-0 function-1-PortList___getitem__:s->function-0-PortList___init__ function-0-PortList___init__:s->primitive-Vec_Port-2 function-1-Port___init__:s->function-1-Port_value function-1-Port_value:s->function-1-Port___init__ function-0-PortList___getitem__:s->function-0-PortList___init__ function-0-Port___init__:s->function-0-Port_value function-0-Port_value:s->function-0-Port___init__ primitive-Vec_Port-2:s->function-1-PortList___getitem__ primitive-Vec_Port-2:s->function-0-PortList___getitem__ function-0-Region___init__:s->function-0-InPorts___init__ function-2-Port_name:s->function-0-PortList___getitem__ function-3-Port_name:s->function-1-PortList___getitem__ function-1-Region_get:s->function-0-Region___init__ function-0-Py_AddIO:s->function-1-Region_get function-0-Py_AddIO:s->function-0-Region_get function-0-Py_AddIO:s->function-2-Region_get function-0-Region_get:s->function-0-Region___init__ function-2-Region_get:s->function-0-Region___init__ function-1-PortList_getValue:s->function-0-PortList___init__ function-4-Term_getPort:s->function-0-Py_AddIO function-0-GraphRoot:s->function-0-Term_Func function-0-Term_Func:s->function-0-Term_RegionEnd function-5-Term_getPort:s->function-0-Term_RegionEnd function-0-Term_RegionEnd:s->function-0-PortList___init__ function-0-Term_RegionEnd:s->function-0-Region___init__ function-6-Term_getPort:s->function-0-Term_RegionEnd function-0-PortList_getValue:s->function-0-PortList___init__ function-3-Term_getPort:s->function-0-Py_AddIO function-0-InPorts___init__ InPorts primitive-Vec_String-0 Vec("!io", "x", "y") function-1-PortList___getitem__ ·[1] function-0-PortList___init__ PortList function-1-Port___init__ Port("!ret", ·) function-1-Port_value ·.value function-0-PortList___getitem__ ·[0] function-0-Port___init__ Port("!io", ·) function-0-Port_value ·.value primitive-Vec_Port-2 Vec function-0-Region___init__ Region("239", ·) function-2-Port_name ·.name primitive-String-10 "!io" function-3-Port_name ·.name primitive-String-2684354572 "!ret" function-1-Region_get ·.get(·, 1) function-0-Py_AddIO Py_AddIO function-0-Region_get ·.get(·, 0) function-2-Region_get ·.get(·, 2) function-1-PortList_getValue ·.getValue(·, 1) function-4-Term_getPort ·.getPort(·, 1) function-0-GraphRoot GraphRoot function-0-Term_Func Term.Func("320", "transformed_add_x_y", ·) function-5-Term_getPort ·.getPort(·, 0) function-0-Term_RegionEnd Term.RegionEnd function-6-Term_getPort ·.getPort(·, 1) function-0-PortList_getValue ·.getValue(·, 0) function-3-Term_getPort ·.getPort(·, 0)
4. EGraph Extraction ▶
EGraph Extraction
Cost ▶
84347.0
Extracted ▶
transformed_add_x_y = Func (Args (ArgSpec 'x' (PyNone)) (ArgSpec 'y' (PyNone)))
$0 = Region[401] <- !io x y
{
  $1 = PyBinOp + $0[0] $0[1], $0[2]
} [450] -> !io=$1[0] !ret=$1[1]
5. Backend ▶
Backend
LLVM ▶
; ModuleID = ""
target triple = "unknown-unknown-unknown"
target datalayout = ""

define ptr @"foo"(ptr %".1", ptr %".2")
{
.4:
  %".5" = alloca ptr
  store ptr null, ptr %".5"
  br label %".7"
.7:
  br label %".9"
.9:
  %".11" = call ptr @"PyNumber_Add"(ptr %".1", ptr %".2")
  ret ptr %".11"
}

declare ptr @"PyNumber_Add"(ptr %".1", ptr %".2")

Testing report

1. Args ▶
(123, 321)
2. JIT output ▶
444
3. Expected output ▶
444

Adding type inference¶

A new EGraph expression class (Expr) is added to represent type:

In [11]:
class Type(Expr):
    def __init__(self, name: StringLike): ...

Then, we add a EGraph function to determine the type-of a Term:

In [12]:
@function
def TypeOf(x: Term) -> Type: ...

Next, we define functions for the new operations:

  • Nb_Unbox_Int64 unboxes a PyObject into a Int64.
  • Nb_Box_Int64 boxes a Int64 into a PyObject.
  • Nb_Unboxed_Add_Int64 performs a Int64 addition on unboxed operands.
In [13]:
@function
def Nb_Unbox_Int64(val: Term) -> Term: ...
@function
def Nb_Box_Int64(val: Term) -> Term: ...
@function
def Nb_Unboxed_Add_Int64(lhs: Term, rhs: Term) -> Term: ...

Now, we define the first type-inference rule:

If a Py_AddIO() (a Python binary add operation) is applied to operands that are known Int64, convert it into the unboxed add. The output type will be Int64. The IO state into the Py_AddIO() will be unchanged.

In [14]:
TypeInt64 = Type("Int64")


@ruleset
def ruleset_type_infer_add(io: Term, x: Term, y: Term, add: Term):
    yield rule(
        add == Py_AddIO(io, x, y),
        TypeOf(x) == TypeInt64,
        TypeOf(y) == TypeInt64,
    ).then(
        # convert to a typed operation
        union(add.getPort(1)).with_(
            Nb_Box_Int64(
                Nb_Unboxed_Add_Int64(Nb_Unbox_Int64(x), Nb_Unbox_Int64(y))
            )
        ),
        # shortcut io
        union(add.getPort(0)).with_(io),
        # output type
        union(TypeOf(add.getPort(1))).with_(TypeInt64),
    )

The following rule defines some fact about the function being compiled. It declares that the two arguments are Int64.

In [15]:
@ruleset
def facts_argument_types(
    outports: PortList,
    func_uid: String,
    fname: String,
    region: Region,
    arg_x: Term,
    arg_y: Term,
):
    yield rule(
        GraphRoot(
            Term.Func(
                body=Term.RegionEnd(region=region, ports=outports),
                uid=func_uid,
                fname=fname,
            )
        ),
        arg_x == region.get(1),
        arg_y == region.get(2),
    ).then(
        union(TypeOf(arg_x)).with_(TypeInt64),
        union(TypeOf(arg_y)).with_(TypeInt64),
    )

Defining conversion into RVSDG¶

We will expand the RVSDG grammar with the typed operations.

Each of the new typed operations will require a corresponding grammar rule.

In [16]:
SExpr = rvsdg.grammar.SExpr


class NbOp_Base(grammar.Rule):
    pass


class NbOp_Unboxed_Add_Int64(NbOp_Base):
    lhs: SExpr
    rhs: SExpr


class NbOp_Unbox_Int64(NbOp_Base):
    val: SExpr


class NbOp_Box_Int64(NbOp_Base):
    val: SExpr

The new grammar for our IR is a combination of the new typed-operation grammar and the base RVSDG grammar.

In [17]:
class Grammar(grammar.Grammar):
    start = rvsdg.Grammar.start | NbOp_Base

Now, we define a EGraph-to-RVSDG conversion class that is expanded to handle the new grammar.

In [18]:
class ExtendEGraphToRVSDG(EGraphToRVSDG):
    grammar = Grammar

    def handle_Term(self, op: str, children: dict | list, grm: Grammar):
        match op, children:
            case "Nb_Unboxed_Add_Int64", {"lhs": lhs, "rhs": rhs}:
                return grm.write(NbOp_Unboxed_Add_Int64(lhs=lhs, rhs=rhs))
            case "Nb_Unbox_Int64", {"val": val}:
                return grm.write(NbOp_Unbox_Int64(val=val))
            case "Nb_Box_Int64", {"val": val}:
                return grm.write(NbOp_Box_Int64(val=val))
            case _:
                # Use parent's implementation for other terms.
                return super().handle_Term(op, children, grm)

The LLVM code-generation also needs an extension:

In [19]:
def codegen_extension(expr, args, builder, pyapi):
    match expr._head, args:
        case "NbOp_Unboxed_Add_Int64", (lhs, rhs):
            return SSAValue(builder.add(lhs.value, rhs.value))
        case "NbOp_Unbox_Int64", (val,):
            return SSAValue(pyapi.long_as_longlong(val.value))
        case "NbOp_Box_Int64", (val,):
            return SSAValue(pyapi.long_from_longlong(val.value))
    return NotImplemented

A new cost model to prioritize the typed operations:

In [20]:
class MyCostModel(CostModel):
    def get_cost_function(self, nodename, op, ty, cost, children):
        self_cost = None
        match op:
            case "Nb_Unboxed_Add_Int64":
                self_cost = 0.1

            case "Nb_Unbox_Int64":
                self_cost = 0.1

            case "Nb_Box_Int64":
                self_cost = 0.1

        if self_cost is not None:
            return self.get_simple(self_cost)

        # Fallthrough to parent's cost function
        return super().get_cost_function(nodename, op, ty, cost, children)

The new ruleset with the type inference logic and facts about the compiled function:

In [21]:
typeinfer_ruleset = (
    basic_ruleset | ruleset_type_infer_add | facts_argument_types
)

We are now ready to run the compiler:

In [22]:
if __name__ == "__main__":
    report = Report("Compiler Pipeline", default_expanded=True)
    jt = compiler_pipeline(
        fn=add_x_y,
        ruleset=typeinfer_ruleset,
        converter_class=ExtendEGraphToRVSDG,
        codegen_extension=codegen_extension,
        cost_model=MyCostModel(),
        pipeline_report=report,
    ).jit_func
    report.display()
    run_test(add_x_y, jt, (123, 321), verbose=True)

Compiler Pipeline

1. Frontend ▶
Frontend
Debug Info on RVSDG ▶
--------------------------------original source---------------------------------
   1|def add_x_y(x, y):
   2|    return x + y
----------------------------------inter source----------------------------------
   1|def transformed_add_x_y(x, y):
   2|    """#file: /tmp/ipykernel_3518/1710880194.py"""
   3|    '#loc: 2:8-2:20'
   4|    return x + y
RVSDG ▶
transformed_add_x_y = Func (Args (ArgSpec 'x' (PyNone)) (ArgSpec 'y' (PyNone)))
$0 = Region[239] <- !io x y
{
  $1 = PyBinOp + $0[0] $0[1], $0[2]
} [314] -> !io=$1[0] !ret=$1[1]
2. EGraph Conversion ▶
EGraph Conversion
EGraph ▶
outer_cluster_InPorts-0 cluster_InPorts-0 outer_cluster_Port-7 cluster_Port-7 outer_cluster_Port-9 cluster_Port-9 outer_cluster_PortList-10 cluster_PortList-10 outer_cluster_Region-1 cluster_Region-1 outer_cluster_Term-5 cluster_Term-5 outer_cluster_Term-13 cluster_Term-13 outer_cluster_Term-11 cluster_Term-11 outer_cluster_Term-4 cluster_Term-4 outer_cluster_Term-2 cluster_Term-2 outer_cluster_Term-6 cluster_Term-6 outer_cluster_Term-12 cluster_Term-12 outer_cluster_Term-8 cluster_Term-8 outer_cluster_Term-3 cluster_Term-3 outer_cluster_Vec_Port-0 cluster_Vec_Port-0 outer_cluster_Vec_String-0 cluster_Vec_String-0 function-0-InPorts___init__:s->primitive-Vec_String-0 function-0-Port___init__:s->function-0-Term_getPort function-0-Term_getPort:s->function-0-Py_AddIO function-1-Port___init__:s->function-1-Term_getPort function-1-Term_getPort:s->function-0-Py_AddIO function-0-PortList___init__:s->primitive-Vec_Port-0 primitive-Vec_Port-0:s->function-0-Port___init__ primitive-Vec_Port-0:s->function-1-Port___init__ function-0-Region___init__:s->function-0-InPorts___init__ function-0-Py_AddIO:s->function-0-Region_get function-0-Py_AddIO:s->function-1-Region_get function-0-Py_AddIO:s->function-2-Region_get function-0-Region_get:s->function-0-Region___init__ function-1-Region_get:s->function-0-Region___init__ function-2-Region_get:s->function-0-Region___init__ function-0-GraphRoot:s->function-0-Term_Func function-0-Term_Func:s->function-0-Term_RegionEnd function-0-Term_RegionEnd:s->function-0-PortList___init__ function-0-Term_RegionEnd:s->function-0-Region___init__ function-0-InPorts___init__ InPorts primitive-Vec_String-0 Vec("!io", "x", "y") function-0-Port___init__ Port("!io", ·) function-0-Term_getPort ·.getPort(·, 0) function-1-Port___init__ Port("!ret", ·) function-1-Term_getPort ·.getPort(·, 1) function-0-PortList___init__ PortList primitive-Vec_Port-0 Vec function-0-Region___init__ Region("239", ·) function-0-Py_AddIO Py_AddIO function-0-Region_get ·.get(·, 0) function-1-Region_get ·.get(·, 1) function-2-Region_get ·.get(·, 2) function-0-GraphRoot GraphRoot function-0-Term_Func Term.Func("320", "transformed_add_x_y", ·) function-0-Term_RegionEnd Term.RegionEnd
3. EGraph Saturated ▶
outer_cluster_InPorts-0 cluster_InPorts-0 outer_cluster_Port-28 cluster_Port-28 outer_cluster_Port-33 cluster_Port-33 outer_cluster_PortList-10 cluster_PortList-10 outer_cluster_Region-1 cluster_Region-1 outer_cluster_String-10 cluster_String-10 outer_cluster_String-2684354572 cluster_String-2684354572 outer_cluster_Term-24 cluster_Term-24 outer_cluster_Term-13 cluster_Term-13 outer_cluster_Term-23 cluster_Term-23 outer_cluster_Term-20 cluster_Term-20 outer_cluster_Term-11 cluster_Term-11 outer_cluster_Term-5 cluster_Term-5 outer_cluster_Term-4 cluster_Term-4 outer_cluster_Term-36 cluster_Term-36 outer_cluster_Term-12 cluster_Term-12 outer_cluster_Term-25 cluster_Term-25 outer_cluster_Term-3 cluster_Term-3 outer_cluster_Term-30 cluster_Term-30 outer_cluster_Term-35 cluster_Term-35 outer_cluster_Type-27 cluster_Type-27 outer_cluster_Vec_Port-2 cluster_Vec_Port-2 outer_cluster_Vec_String-0 cluster_Vec_String-0 function-0-InPorts___init__:s->primitive-Vec_String-0 function-0-Port___init__:s->function-4-Region_get function-4-Region_get:s->function-0-Region___init__ function-0-PortList___getitem__:s->function-0-PortList___init__ function-0-PortList___init__:s->primitive-Vec_Port-2 function-1-Port___init__:s->function-1-Port_value function-1-Port_value:s->function-1-Port___init__ function-1-PortList___getitem__:s->function-0-PortList___init__ primitive-Vec_Port-2:s->function-0-PortList___getitem__ primitive-Vec_Port-2:s->function-1-PortList___getitem__ function-0-Region___init__:s->function-0-InPorts___init__ function-2-Port_name:s->function-0-PortList___getitem__ function-3-Port_name:s->function-1-PortList___getitem__ function-1-Nb_Unbox_Int64:s->function-2-Region_get function-2-Region_get:s->function-0-Region___init__ function-0-GraphRoot:s->function-0-Term_Func function-0-Term_Func:s->function-0-Term_RegionEnd function-0-Nb_Unbox_Int64:s->function-1-Region_get function-1-Region_get:s->function-0-Region___init__ function-3-Term_getPort:s->function-0-Py_AddIO function-0-Py_AddIO:s->function-2-Region_get function-0-Py_AddIO:s->function-1-Region_get function-0-Py_AddIO:s->function-0-PortList_getValue function-0-Port_value:s->function-0-Port___init__ function-0-PortList_getValue:s->function-0-PortList___init__ function-0-Term_RegionEnd:s->function-0-PortList___init__ function-0-Term_RegionEnd:s->function-0-Region___init__ function-6-Term_getPort:s->function-0-Term_RegionEnd function-0-Nb_Unboxed_Add_Int64:s->function-1-Nb_Unbox_Int64 function-0-Nb_Unboxed_Add_Int64:s->function-0-Nb_Unbox_Int64 function-0-Nb_Box_Int64:s->function-0-Nb_Unboxed_Add_Int64 function-4-Term_getPort:s->function-0-Py_AddIO function-1-PortList_getValue:s->function-0-PortList___init__ function-5-Term_getPort:s->function-0-Term_RegionEnd function-2-TypeOf:s->function-1-PortList_getValue function-1-TypeOf:s->function-1-Region_get function-0-TypeOf:s->function-2-Region_get function-0-InPorts___init__ InPorts primitive-Vec_String-0 Vec("!io", "x", "y") function-0-Port___init__ Port("!io", ·) function-4-Region_get ·.get(·, 0) function-0-PortList___getitem__ ·[0] function-0-PortList___init__ PortList function-1-Port___init__ Port("!ret", ·) function-1-Port_value ·.value function-1-PortList___getitem__ ·[1] primitive-Vec_Port-2 Vec function-0-Region___init__ Region("239", ·) function-2-Port_name ·.name primitive-String-10 "!io" function-3-Port_name ·.name primitive-String-2684354572 "!ret" function-1-Nb_Unbox_Int64 Nb_Unbox_Int64 function-2-Region_get ·.get(·, 2) function-0-GraphRoot GraphRoot function-0-Term_Func Term.Func("320", "transformed_add_x_y", ·) function-0-Nb_Unbox_Int64 Nb_Unbox_Int64 function-1-Region_get ·.get(·, 1) function-3-Term_getPort ·.getPort(·, 0) function-0-Py_AddIO Py_AddIO function-0-Port_value ·.value function-0-PortList_getValue ·.getValue(·, 0) function-0-Term_RegionEnd Term.RegionEnd function-6-Term_getPort ·.getPort(·, 1) function-0-Nb_Unboxed_Add_Int64 Nb_Unboxed_Add_Int64 function-0-Nb_Box_Int64 Nb_Box_Int64 function-4-Term_getPort ·.getPort(·, 1) function-1-PortList_getValue ·.getValue(·, 1) function-5-Term_getPort ·.getPort(·, 0) function-2-TypeOf TypeOf function-1-TypeOf TypeOf function-0-TypeOf TypeOf function-0-Type___init__ Type("Int64")
4. EGraph Extraction ▶
EGraph Extraction
Cost ▶
10964.6
Extracted ▶
transformed_add_x_y = Func (Args (ArgSpec 'x' (PyNone)) (ArgSpec 'y' (PyNone)))
$0 = Region[401] <- !io x y
{
  $1 = NbOp_Unbox_Int64 $0[1]
  $2 = NbOp_Unbox_Int64 $0[2]
  $3 = NbOp_Unboxed_Add_Int64 $1 $2
  $4 = NbOp_Box_Int64 $3
} [450] -> !io=$0[0] !ret=$4
5. Backend ▶
Backend
LLVM ▶
; ModuleID = ""
target triple = "unknown-unknown-unknown"
target datalayout = ""

define ptr @"foo"(ptr %".1", ptr %".2")
{
.4:
  %".5" = alloca ptr
  store ptr null, ptr %".5"
  br label %".7"
.7:
  br label %".9"
.9:
  %".11" = call i64 @"PyLong_AsLongLong"(ptr %".1")
  %".12" = call i64 @"PyLong_AsLongLong"(ptr %".2")
  %".13" = add i64 %".11", %".12"
  %".14" = call ptr @"PyLong_FromLongLong"(i64 %".13")
  ret ptr %".14"
}

declare i64 @"PyLong_AsLongLong"(ptr %".1")

declare ptr @"PyLong_FromLongLong"(i64 %".1")

Testing report

1. Args ▶
(123, 321)
2. JIT output ▶
444
3. Expected output ▶
444

Observations:

  • In the egraph, observe how the new operations are represented.
  • In the RVSDG, notice the lack of Py_AddIO()
  • In the LLVM, notice the addition is now done in native i64.

Optimize boxing logic¶

A key benefit of EGraph is that there is no need to specify ordering to "compiler-passes". To demonstrate this, we will insert optimization rules on the boxing and unboxing operation. unbox(box(x)) is equivalent to an no-op. We can remove redundant boxing and unboxing.

We will need more than one addition to showcase the optimization:

In [23]:
def chained_additions(x, y):
    return x + y + y
In [24]:
if __name__ == "__main__":
    report = Report("Compiler Pipeline", default_expanded=True)
    jt = compiler_pipeline(
        fn=chained_additions,
        ruleset=typeinfer_ruleset,
        converter_class=ExtendEGraphToRVSDG,
        codegen_extension=codegen_extension,
        cost_model=MyCostModel(),
    ).jit_func
    report.display()
    run_test(chained_additions, jt, (123, 321), verbose=True)

Compiler Pipeline

Testing report

1. Args ▶
(123, 321)
2. JIT output ▶
765
3. Expected output ▶
765

Observations:

  $4 = NbOp_Box_Int64 $3
  $5 = NbOp_Unbox_Int64 $4

The box and unbox chain is redundant (i.e. $3 = $5).

Box/Unbox optimization rules¶

The needed optimization rule is very simple. Any chained box-unbox; or unbox-box are redundant.

(We use subsume=True to delete the original EGraph node (enode) to shrink the graph early.)

In [25]:
@ruleset
def ruleset_optimize_boxing(x: Term):
    yield rewrite(Nb_Box_Int64(Nb_Unbox_Int64(x)), subsume=True).to(x)
    yield rewrite(Nb_Unbox_Int64(Nb_Box_Int64(x)), subsume=True).to(x)
In [26]:
optimized_ruleset = typeinfer_ruleset | ruleset_optimize_boxing
In [27]:
if __name__ == "__main__":
    report = Report("Compiler Pipeline", default_expanded=True)
    jt = compiler_pipeline(
        fn=chained_additions,
        ruleset=optimized_ruleset,
        converter_class=ExtendEGraphToRVSDG,
        codegen_extension=codegen_extension,
        cost_model=MyCostModel(),
        pipeline_report=report,
    ).jit_func
    report.display()
    run_test(chained_additions, jt, (123, 321), verbose=True)

Compiler Pipeline

1. Frontend ▶
Frontend
Debug Info on RVSDG ▶
--------------------------------original source---------------------------------
   1|def chained_additions(x, y):
   2|    return x + y + y
----------------------------------inter source----------------------------------
   1|def transformed_chained_additions(x, y):
   2|    """#file: /tmp/ipykernel_3518/3462865332.py"""
   3|    '#loc: 2:8-2:24'
   4|    return x + y + y
RVSDG ▶
transformed_chained_additions = Func (Args (ArgSpec 'x' (PyNone)) (ArgSpec 'y' (PyNone)))
$0 = Region[266] <- !io x y
{
  $1 = PyBinOp + $0[0] $0[1], $0[2]
  $2 = PyBinOp + $1[0] $1[1], $0[2]
} [358] -> !io=$2[0] !ret=$2[1]
2. EGraph Conversion ▶
EGraph Conversion
EGraph ▶
outer_cluster_InPorts-0 cluster_InPorts-0 outer_cluster_Port-10 cluster_Port-10 outer_cluster_Port-12 cluster_Port-12 outer_cluster_PortList-13 cluster_PortList-13 outer_cluster_Region-1 cluster_Region-1 outer_cluster_Term-9 cluster_Term-9 outer_cluster_Term-11 cluster_Term-11 outer_cluster_Term-6 cluster_Term-6 outer_cluster_Term-4 cluster_Term-4 outer_cluster_Term-14 cluster_Term-14 outer_cluster_Term-5 cluster_Term-5 outer_cluster_Term-15 cluster_Term-15 outer_cluster_Term-3 cluster_Term-3 outer_cluster_Term-7 cluster_Term-7 outer_cluster_Term-16 cluster_Term-16 outer_cluster_Term-8 cluster_Term-8 outer_cluster_Term-2 cluster_Term-2 outer_cluster_Vec_Port-0 cluster_Vec_Port-0 outer_cluster_Vec_String-0 cluster_Vec_String-0 function-0-InPorts___init__:s->primitive-Vec_String-0 function-0-Port___init__:s->function-2-Term_getPort function-2-Term_getPort:s->function-1-Py_AddIO function-1-Port___init__:s->function-3-Term_getPort function-3-Term_getPort:s->function-1-Py_AddIO function-0-PortList___init__:s->primitive-Vec_Port-0 primitive-Vec_Port-0:s->function-0-Port___init__ primitive-Vec_Port-0:s->function-1-Port___init__ function-0-Region___init__:s->function-0-InPorts___init__ function-1-Py_AddIO:s->function-0-Term_getPort function-1-Py_AddIO:s->function-2-Region_get function-1-Py_AddIO:s->function-1-Term_getPort function-0-Term_getPort:s->function-0-Py_AddIO function-0-Py_AddIO:s->function-2-Region_get function-0-Py_AddIO:s->function-0-Region_get function-0-Py_AddIO:s->function-1-Region_get function-2-Region_get:s->function-0-Region___init__ function-0-Term_RegionEnd:s->function-0-PortList___init__ function-0-Term_RegionEnd:s->function-0-Region___init__ function-0-Region_get:s->function-0-Region___init__ function-1-Region_get:s->function-0-Region___init__ function-0-Term_Func:s->function-0-Term_RegionEnd function-1-Term_getPort:s->function-0-Py_AddIO function-0-GraphRoot:s->function-0-Term_Func function-0-InPorts___init__ InPorts primitive-Vec_String-0 Vec("!io", "x", "y") function-0-Port___init__ Port("!io", ·) function-2-Term_getPort ·.getPort(·, 0) function-1-Port___init__ Port("!ret", ·) function-3-Term_getPort ·.getPort(·, 1) function-0-PortList___init__ PortList primitive-Vec_Port-0 Vec function-0-Region___init__ Region("266", ·) function-1-Py_AddIO Py_AddIO function-0-Term_getPort ·.getPort(·, 0) function-0-Py_AddIO Py_AddIO function-2-Region_get ·.get(·, 2) function-0-Term_RegionEnd Term.RegionEnd function-0-Region_get ·.get(·, 0) function-1-Region_get ·.get(·, 1) function-0-Term_Func Term.Func("364", "transformed_chained_additions", ·) function-1-Term_getPort ·.getPort(·, 1) function-0-GraphRoot GraphRoot
3. EGraph Saturated ▶
outer_cluster_InPorts-0 cluster_InPorts-0 outer_cluster_Port-40 cluster_Port-40 outer_cluster_Port-31 cluster_Port-31 outer_cluster_PortList-13 cluster_PortList-13 outer_cluster_Region-1 cluster_Region-1 outer_cluster_String-10 cluster_String-10 outer_cluster_String-2684354572 cluster_String-2684354572 outer_cluster_Term-5 cluster_Term-5 outer_cluster_Term-23 cluster_Term-23 outer_cluster_Term-36 cluster_Term-36 outer_cluster_Term-42 cluster_Term-42 outer_cluster_Term-14 cluster_Term-14 outer_cluster_Term-8 cluster_Term-8 outer_cluster_Term-33 cluster_Term-33 outer_cluster_Term-37 cluster_Term-37 outer_cluster_Term-3 cluster_Term-3 outer_cluster_Term-27 cluster_Term-27 outer_cluster_Term-26 cluster_Term-26 outer_cluster_Term-4 cluster_Term-4 outer_cluster_Term-7 cluster_Term-7 outer_cluster_Term-43 cluster_Term-43 outer_cluster_Term-16 cluster_Term-16 outer_cluster_Term-15 cluster_Term-15 outer_cluster_Type-39 cluster_Type-39 outer_cluster_Vec_Port-2 cluster_Vec_Port-2 outer_cluster_Vec_String-0 cluster_Vec_String-0 function-0-InPorts___init__:s->primitive-Vec_String-0 function-1-Port___init__:s->function-1-PortList_getValue function-1-PortList_getValue:s->function-0-PortList___init__ function-1-PortList___getitem__:s->function-0-PortList___init__ function-0-PortList___init__:s->primitive-Vec_Port-2 function-0-Port___init__:s->function-4-Term_getPort function-4-Term_getPort:s->function-2-Py_AddIO function-0-PortList___getitem__:s->function-0-PortList___init__ primitive-Vec_Port-2:s->function-1-PortList___getitem__ primitive-Vec_Port-2:s->function-0-PortList___getitem__ function-0-Region___init__:s->function-0-InPorts___init__ function-2-Port_name:s->function-0-PortList___getitem__ function-3-Port_name:s->function-1-PortList___getitem__ function-3-Py_AddIO:s->function-5-Term_getPort function-3-Py_AddIO:s->function-1-Region_get function-3-Py_AddIO:s->function-2-Region_get function-5-Term_getPort:s->function-3-Py_AddIO function-1-Region_get:s->function-0-Region___init__ function-2-Region_get:s->function-0-Region___init__ function-5-Region_get:s->function-0-Region___init__ function-2-Py_AddIO:s->function-4-Term_getPort function-2-Py_AddIO:s->function-2-Region_get function-2-Py_AddIO:s->function-1-Nb_Box_Int64 function-0-Port_value:s->function-0-Port___init__ function-0-PortList_getValue:s->function-0-PortList___init__ function-3-Nb_Unbox_Int64:s->function-0-Term_getPort function-0-Term_getPort:s->function-3-Py_AddIO function-2-Nb_Unboxed_Add_Int64:s->function-0-Nb_Unbox_Int64 function-2-Nb_Unboxed_Add_Int64:s->function-1-Nb_Unbox_Int64 function-0-Nb_Unbox_Int64:s->function-1-Region_get function-1-Nb_Unbox_Int64:s->function-2-Region_get function-7-Term_getPort:s->function-0-Term_RegionEnd function-0-Term_RegionEnd:s->function-0-PortList___init__ function-0-Term_RegionEnd:s->function-0-Region___init__ function-1-Nb_Box_Int64:s->function-3-Nb_Unbox_Int64 function-6-Term_getPort:s->function-2-Py_AddIO function-1-Port_value:s->function-1-Port___init__ function-2-Nb_Box_Int64:s->function-1-Nb_Unboxed_Add_Int64 function-1-Nb_Unboxed_Add_Int64:s->function-2-Nb_Unboxed_Add_Int64 function-1-Nb_Unboxed_Add_Int64:s->function-1-Nb_Unbox_Int64 function-8-Term_getPort:s->function-0-Term_RegionEnd function-0-GraphRoot:s->function-0-Term_Func function-0-Term_Func:s->function-0-Term_RegionEnd function-2-TypeOf:s->function-2-Region_get function-1-TypeOf:s->function-1-Nb_Box_Int64 function-3-TypeOf:s->function-1-Region_get function-4-TypeOf:s->function-1-PortList_getValue function-0-InPorts___init__ InPorts primitive-Vec_String-0 Vec("!io", "x", "y") function-1-Port___init__ Port("!ret", ·) function-1-PortList_getValue ·.getValue(·, 1) function-1-PortList___getitem__ ·[1] function-0-PortList___init__ PortList function-0-Port___init__ Port("!io", ·) function-4-Term_getPort ·.getPort(·, 0) function-0-PortList___getitem__ ·[0] primitive-Vec_Port-2 Vec function-0-Region___init__ Region("266", ·) function-2-Port_name ·.name primitive-String-10 "!io" function-3-Port_name ·.name primitive-String-2684354572 "!ret" function-3-Py_AddIO Py_AddIO function-5-Term_getPort ·.getPort(·, 0) function-1-Region_get ·.get(·, 1) function-2-Region_get ·.get(·, 2) function-5-Region_get ·.get(·, 0) function-2-Py_AddIO Py_AddIO function-0-Port_value ·.value function-0-PortList_getValue ·.getValue(·, 0) function-3-Nb_Unbox_Int64 Nb_Unbox_Int64 function-0-Term_getPort ·.getPort(·, 1) function-2-Nb_Unboxed_Add_Int64 Nb_Unboxed_Add_Int64 function-0-Nb_Unbox_Int64 Nb_Unbox_Int64 function-1-Nb_Unbox_Int64 Nb_Unbox_Int64 function-7-Term_getPort ·.getPort(·, 0) function-0-Term_RegionEnd Term.RegionEnd function-1-Nb_Box_Int64 Nb_Box_Int64 function-6-Term_getPort ·.getPort(·, 1) function-1-Port_value ·.value function-2-Nb_Box_Int64 Nb_Box_Int64 function-1-Nb_Unboxed_Add_Int64 Nb_Unboxed_Add_Int64 function-8-Term_getPort ·.getPort(·, 1) function-0-GraphRoot GraphRoot function-0-Term_Func Term.Func("364", "transformed_chained_additions", ·) function-2-TypeOf TypeOf function-1-TypeOf TypeOf function-3-TypeOf TypeOf function-4-TypeOf TypeOf function-1-Type___init__ Type("Int64")
4. EGraph Extraction ▶
EGraph Extraction
Cost ▶
10971.0
Extracted ▶
transformed_chained_additions = Func (Args (ArgSpec 'x' (PyNone)) (ArgSpec 'y' (PyNone)))
$0 = Region[457] <- !io x y
{
  $1 = NbOp_Unbox_Int64 $0[1]
  $2 = NbOp_Unbox_Int64 $0[2]
  $3 = NbOp_Unboxed_Add_Int64 $1 $2
  $4 = NbOp_Unboxed_Add_Int64 $3 $2
  $5 = NbOp_Box_Int64 $4
} [511] -> !io=$0[0] !ret=$5
5. Backend ▶
Backend
LLVM ▶
; ModuleID = ""
target triple = "unknown-unknown-unknown"
target datalayout = ""

define ptr @"foo"(ptr %".1", ptr %".2")
{
.4:
  %".5" = alloca ptr
  store ptr null, ptr %".5"
  br label %".7"
.7:
  br label %".9"
.9:
  %".11" = call i64 @"PyLong_AsLongLong"(ptr %".1")
  %".12" = call i64 @"PyLong_AsLongLong"(ptr %".2")
  %".13" = add i64 %".11", %".12"
  %".14" = add i64 %".13", %".12"
  %".15" = call ptr @"PyLong_FromLongLong"(i64 %".14")
  ret ptr %".15"
}

declare i64 @"PyLong_AsLongLong"(ptr %".1")

declare ptr @"PyLong_FromLongLong"(i64 %".1")

Testing report

1. Args ▶
(123, 321)
2. JIT output ▶
765
3. Expected output ▶
765

Observations:

  $1 = NbOp_Unbox_Int64 $0[1]
  $2 = NbOp_Unbox_Int64 $0[2]
  $3 = NbOp_Unboxed_Add_Int64 $1 $2
  $4 = NbOp_Unboxed_Add_Int64 $3 $2
  $5 = NbOp_Box_Int64 $4

There is no redundant box-unbox between the two unboxed add anymore.