Tipragot
628be439b8
Cela permet de ne pas avoir de problèmes de compatibilité car python est dans le git.
487 lines
17 KiB
Python
487 lines
17 KiB
Python
"""Utilities for pretty-printing IR in a human-readable form."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections import defaultdict
|
|
from typing import Any, Final, Sequence, Union
|
|
|
|
from mypyc.common import short_name
|
|
from mypyc.ir.func_ir import FuncIR, all_values_full
|
|
from mypyc.ir.module_ir import ModuleIRs
|
|
from mypyc.ir.ops import (
|
|
ERR_NEVER,
|
|
Assign,
|
|
AssignMulti,
|
|
BasicBlock,
|
|
Box,
|
|
Branch,
|
|
Call,
|
|
CallC,
|
|
Cast,
|
|
ComparisonOp,
|
|
ControlOp,
|
|
DecRef,
|
|
Extend,
|
|
Float,
|
|
FloatComparisonOp,
|
|
FloatNeg,
|
|
FloatOp,
|
|
GetAttr,
|
|
GetElementPtr,
|
|
Goto,
|
|
IncRef,
|
|
InitStatic,
|
|
Integer,
|
|
IntOp,
|
|
KeepAlive,
|
|
LoadAddress,
|
|
LoadErrorValue,
|
|
LoadGlobal,
|
|
LoadLiteral,
|
|
LoadMem,
|
|
LoadStatic,
|
|
MethodCall,
|
|
Op,
|
|
OpVisitor,
|
|
RaiseStandardError,
|
|
Register,
|
|
Return,
|
|
SetAttr,
|
|
SetMem,
|
|
Truncate,
|
|
TupleGet,
|
|
TupleSet,
|
|
Unbox,
|
|
Unreachable,
|
|
Value,
|
|
)
|
|
from mypyc.ir.rtypes import RType, is_bool_rprimitive, is_int_rprimitive
|
|
|
|
ErrorSource = Union[BasicBlock, Op]
|
|
|
|
|
|
class IRPrettyPrintVisitor(OpVisitor[str]):
|
|
"""Internal visitor that pretty-prints ops."""
|
|
|
|
def __init__(self, names: dict[Value, str]) -> None:
|
|
# This should contain a name for all values that are shown as
|
|
# registers in the output. This is not just for Register
|
|
# instances -- all Ops that produce values need (generated) names.
|
|
self.names = names
|
|
|
|
def visit_goto(self, op: Goto) -> str:
|
|
return self.format("goto %l", op.label)
|
|
|
|
branch_op_names: Final = {Branch.BOOL: ("%r", "bool"), Branch.IS_ERROR: ("is_error(%r)", "")}
|
|
|
|
def visit_branch(self, op: Branch) -> str:
|
|
fmt, typ = self.branch_op_names[op.op]
|
|
if op.negated:
|
|
fmt = f"not {fmt}"
|
|
|
|
cond = self.format(fmt, op.value)
|
|
tb = ""
|
|
if op.traceback_entry:
|
|
tb = " (error at %s:%d)" % op.traceback_entry
|
|
fmt = f"if {cond} goto %l{tb} else goto %l"
|
|
if typ:
|
|
fmt += f" :: {typ}"
|
|
return self.format(fmt, op.true, op.false)
|
|
|
|
def visit_return(self, op: Return) -> str:
|
|
return self.format("return %r", op.value)
|
|
|
|
def visit_unreachable(self, op: Unreachable) -> str:
|
|
return "unreachable"
|
|
|
|
def visit_assign(self, op: Assign) -> str:
|
|
return self.format("%r = %r", op.dest, op.src)
|
|
|
|
def visit_assign_multi(self, op: AssignMulti) -> str:
|
|
return self.format("%r = [%s]", op.dest, ", ".join(self.format("%r", v) for v in op.src))
|
|
|
|
def visit_load_error_value(self, op: LoadErrorValue) -> str:
|
|
return self.format("%r = <error> :: %s", op, op.type)
|
|
|
|
def visit_load_literal(self, op: LoadLiteral) -> str:
|
|
prefix = ""
|
|
# For values that have a potential unboxed representation, make
|
|
# it explicit that this is a Python object.
|
|
if isinstance(op.value, int):
|
|
prefix = "object "
|
|
|
|
rvalue = repr(op.value)
|
|
if isinstance(op.value, frozenset):
|
|
# We need to generate a string representation that won't vary
|
|
# run-to-run because sets are unordered, otherwise we may get
|
|
# spurious irbuild test failures.
|
|
#
|
|
# Sorting by the item's string representation is a bit of a
|
|
# hack, but it's stable and won't cause TypeErrors.
|
|
formatted_items = [repr(i) for i in sorted(op.value, key=str)]
|
|
rvalue = "frozenset({" + ", ".join(formatted_items) + "})"
|
|
return self.format("%r = %s%s", op, prefix, rvalue)
|
|
|
|
def visit_get_attr(self, op: GetAttr) -> str:
|
|
return self.format("%r = %s%r.%s", op, self.borrow_prefix(op), op.obj, op.attr)
|
|
|
|
def borrow_prefix(self, op: Op) -> str:
|
|
if op.is_borrowed:
|
|
return "borrow "
|
|
return ""
|
|
|
|
def visit_set_attr(self, op: SetAttr) -> str:
|
|
if op.is_init:
|
|
assert op.error_kind == ERR_NEVER
|
|
if op.error_kind == ERR_NEVER:
|
|
# Initialization and direct struct access can never fail
|
|
return self.format("%r.%s = %r", op.obj, op.attr, op.src)
|
|
else:
|
|
return self.format("%r.%s = %r; %r = is_error", op.obj, op.attr, op.src, op)
|
|
|
|
def visit_load_static(self, op: LoadStatic) -> str:
|
|
ann = f" ({repr(op.ann)})" if op.ann else ""
|
|
name = op.identifier
|
|
if op.module_name is not None:
|
|
name = f"{op.module_name}.{name}"
|
|
return self.format("%r = %s :: %s%s", op, name, op.namespace, ann)
|
|
|
|
def visit_init_static(self, op: InitStatic) -> str:
|
|
name = op.identifier
|
|
if op.module_name is not None:
|
|
name = f"{op.module_name}.{name}"
|
|
return self.format("%s = %r :: %s", name, op.value, op.namespace)
|
|
|
|
def visit_tuple_get(self, op: TupleGet) -> str:
|
|
return self.format("%r = %r[%d]", op, op.src, op.index)
|
|
|
|
def visit_tuple_set(self, op: TupleSet) -> str:
|
|
item_str = ", ".join(self.format("%r", item) for item in op.items)
|
|
return self.format("%r = (%s)", op, item_str)
|
|
|
|
def visit_inc_ref(self, op: IncRef) -> str:
|
|
s = self.format("inc_ref %r", op.src)
|
|
# TODO: Remove bool check (it's unboxed)
|
|
if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type):
|
|
s += f" :: {short_name(op.src.type.name)}"
|
|
return s
|
|
|
|
def visit_dec_ref(self, op: DecRef) -> str:
|
|
s = self.format("%sdec_ref %r", "x" if op.is_xdec else "", op.src)
|
|
# TODO: Remove bool check (it's unboxed)
|
|
if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type):
|
|
s += f" :: {short_name(op.src.type.name)}"
|
|
return s
|
|
|
|
def visit_call(self, op: Call) -> str:
|
|
args = ", ".join(self.format("%r", arg) for arg in op.args)
|
|
# TODO: Display long name?
|
|
short_name = op.fn.shortname
|
|
s = f"{short_name}({args})"
|
|
if not op.is_void:
|
|
s = self.format("%r = ", op) + s
|
|
return s
|
|
|
|
def visit_method_call(self, op: MethodCall) -> str:
|
|
args = ", ".join(self.format("%r", arg) for arg in op.args)
|
|
s = self.format("%r.%s(%s)", op.obj, op.method, args)
|
|
if not op.is_void:
|
|
s = self.format("%r = ", op) + s
|
|
return s
|
|
|
|
def visit_cast(self, op: Cast) -> str:
|
|
return self.format("%r = %scast(%s, %r)", op, self.borrow_prefix(op), op.type, op.src)
|
|
|
|
def visit_box(self, op: Box) -> str:
|
|
return self.format("%r = box(%s, %r)", op, op.src.type, op.src)
|
|
|
|
def visit_unbox(self, op: Unbox) -> str:
|
|
return self.format("%r = unbox(%s, %r)", op, op.type, op.src)
|
|
|
|
def visit_raise_standard_error(self, op: RaiseStandardError) -> str:
|
|
if op.value is not None:
|
|
if isinstance(op.value, str):
|
|
return self.format("%r = raise %s(%s)", op, op.class_name, repr(op.value))
|
|
elif isinstance(op.value, Value):
|
|
return self.format("%r = raise %s(%r)", op, op.class_name, op.value)
|
|
else:
|
|
assert False, "value type must be either str or Value"
|
|
else:
|
|
return self.format("%r = raise %s", op, op.class_name)
|
|
|
|
def visit_call_c(self, op: CallC) -> str:
|
|
args_str = ", ".join(self.format("%r", arg) for arg in op.args)
|
|
if op.is_void:
|
|
return self.format("%s(%s)", op.function_name, args_str)
|
|
else:
|
|
return self.format("%r = %s(%s)", op, op.function_name, args_str)
|
|
|
|
def visit_truncate(self, op: Truncate) -> str:
|
|
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)
|
|
|
|
def visit_extend(self, op: Extend) -> str:
|
|
if op.signed:
|
|
extra = " signed"
|
|
else:
|
|
extra = ""
|
|
return self.format("%r = extend%s %r: %t to %t", op, extra, op.src, op.src_type, op.type)
|
|
|
|
def visit_load_global(self, op: LoadGlobal) -> str:
|
|
ann = f" ({repr(op.ann)})" if op.ann else ""
|
|
return self.format("%r = load_global %s :: static%s", op, op.identifier, ann)
|
|
|
|
def visit_int_op(self, op: IntOp) -> str:
|
|
return self.format("%r = %r %s %r", op, op.lhs, IntOp.op_str[op.op], op.rhs)
|
|
|
|
def visit_comparison_op(self, op: ComparisonOp) -> str:
|
|
if op.op in (ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE):
|
|
sign_format = " :: signed"
|
|
elif op.op in (ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE):
|
|
sign_format = " :: unsigned"
|
|
else:
|
|
sign_format = ""
|
|
return self.format(
|
|
"%r = %r %s %r%s", op, op.lhs, ComparisonOp.op_str[op.op], op.rhs, sign_format
|
|
)
|
|
|
|
def visit_float_op(self, op: FloatOp) -> str:
|
|
return self.format("%r = %r %s %r", op, op.lhs, FloatOp.op_str[op.op], op.rhs)
|
|
|
|
def visit_float_neg(self, op: FloatNeg) -> str:
|
|
return self.format("%r = -%r", op, op.src)
|
|
|
|
def visit_float_comparison_op(self, op: FloatComparisonOp) -> str:
|
|
return self.format("%r = %r %s %r", op, op.lhs, op.op_str[op.op], op.rhs)
|
|
|
|
def visit_load_mem(self, op: LoadMem) -> str:
|
|
return self.format("%r = load_mem %r :: %t*", op, op.src, op.type)
|
|
|
|
def visit_set_mem(self, op: SetMem) -> str:
|
|
return self.format("set_mem %r, %r :: %t*", op.dest, op.src, op.dest_type)
|
|
|
|
def visit_get_element_ptr(self, op: GetElementPtr) -> str:
|
|
return self.format("%r = get_element_ptr %r %s :: %t", op, op.src, op.field, op.src_type)
|
|
|
|
def visit_load_address(self, op: LoadAddress) -> str:
|
|
if isinstance(op.src, Register):
|
|
return self.format("%r = load_address %r", op, op.src)
|
|
elif isinstance(op.src, LoadStatic):
|
|
name = op.src.identifier
|
|
if op.src.module_name is not None:
|
|
name = f"{op.src.module_name}.{name}"
|
|
return self.format("%r = load_address %s :: %s", op, name, op.src.namespace)
|
|
else:
|
|
return self.format("%r = load_address %s", op, op.src)
|
|
|
|
def visit_keep_alive(self, op: KeepAlive) -> str:
|
|
return self.format("keep_alive %s" % ", ".join(self.format("%r", v) for v in op.src))
|
|
|
|
# Helpers
|
|
|
|
def format(self, fmt: str, *args: Any) -> str:
|
|
"""Helper for formatting strings.
|
|
|
|
These format sequences are supported in fmt:
|
|
|
|
%s: arbitrary object converted to string using str()
|
|
%r: name of IR value/register
|
|
%d: int
|
|
%f: float
|
|
%l: BasicBlock (formatted as label 'Ln')
|
|
%t: RType
|
|
"""
|
|
result = []
|
|
i = 0
|
|
arglist = list(args)
|
|
while i < len(fmt):
|
|
n = fmt.find("%", i)
|
|
if n < 0:
|
|
n = len(fmt)
|
|
result.append(fmt[i:n])
|
|
if n < len(fmt):
|
|
typespec = fmt[n + 1]
|
|
arg = arglist.pop(0)
|
|
if typespec == "r":
|
|
# Register/value
|
|
assert isinstance(arg, Value)
|
|
if isinstance(arg, Integer):
|
|
result.append(str(arg.value))
|
|
elif isinstance(arg, Float):
|
|
result.append(repr(arg.value))
|
|
else:
|
|
result.append(self.names[arg])
|
|
elif typespec == "d":
|
|
# Integer
|
|
result.append("%d" % arg)
|
|
elif typespec == "f":
|
|
# Float
|
|
result.append("%f" % arg)
|
|
elif typespec == "l":
|
|
# Basic block (label)
|
|
assert isinstance(arg, BasicBlock)
|
|
result.append("L%s" % arg.label)
|
|
elif typespec == "t":
|
|
# RType
|
|
assert isinstance(arg, RType)
|
|
result.append(arg.name)
|
|
elif typespec == "s":
|
|
# String
|
|
result.append(str(arg))
|
|
else:
|
|
raise ValueError(f"Invalid format sequence %{typespec}")
|
|
i = n + 2
|
|
else:
|
|
i = n
|
|
return "".join(result)
|
|
|
|
|
|
def format_registers(func_ir: FuncIR, names: dict[Value, str]) -> list[str]:
|
|
result = []
|
|
i = 0
|
|
regs = all_values_full(func_ir.arg_regs, func_ir.blocks)
|
|
while i < len(regs):
|
|
i0 = i
|
|
group = [names[regs[i0]]]
|
|
while i + 1 < len(regs) and regs[i + 1].type == regs[i0].type:
|
|
i += 1
|
|
group.append(names[regs[i]])
|
|
i += 1
|
|
result.append("{} :: {}".format(", ".join(group), regs[i0].type))
|
|
return result
|
|
|
|
|
|
def format_blocks(
|
|
blocks: list[BasicBlock],
|
|
names: dict[Value, str],
|
|
source_to_error: dict[ErrorSource, list[str]],
|
|
) -> list[str]:
|
|
"""Format a list of IR basic blocks into a human-readable form."""
|
|
# First label all of the blocks
|
|
for i, block in enumerate(blocks):
|
|
block.label = i
|
|
|
|
handler_map: dict[BasicBlock, list[BasicBlock]] = {}
|
|
for b in blocks:
|
|
if b.error_handler:
|
|
handler_map.setdefault(b.error_handler, []).append(b)
|
|
|
|
visitor = IRPrettyPrintVisitor(names)
|
|
|
|
lines = []
|
|
for i, block in enumerate(blocks):
|
|
handler_msg = ""
|
|
if block in handler_map:
|
|
labels = sorted("L%d" % b.label for b in handler_map[block])
|
|
handler_msg = " (handler for {})".format(", ".join(labels))
|
|
|
|
lines.append("L%d:%s" % (block.label, handler_msg))
|
|
if block in source_to_error:
|
|
for error in source_to_error[block]:
|
|
lines.append(f" ERR: {error}")
|
|
ops = block.ops
|
|
if (
|
|
isinstance(ops[-1], Goto)
|
|
and i + 1 < len(blocks)
|
|
and ops[-1].label == blocks[i + 1]
|
|
and not source_to_error.get(ops[-1], [])
|
|
):
|
|
# Hide the last goto if it just goes to the next basic block,
|
|
# and there are no assocatiated errors with the op.
|
|
ops = ops[:-1]
|
|
for op in ops:
|
|
line = " " + op.accept(visitor)
|
|
lines.append(line)
|
|
if op in source_to_error:
|
|
for error in source_to_error[op]:
|
|
lines.append(f" ERR: {error}")
|
|
|
|
if not isinstance(block.ops[-1], (Goto, Branch, Return, Unreachable)):
|
|
# Each basic block needs to exit somewhere.
|
|
lines.append(" [MISSING BLOCK EXIT OPCODE]")
|
|
return lines
|
|
|
|
|
|
def format_func(fn: FuncIR, errors: Sequence[tuple[ErrorSource, str]] = ()) -> list[str]:
|
|
lines = []
|
|
cls_prefix = fn.class_name + "." if fn.class_name else ""
|
|
lines.append(
|
|
"def {}{}({}):".format(cls_prefix, fn.name, ", ".join(arg.name for arg in fn.args))
|
|
)
|
|
names = generate_names_for_ir(fn.arg_regs, fn.blocks)
|
|
for line in format_registers(fn, names):
|
|
lines.append(" " + line)
|
|
|
|
source_to_error = defaultdict(list)
|
|
for source, error in errors:
|
|
source_to_error[source].append(error)
|
|
|
|
code = format_blocks(fn.blocks, names, source_to_error)
|
|
lines.extend(code)
|
|
return lines
|
|
|
|
|
|
def format_modules(modules: ModuleIRs) -> list[str]:
|
|
ops = []
|
|
for module in modules.values():
|
|
for fn in module.functions:
|
|
ops.extend(format_func(fn))
|
|
ops.append("")
|
|
return ops
|
|
|
|
|
|
def generate_names_for_ir(args: list[Register], blocks: list[BasicBlock]) -> dict[Value, str]:
|
|
"""Generate unique names for IR values.
|
|
|
|
Give names such as 'r5' to temp values in IR which are useful when
|
|
pretty-printing or generating C. Ensure generated names are unique.
|
|
"""
|
|
names: dict[Value, str] = {}
|
|
used_names = set()
|
|
|
|
temp_index = 0
|
|
|
|
for arg in args:
|
|
names[arg] = arg.name
|
|
used_names.add(arg.name)
|
|
|
|
for block in blocks:
|
|
for op in block.ops:
|
|
values = []
|
|
|
|
for source in op.sources():
|
|
if source not in names:
|
|
values.append(source)
|
|
|
|
if isinstance(op, (Assign, AssignMulti)):
|
|
values.append(op.dest)
|
|
elif isinstance(op, ControlOp) or op.is_void:
|
|
continue
|
|
elif op not in names:
|
|
values.append(op)
|
|
|
|
for value in values:
|
|
if value in names:
|
|
continue
|
|
if isinstance(value, Register) and value.name:
|
|
name = value.name
|
|
elif isinstance(value, (Integer, Float)):
|
|
continue
|
|
else:
|
|
name = "r%d" % temp_index
|
|
temp_index += 1
|
|
|
|
# Append _2, _3, ... if needed to make the name unique.
|
|
if name in used_names:
|
|
n = 2
|
|
while True:
|
|
candidate = "%s_%d" % (name, n)
|
|
if candidate not in used_names:
|
|
name = candidate
|
|
break
|
|
n += 1
|
|
|
|
names[value] = name
|
|
used_names.add(name)
|
|
|
|
return names
|