425 lines
13 KiB
Python
425 lines
13 KiB
Python
|
"""Utilities for checking that internal ir is valid and consistent."""
|
||
|
from __future__ import annotations
|
||
|
|
||
|
from mypyc.ir.func_ir import FUNC_STATICMETHOD, FuncIR
|
||
|
from mypyc.ir.ops import (
|
||
|
Assign,
|
||
|
AssignMulti,
|
||
|
BaseAssign,
|
||
|
BasicBlock,
|
||
|
Box,
|
||
|
Branch,
|
||
|
Call,
|
||
|
CallC,
|
||
|
Cast,
|
||
|
ComparisonOp,
|
||
|
ControlOp,
|
||
|
DecRef,
|
||
|
Extend,
|
||
|
FloatComparisonOp,
|
||
|
FloatNeg,
|
||
|
FloatOp,
|
||
|
GetAttr,
|
||
|
GetElementPtr,
|
||
|
Goto,
|
||
|
IncRef,
|
||
|
InitStatic,
|
||
|
Integer,
|
||
|
IntOp,
|
||
|
KeepAlive,
|
||
|
LoadAddress,
|
||
|
LoadErrorValue,
|
||
|
LoadGlobal,
|
||
|
LoadLiteral,
|
||
|
LoadMem,
|
||
|
LoadStatic,
|
||
|
MethodCall,
|
||
|
Op,
|
||
|
OpVisitor,
|
||
|
RaiseStandardError,
|
||
|
Register,
|
||
|
Return,
|
||
|
SetAttr,
|
||
|
SetMem,
|
||
|
Truncate,
|
||
|
TupleGet,
|
||
|
TupleSet,
|
||
|
Unbox,
|
||
|
Unreachable,
|
||
|
Value,
|
||
|
)
|
||
|
from mypyc.ir.pprint import format_func
|
||
|
from mypyc.ir.rtypes import (
|
||
|
RArray,
|
||
|
RInstance,
|
||
|
RPrimitive,
|
||
|
RType,
|
||
|
RUnion,
|
||
|
bytes_rprimitive,
|
||
|
dict_rprimitive,
|
||
|
int_rprimitive,
|
||
|
is_float_rprimitive,
|
||
|
is_object_rprimitive,
|
||
|
list_rprimitive,
|
||
|
range_rprimitive,
|
||
|
set_rprimitive,
|
||
|
str_rprimitive,
|
||
|
tuple_rprimitive,
|
||
|
)
|
||
|
|
||
|
|
||
|
class FnError:
|
||
|
def __init__(self, source: Op | BasicBlock, desc: str) -> None:
|
||
|
self.source = source
|
||
|
self.desc = desc
|
||
|
|
||
|
def __eq__(self, other: object) -> bool:
|
||
|
return (
|
||
|
isinstance(other, FnError) and self.source == other.source and self.desc == other.desc
|
||
|
)
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return f"FnError(source={self.source}, desc={self.desc})"
|
||
|
|
||
|
|
||
|
def check_func_ir(fn: FuncIR) -> list[FnError]:
|
||
|
"""Applies validations to a given function ir and returns a list of errors found."""
|
||
|
errors = []
|
||
|
|
||
|
op_set = set()
|
||
|
|
||
|
for block in fn.blocks:
|
||
|
if not block.terminated:
|
||
|
errors.append(
|
||
|
FnError(source=block.ops[-1] if block.ops else block, desc="Block not terminated")
|
||
|
)
|
||
|
for op in block.ops[:-1]:
|
||
|
if isinstance(op, ControlOp):
|
||
|
errors.append(FnError(source=op, desc="Block has operations after control op"))
|
||
|
|
||
|
if op in op_set:
|
||
|
errors.append(FnError(source=op, desc="Func has a duplicate op"))
|
||
|
op_set.add(op)
|
||
|
|
||
|
errors.extend(check_op_sources_valid(fn))
|
||
|
if errors:
|
||
|
return errors
|
||
|
|
||
|
op_checker = OpChecker(fn)
|
||
|
for block in fn.blocks:
|
||
|
for op in block.ops:
|
||
|
op.accept(op_checker)
|
||
|
|
||
|
return op_checker.errors
|
||
|
|
||
|
|
||
|
class IrCheckException(Exception):
|
||
|
pass
|
||
|
|
||
|
|
||
|
def assert_func_ir_valid(fn: FuncIR) -> None:
|
||
|
errors = check_func_ir(fn)
|
||
|
if errors:
|
||
|
raise IrCheckException(
|
||
|
"Internal error: Generated invalid IR: \n"
|
||
|
+ "\n".join(format_func(fn, [(e.source, e.desc) for e in errors]))
|
||
|
)
|
||
|
|
||
|
|
||
|
def check_op_sources_valid(fn: FuncIR) -> list[FnError]:
|
||
|
errors = []
|
||
|
valid_ops: set[Op] = set()
|
||
|
valid_registers: set[Register] = set()
|
||
|
|
||
|
for block in fn.blocks:
|
||
|
valid_ops.update(block.ops)
|
||
|
|
||
|
for op in block.ops:
|
||
|
if isinstance(op, BaseAssign):
|
||
|
valid_registers.add(op.dest)
|
||
|
elif isinstance(op, LoadAddress) and isinstance(op.src, Register):
|
||
|
valid_registers.add(op.src)
|
||
|
|
||
|
valid_registers.update(fn.arg_regs)
|
||
|
|
||
|
for block in fn.blocks:
|
||
|
for op in block.ops:
|
||
|
for source in op.sources():
|
||
|
if isinstance(source, Integer):
|
||
|
pass
|
||
|
elif isinstance(source, Op):
|
||
|
if source not in valid_ops:
|
||
|
errors.append(
|
||
|
FnError(
|
||
|
source=op,
|
||
|
desc=f"Invalid op reference to op of type {type(source).__name__}",
|
||
|
)
|
||
|
)
|
||
|
elif isinstance(source, Register):
|
||
|
if source not in valid_registers:
|
||
|
errors.append(
|
||
|
FnError(
|
||
|
source=op, desc=f"Invalid op reference to register {source.name!r}"
|
||
|
)
|
||
|
)
|
||
|
|
||
|
return errors
|
||
|
|
||
|
|
||
|
disjoint_types = {
|
||
|
int_rprimitive.name,
|
||
|
bytes_rprimitive.name,
|
||
|
str_rprimitive.name,
|
||
|
dict_rprimitive.name,
|
||
|
list_rprimitive.name,
|
||
|
set_rprimitive.name,
|
||
|
tuple_rprimitive.name,
|
||
|
range_rprimitive.name,
|
||
|
}
|
||
|
|
||
|
|
||
|
def can_coerce_to(src: RType, dest: RType) -> bool:
|
||
|
"""Check if src can be assigned to dest_rtype.
|
||
|
|
||
|
Currently okay to have false positives.
|
||
|
"""
|
||
|
if isinstance(dest, RUnion):
|
||
|
return any(can_coerce_to(src, d) for d in dest.items)
|
||
|
|
||
|
if isinstance(dest, RPrimitive):
|
||
|
if isinstance(src, RPrimitive):
|
||
|
# If either src or dest is a disjoint type, then they must both be.
|
||
|
if src.name in disjoint_types and dest.name in disjoint_types:
|
||
|
return src.name == dest.name
|
||
|
return src.size == dest.size
|
||
|
if isinstance(src, RInstance):
|
||
|
return is_object_rprimitive(dest)
|
||
|
if isinstance(src, RUnion):
|
||
|
# IR doesn't have the ability to narrow unions based on
|
||
|
# control flow, so cannot be a strict all() here.
|
||
|
return any(can_coerce_to(s, dest) for s in src.items)
|
||
|
return False
|
||
|
|
||
|
return True
|
||
|
|
||
|
|
||
|
class OpChecker(OpVisitor[None]):
|
||
|
def __init__(self, parent_fn: FuncIR) -> None:
|
||
|
self.parent_fn = parent_fn
|
||
|
self.errors: list[FnError] = []
|
||
|
|
||
|
def fail(self, source: Op, desc: str) -> None:
|
||
|
self.errors.append(FnError(source=source, desc=desc))
|
||
|
|
||
|
def check_control_op_targets(self, op: ControlOp) -> None:
|
||
|
for target in op.targets():
|
||
|
if target not in self.parent_fn.blocks:
|
||
|
self.fail(source=op, desc=f"Invalid control operation target: {target.label}")
|
||
|
|
||
|
def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None:
|
||
|
if not can_coerce_to(src, dest):
|
||
|
self.fail(
|
||
|
source=op, desc=f"Cannot coerce source type {src.name} to dest type {dest.name}"
|
||
|
)
|
||
|
|
||
|
def check_compatibility(self, op: Op, t: RType, s: RType) -> None:
|
||
|
if not can_coerce_to(t, s) or not can_coerce_to(s, t):
|
||
|
self.fail(source=op, desc=f"{t.name} and {s.name} are not compatible")
|
||
|
|
||
|
def expect_float(self, op: Op, v: Value) -> None:
|
||
|
if not is_float_rprimitive(v.type):
|
||
|
self.fail(op, f"Float expected (actual type is {v.type})")
|
||
|
|
||
|
def expect_non_float(self, op: Op, v: Value) -> None:
|
||
|
if is_float_rprimitive(v.type):
|
||
|
self.fail(op, "Float not expected")
|
||
|
|
||
|
def visit_goto(self, op: Goto) -> None:
|
||
|
self.check_control_op_targets(op)
|
||
|
|
||
|
def visit_branch(self, op: Branch) -> None:
|
||
|
self.check_control_op_targets(op)
|
||
|
|
||
|
def visit_return(self, op: Return) -> None:
|
||
|
self.check_type_coercion(op, op.value.type, self.parent_fn.decl.sig.ret_type)
|
||
|
|
||
|
def visit_unreachable(self, op: Unreachable) -> None:
|
||
|
# Unreachables are checked at a higher level since validation
|
||
|
# requires access to the entire basic block.
|
||
|
pass
|
||
|
|
||
|
def visit_assign(self, op: Assign) -> None:
|
||
|
self.check_type_coercion(op, op.src.type, op.dest.type)
|
||
|
|
||
|
def visit_assign_multi(self, op: AssignMulti) -> None:
|
||
|
for src in op.src:
|
||
|
assert isinstance(op.dest.type, RArray)
|
||
|
self.check_type_coercion(op, src.type, op.dest.type.item_type)
|
||
|
|
||
|
def visit_load_error_value(self, op: LoadErrorValue) -> None:
|
||
|
# Currently it is assumed that all types have an error value.
|
||
|
# Once this is fixed we can validate that the rtype here actually
|
||
|
# has an error value.
|
||
|
pass
|
||
|
|
||
|
def check_tuple_items_valid_literals(self, op: LoadLiteral, t: tuple[object, ...]) -> None:
|
||
|
for x in t:
|
||
|
if x is not None and not isinstance(x, (str, bytes, bool, int, float, complex, tuple)):
|
||
|
self.fail(op, f"Invalid type for item of tuple literal: {type(x)})")
|
||
|
if isinstance(x, tuple):
|
||
|
self.check_tuple_items_valid_literals(op, x)
|
||
|
|
||
|
def check_frozenset_items_valid_literals(self, op: LoadLiteral, s: frozenset[object]) -> None:
|
||
|
for x in s:
|
||
|
if x is None or isinstance(x, (str, bytes, bool, int, float, complex)):
|
||
|
pass
|
||
|
elif isinstance(x, tuple):
|
||
|
self.check_tuple_items_valid_literals(op, x)
|
||
|
else:
|
||
|
self.fail(op, f"Invalid type for item of frozenset literal: {type(x)})")
|
||
|
|
||
|
def visit_load_literal(self, op: LoadLiteral) -> None:
|
||
|
expected_type = None
|
||
|
if op.value is None:
|
||
|
expected_type = "builtins.object"
|
||
|
elif isinstance(op.value, int):
|
||
|
expected_type = "builtins.int"
|
||
|
elif isinstance(op.value, str):
|
||
|
expected_type = "builtins.str"
|
||
|
elif isinstance(op.value, bytes):
|
||
|
expected_type = "builtins.bytes"
|
||
|
elif isinstance(op.value, bool):
|
||
|
expected_type = "builtins.object"
|
||
|
elif isinstance(op.value, float):
|
||
|
expected_type = "builtins.float"
|
||
|
elif isinstance(op.value, complex):
|
||
|
expected_type = "builtins.object"
|
||
|
elif isinstance(op.value, tuple):
|
||
|
expected_type = "builtins.tuple"
|
||
|
self.check_tuple_items_valid_literals(op, op.value)
|
||
|
elif isinstance(op.value, frozenset):
|
||
|
# There's no frozenset_rprimitive type since it'd be pretty useless so we just pretend
|
||
|
# it's a set (when it's really a frozenset).
|
||
|
expected_type = "builtins.set"
|
||
|
self.check_frozenset_items_valid_literals(op, op.value)
|
||
|
|
||
|
assert expected_type is not None, "Missed a case for LoadLiteral check"
|
||
|
|
||
|
if op.type.name not in [expected_type, "builtins.object"]:
|
||
|
self.fail(
|
||
|
op,
|
||
|
f"Invalid literal value for type: value has "
|
||
|
f"type {expected_type}, but op has type {op.type.name}",
|
||
|
)
|
||
|
|
||
|
def visit_get_attr(self, op: GetAttr) -> None:
|
||
|
# Nothing to do.
|
||
|
pass
|
||
|
|
||
|
def visit_set_attr(self, op: SetAttr) -> None:
|
||
|
# Nothing to do.
|
||
|
pass
|
||
|
|
||
|
# Static operations cannot be checked at the function level.
|
||
|
def visit_load_static(self, op: LoadStatic) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_init_static(self, op: InitStatic) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_tuple_get(self, op: TupleGet) -> None:
|
||
|
# Nothing to do.
|
||
|
pass
|
||
|
|
||
|
def visit_tuple_set(self, op: TupleSet) -> None:
|
||
|
# Nothing to do.
|
||
|
pass
|
||
|
|
||
|
def visit_inc_ref(self, op: IncRef) -> None:
|
||
|
# Nothing to do.
|
||
|
pass
|
||
|
|
||
|
def visit_dec_ref(self, op: DecRef) -> None:
|
||
|
# Nothing to do.
|
||
|
pass
|
||
|
|
||
|
def visit_call(self, op: Call) -> None:
|
||
|
# Length is checked in constructor, and return type is set
|
||
|
# in a way that can't be incorrect
|
||
|
for arg_value, arg_runtime in zip(op.args, op.fn.sig.args):
|
||
|
self.check_type_coercion(op, arg_value.type, arg_runtime.type)
|
||
|
|
||
|
def visit_method_call(self, op: MethodCall) -> None:
|
||
|
# Similar to above, but we must look up method first.
|
||
|
method_decl = op.receiver_type.class_ir.method_decl(op.method)
|
||
|
if method_decl.kind == FUNC_STATICMETHOD:
|
||
|
decl_index = 0
|
||
|
else:
|
||
|
decl_index = 1
|
||
|
|
||
|
if len(op.args) + decl_index != len(method_decl.sig.args):
|
||
|
self.fail(op, "Incorrect number of args for method call.")
|
||
|
|
||
|
# Skip the receiver argument (self)
|
||
|
for arg_value, arg_runtime in zip(op.args, method_decl.sig.args[decl_index:]):
|
||
|
self.check_type_coercion(op, arg_value.type, arg_runtime.type)
|
||
|
|
||
|
def visit_cast(self, op: Cast) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_box(self, op: Box) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_unbox(self, op: Unbox) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_call_c(self, op: CallC) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_truncate(self, op: Truncate) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_extend(self, op: Extend) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_load_global(self, op: LoadGlobal) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_int_op(self, op: IntOp) -> None:
|
||
|
self.expect_non_float(op, op.lhs)
|
||
|
self.expect_non_float(op, op.rhs)
|
||
|
|
||
|
def visit_comparison_op(self, op: ComparisonOp) -> None:
|
||
|
self.check_compatibility(op, op.lhs.type, op.rhs.type)
|
||
|
self.expect_non_float(op, op.lhs)
|
||
|
self.expect_non_float(op, op.rhs)
|
||
|
|
||
|
def visit_float_op(self, op: FloatOp) -> None:
|
||
|
self.expect_float(op, op.lhs)
|
||
|
self.expect_float(op, op.rhs)
|
||
|
|
||
|
def visit_float_neg(self, op: FloatNeg) -> None:
|
||
|
self.expect_float(op, op.src)
|
||
|
|
||
|
def visit_float_comparison_op(self, op: FloatComparisonOp) -> None:
|
||
|
self.expect_float(op, op.lhs)
|
||
|
self.expect_float(op, op.rhs)
|
||
|
|
||
|
def visit_load_mem(self, op: LoadMem) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_set_mem(self, op: SetMem) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_get_element_ptr(self, op: GetElementPtr) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_load_address(self, op: LoadAddress) -> None:
|
||
|
pass
|
||
|
|
||
|
def visit_keep_alive(self, op: KeepAlive) -> None:
|
||
|
pass
|