gtn/.venv/Lib/site-packages/mypyc/ir/pprint.py
Tipragot 628be439b8 Ajout d'un environement de développement.
Cela permet de ne pas avoir de problèmes de compatibilité
car python est dans le git.
2023-10-26 15:33:03 +02:00

487 lines
17 KiB
Python

"""Utilities for pretty-printing IR in a human-readable form."""
from __future__ import annotations
from collections import defaultdict
from typing import Any, Final, Sequence, Union
from mypyc.common import short_name
from mypyc.ir.func_ir import FuncIR, all_values_full
from mypyc.ir.module_ir import ModuleIRs
from mypyc.ir.ops import (
ERR_NEVER,
Assign,
AssignMulti,
BasicBlock,
Box,
Branch,
Call,
CallC,
Cast,
ComparisonOp,
ControlOp,
DecRef,
Extend,
Float,
FloatComparisonOp,
FloatNeg,
FloatOp,
GetAttr,
GetElementPtr,
Goto,
IncRef,
InitStatic,
Integer,
IntOp,
KeepAlive,
LoadAddress,
LoadErrorValue,
LoadGlobal,
LoadLiteral,
LoadMem,
LoadStatic,
MethodCall,
Op,
OpVisitor,
RaiseStandardError,
Register,
Return,
SetAttr,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unbox,
Unreachable,
Value,
)
from mypyc.ir.rtypes import RType, is_bool_rprimitive, is_int_rprimitive
ErrorSource = Union[BasicBlock, Op]
class IRPrettyPrintVisitor(OpVisitor[str]):
"""Internal visitor that pretty-prints ops."""
def __init__(self, names: dict[Value, str]) -> None:
# This should contain a name for all values that are shown as
# registers in the output. This is not just for Register
# instances -- all Ops that produce values need (generated) names.
self.names = names
def visit_goto(self, op: Goto) -> str:
return self.format("goto %l", op.label)
branch_op_names: Final = {Branch.BOOL: ("%r", "bool"), Branch.IS_ERROR: ("is_error(%r)", "")}
def visit_branch(self, op: Branch) -> str:
fmt, typ = self.branch_op_names[op.op]
if op.negated:
fmt = f"not {fmt}"
cond = self.format(fmt, op.value)
tb = ""
if op.traceback_entry:
tb = " (error at %s:%d)" % op.traceback_entry
fmt = f"if {cond} goto %l{tb} else goto %l"
if typ:
fmt += f" :: {typ}"
return self.format(fmt, op.true, op.false)
def visit_return(self, op: Return) -> str:
return self.format("return %r", op.value)
def visit_unreachable(self, op: Unreachable) -> str:
return "unreachable"
def visit_assign(self, op: Assign) -> str:
return self.format("%r = %r", op.dest, op.src)
def visit_assign_multi(self, op: AssignMulti) -> str:
return self.format("%r = [%s]", op.dest, ", ".join(self.format("%r", v) for v in op.src))
def visit_load_error_value(self, op: LoadErrorValue) -> str:
return self.format("%r = <error> :: %s", op, op.type)
def visit_load_literal(self, op: LoadLiteral) -> str:
prefix = ""
# For values that have a potential unboxed representation, make
# it explicit that this is a Python object.
if isinstance(op.value, int):
prefix = "object "
rvalue = repr(op.value)
if isinstance(op.value, frozenset):
# We need to generate a string representation that won't vary
# run-to-run because sets are unordered, otherwise we may get
# spurious irbuild test failures.
#
# Sorting by the item's string representation is a bit of a
# hack, but it's stable and won't cause TypeErrors.
formatted_items = [repr(i) for i in sorted(op.value, key=str)]
rvalue = "frozenset({" + ", ".join(formatted_items) + "})"
return self.format("%r = %s%s", op, prefix, rvalue)
def visit_get_attr(self, op: GetAttr) -> str:
return self.format("%r = %s%r.%s", op, self.borrow_prefix(op), op.obj, op.attr)
def borrow_prefix(self, op: Op) -> str:
if op.is_borrowed:
return "borrow "
return ""
def visit_set_attr(self, op: SetAttr) -> str:
if op.is_init:
assert op.error_kind == ERR_NEVER
if op.error_kind == ERR_NEVER:
# Initialization and direct struct access can never fail
return self.format("%r.%s = %r", op.obj, op.attr, op.src)
else:
return self.format("%r.%s = %r; %r = is_error", op.obj, op.attr, op.src, op)
def visit_load_static(self, op: LoadStatic) -> str:
ann = f" ({repr(op.ann)})" if op.ann else ""
name = op.identifier
if op.module_name is not None:
name = f"{op.module_name}.{name}"
return self.format("%r = %s :: %s%s", op, name, op.namespace, ann)
def visit_init_static(self, op: InitStatic) -> str:
name = op.identifier
if op.module_name is not None:
name = f"{op.module_name}.{name}"
return self.format("%s = %r :: %s", name, op.value, op.namespace)
def visit_tuple_get(self, op: TupleGet) -> str:
return self.format("%r = %r[%d]", op, op.src, op.index)
def visit_tuple_set(self, op: TupleSet) -> str:
item_str = ", ".join(self.format("%r", item) for item in op.items)
return self.format("%r = (%s)", op, item_str)
def visit_inc_ref(self, op: IncRef) -> str:
s = self.format("inc_ref %r", op.src)
# TODO: Remove bool check (it's unboxed)
if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type):
s += f" :: {short_name(op.src.type.name)}"
return s
def visit_dec_ref(self, op: DecRef) -> str:
s = self.format("%sdec_ref %r", "x" if op.is_xdec else "", op.src)
# TODO: Remove bool check (it's unboxed)
if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type):
s += f" :: {short_name(op.src.type.name)}"
return s
def visit_call(self, op: Call) -> str:
args = ", ".join(self.format("%r", arg) for arg in op.args)
# TODO: Display long name?
short_name = op.fn.shortname
s = f"{short_name}({args})"
if not op.is_void:
s = self.format("%r = ", op) + s
return s
def visit_method_call(self, op: MethodCall) -> str:
args = ", ".join(self.format("%r", arg) for arg in op.args)
s = self.format("%r.%s(%s)", op.obj, op.method, args)
if not op.is_void:
s = self.format("%r = ", op) + s
return s
def visit_cast(self, op: Cast) -> str:
return self.format("%r = %scast(%s, %r)", op, self.borrow_prefix(op), op.type, op.src)
def visit_box(self, op: Box) -> str:
return self.format("%r = box(%s, %r)", op, op.src.type, op.src)
def visit_unbox(self, op: Unbox) -> str:
return self.format("%r = unbox(%s, %r)", op, op.type, op.src)
def visit_raise_standard_error(self, op: RaiseStandardError) -> str:
if op.value is not None:
if isinstance(op.value, str):
return self.format("%r = raise %s(%s)", op, op.class_name, repr(op.value))
elif isinstance(op.value, Value):
return self.format("%r = raise %s(%r)", op, op.class_name, op.value)
else:
assert False, "value type must be either str or Value"
else:
return self.format("%r = raise %s", op, op.class_name)
def visit_call_c(self, op: CallC) -> str:
args_str = ", ".join(self.format("%r", arg) for arg in op.args)
if op.is_void:
return self.format("%s(%s)", op.function_name, args_str)
else:
return self.format("%r = %s(%s)", op, op.function_name, args_str)
def visit_truncate(self, op: Truncate) -> str:
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)
def visit_extend(self, op: Extend) -> str:
if op.signed:
extra = " signed"
else:
extra = ""
return self.format("%r = extend%s %r: %t to %t", op, extra, op.src, op.src_type, op.type)
def visit_load_global(self, op: LoadGlobal) -> str:
ann = f" ({repr(op.ann)})" if op.ann else ""
return self.format("%r = load_global %s :: static%s", op, op.identifier, ann)
def visit_int_op(self, op: IntOp) -> str:
return self.format("%r = %r %s %r", op, op.lhs, IntOp.op_str[op.op], op.rhs)
def visit_comparison_op(self, op: ComparisonOp) -> str:
if op.op in (ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE):
sign_format = " :: signed"
elif op.op in (ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE):
sign_format = " :: unsigned"
else:
sign_format = ""
return self.format(
"%r = %r %s %r%s", op, op.lhs, ComparisonOp.op_str[op.op], op.rhs, sign_format
)
def visit_float_op(self, op: FloatOp) -> str:
return self.format("%r = %r %s %r", op, op.lhs, FloatOp.op_str[op.op], op.rhs)
def visit_float_neg(self, op: FloatNeg) -> str:
return self.format("%r = -%r", op, op.src)
def visit_float_comparison_op(self, op: FloatComparisonOp) -> str:
return self.format("%r = %r %s %r", op, op.lhs, op.op_str[op.op], op.rhs)
def visit_load_mem(self, op: LoadMem) -> str:
return self.format("%r = load_mem %r :: %t*", op, op.src, op.type)
def visit_set_mem(self, op: SetMem) -> str:
return self.format("set_mem %r, %r :: %t*", op.dest, op.src, op.dest_type)
def visit_get_element_ptr(self, op: GetElementPtr) -> str:
return self.format("%r = get_element_ptr %r %s :: %t", op, op.src, op.field, op.src_type)
def visit_load_address(self, op: LoadAddress) -> str:
if isinstance(op.src, Register):
return self.format("%r = load_address %r", op, op.src)
elif isinstance(op.src, LoadStatic):
name = op.src.identifier
if op.src.module_name is not None:
name = f"{op.src.module_name}.{name}"
return self.format("%r = load_address %s :: %s", op, name, op.src.namespace)
else:
return self.format("%r = load_address %s", op, op.src)
def visit_keep_alive(self, op: KeepAlive) -> str:
return self.format("keep_alive %s" % ", ".join(self.format("%r", v) for v in op.src))
# Helpers
def format(self, fmt: str, *args: Any) -> str:
"""Helper for formatting strings.
These format sequences are supported in fmt:
%s: arbitrary object converted to string using str()
%r: name of IR value/register
%d: int
%f: float
%l: BasicBlock (formatted as label 'Ln')
%t: RType
"""
result = []
i = 0
arglist = list(args)
while i < len(fmt):
n = fmt.find("%", i)
if n < 0:
n = len(fmt)
result.append(fmt[i:n])
if n < len(fmt):
typespec = fmt[n + 1]
arg = arglist.pop(0)
if typespec == "r":
# Register/value
assert isinstance(arg, Value)
if isinstance(arg, Integer):
result.append(str(arg.value))
elif isinstance(arg, Float):
result.append(repr(arg.value))
else:
result.append(self.names[arg])
elif typespec == "d":
# Integer
result.append("%d" % arg)
elif typespec == "f":
# Float
result.append("%f" % arg)
elif typespec == "l":
# Basic block (label)
assert isinstance(arg, BasicBlock)
result.append("L%s" % arg.label)
elif typespec == "t":
# RType
assert isinstance(arg, RType)
result.append(arg.name)
elif typespec == "s":
# String
result.append(str(arg))
else:
raise ValueError(f"Invalid format sequence %{typespec}")
i = n + 2
else:
i = n
return "".join(result)
def format_registers(func_ir: FuncIR, names: dict[Value, str]) -> list[str]:
result = []
i = 0
regs = all_values_full(func_ir.arg_regs, func_ir.blocks)
while i < len(regs):
i0 = i
group = [names[regs[i0]]]
while i + 1 < len(regs) and regs[i + 1].type == regs[i0].type:
i += 1
group.append(names[regs[i]])
i += 1
result.append("{} :: {}".format(", ".join(group), regs[i0].type))
return result
def format_blocks(
blocks: list[BasicBlock],
names: dict[Value, str],
source_to_error: dict[ErrorSource, list[str]],
) -> list[str]:
"""Format a list of IR basic blocks into a human-readable form."""
# First label all of the blocks
for i, block in enumerate(blocks):
block.label = i
handler_map: dict[BasicBlock, list[BasicBlock]] = {}
for b in blocks:
if b.error_handler:
handler_map.setdefault(b.error_handler, []).append(b)
visitor = IRPrettyPrintVisitor(names)
lines = []
for i, block in enumerate(blocks):
handler_msg = ""
if block in handler_map:
labels = sorted("L%d" % b.label for b in handler_map[block])
handler_msg = " (handler for {})".format(", ".join(labels))
lines.append("L%d:%s" % (block.label, handler_msg))
if block in source_to_error:
for error in source_to_error[block]:
lines.append(f" ERR: {error}")
ops = block.ops
if (
isinstance(ops[-1], Goto)
and i + 1 < len(blocks)
and ops[-1].label == blocks[i + 1]
and not source_to_error.get(ops[-1], [])
):
# Hide the last goto if it just goes to the next basic block,
# and there are no assocatiated errors with the op.
ops = ops[:-1]
for op in ops:
line = " " + op.accept(visitor)
lines.append(line)
if op in source_to_error:
for error in source_to_error[op]:
lines.append(f" ERR: {error}")
if not isinstance(block.ops[-1], (Goto, Branch, Return, Unreachable)):
# Each basic block needs to exit somewhere.
lines.append(" [MISSING BLOCK EXIT OPCODE]")
return lines
def format_func(fn: FuncIR, errors: Sequence[tuple[ErrorSource, str]] = ()) -> list[str]:
lines = []
cls_prefix = fn.class_name + "." if fn.class_name else ""
lines.append(
"def {}{}({}):".format(cls_prefix, fn.name, ", ".join(arg.name for arg in fn.args))
)
names = generate_names_for_ir(fn.arg_regs, fn.blocks)
for line in format_registers(fn, names):
lines.append(" " + line)
source_to_error = defaultdict(list)
for source, error in errors:
source_to_error[source].append(error)
code = format_blocks(fn.blocks, names, source_to_error)
lines.extend(code)
return lines
def format_modules(modules: ModuleIRs) -> list[str]:
ops = []
for module in modules.values():
for fn in module.functions:
ops.extend(format_func(fn))
ops.append("")
return ops
def generate_names_for_ir(args: list[Register], blocks: list[BasicBlock]) -> dict[Value, str]:
"""Generate unique names for IR values.
Give names such as 'r5' to temp values in IR which are useful when
pretty-printing or generating C. Ensure generated names are unique.
"""
names: dict[Value, str] = {}
used_names = set()
temp_index = 0
for arg in args:
names[arg] = arg.name
used_names.add(arg.name)
for block in blocks:
for op in block.ops:
values = []
for source in op.sources():
if source not in names:
values.append(source)
if isinstance(op, (Assign, AssignMulti)):
values.append(op.dest)
elif isinstance(op, ControlOp) or op.is_void:
continue
elif op not in names:
values.append(op)
for value in values:
if value in names:
continue
if isinstance(value, Register) and value.name:
name = value.name
elif isinstance(value, (Integer, Float)):
continue
else:
name = "r%d" % temp_index
temp_index += 1
# Append _2, _3, ... if needed to make the name unique.
if name in used_names:
n = 2
while True:
candidate = "%s_%d" % (name, n)
if candidate not in used_names:
name = candidate
break
n += 1
names[value] = name
used_names.add(name)
return names