Tipragot
628be439b8
Cela permet de ne pas avoir de problèmes de compatibilité car python est dans le git.
425 lines
13 KiB
Python
425 lines
13 KiB
Python
"""Utilities for checking that internal ir is valid and consistent."""
|
|
from __future__ import annotations
|
|
|
|
from mypyc.ir.func_ir import FUNC_STATICMETHOD, FuncIR
|
|
from mypyc.ir.ops import (
|
|
Assign,
|
|
AssignMulti,
|
|
BaseAssign,
|
|
BasicBlock,
|
|
Box,
|
|
Branch,
|
|
Call,
|
|
CallC,
|
|
Cast,
|
|
ComparisonOp,
|
|
ControlOp,
|
|
DecRef,
|
|
Extend,
|
|
FloatComparisonOp,
|
|
FloatNeg,
|
|
FloatOp,
|
|
GetAttr,
|
|
GetElementPtr,
|
|
Goto,
|
|
IncRef,
|
|
InitStatic,
|
|
Integer,
|
|
IntOp,
|
|
KeepAlive,
|
|
LoadAddress,
|
|
LoadErrorValue,
|
|
LoadGlobal,
|
|
LoadLiteral,
|
|
LoadMem,
|
|
LoadStatic,
|
|
MethodCall,
|
|
Op,
|
|
OpVisitor,
|
|
RaiseStandardError,
|
|
Register,
|
|
Return,
|
|
SetAttr,
|
|
SetMem,
|
|
Truncate,
|
|
TupleGet,
|
|
TupleSet,
|
|
Unbox,
|
|
Unreachable,
|
|
Value,
|
|
)
|
|
from mypyc.ir.pprint import format_func
|
|
from mypyc.ir.rtypes import (
|
|
RArray,
|
|
RInstance,
|
|
RPrimitive,
|
|
RType,
|
|
RUnion,
|
|
bytes_rprimitive,
|
|
dict_rprimitive,
|
|
int_rprimitive,
|
|
is_float_rprimitive,
|
|
is_object_rprimitive,
|
|
list_rprimitive,
|
|
range_rprimitive,
|
|
set_rprimitive,
|
|
str_rprimitive,
|
|
tuple_rprimitive,
|
|
)
|
|
|
|
|
|
class FnError:
|
|
def __init__(self, source: Op | BasicBlock, desc: str) -> None:
|
|
self.source = source
|
|
self.desc = desc
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
return (
|
|
isinstance(other, FnError) and self.source == other.source and self.desc == other.desc
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"FnError(source={self.source}, desc={self.desc})"
|
|
|
|
|
|
def check_func_ir(fn: FuncIR) -> list[FnError]:
|
|
"""Applies validations to a given function ir and returns a list of errors found."""
|
|
errors = []
|
|
|
|
op_set = set()
|
|
|
|
for block in fn.blocks:
|
|
if not block.terminated:
|
|
errors.append(
|
|
FnError(source=block.ops[-1] if block.ops else block, desc="Block not terminated")
|
|
)
|
|
for op in block.ops[:-1]:
|
|
if isinstance(op, ControlOp):
|
|
errors.append(FnError(source=op, desc="Block has operations after control op"))
|
|
|
|
if op in op_set:
|
|
errors.append(FnError(source=op, desc="Func has a duplicate op"))
|
|
op_set.add(op)
|
|
|
|
errors.extend(check_op_sources_valid(fn))
|
|
if errors:
|
|
return errors
|
|
|
|
op_checker = OpChecker(fn)
|
|
for block in fn.blocks:
|
|
for op in block.ops:
|
|
op.accept(op_checker)
|
|
|
|
return op_checker.errors
|
|
|
|
|
|
class IrCheckException(Exception):
|
|
pass
|
|
|
|
|
|
def assert_func_ir_valid(fn: FuncIR) -> None:
|
|
errors = check_func_ir(fn)
|
|
if errors:
|
|
raise IrCheckException(
|
|
"Internal error: Generated invalid IR: \n"
|
|
+ "\n".join(format_func(fn, [(e.source, e.desc) for e in errors]))
|
|
)
|
|
|
|
|
|
def check_op_sources_valid(fn: FuncIR) -> list[FnError]:
|
|
errors = []
|
|
valid_ops: set[Op] = set()
|
|
valid_registers: set[Register] = set()
|
|
|
|
for block in fn.blocks:
|
|
valid_ops.update(block.ops)
|
|
|
|
for op in block.ops:
|
|
if isinstance(op, BaseAssign):
|
|
valid_registers.add(op.dest)
|
|
elif isinstance(op, LoadAddress) and isinstance(op.src, Register):
|
|
valid_registers.add(op.src)
|
|
|
|
valid_registers.update(fn.arg_regs)
|
|
|
|
for block in fn.blocks:
|
|
for op in block.ops:
|
|
for source in op.sources():
|
|
if isinstance(source, Integer):
|
|
pass
|
|
elif isinstance(source, Op):
|
|
if source not in valid_ops:
|
|
errors.append(
|
|
FnError(
|
|
source=op,
|
|
desc=f"Invalid op reference to op of type {type(source).__name__}",
|
|
)
|
|
)
|
|
elif isinstance(source, Register):
|
|
if source not in valid_registers:
|
|
errors.append(
|
|
FnError(
|
|
source=op, desc=f"Invalid op reference to register {source.name!r}"
|
|
)
|
|
)
|
|
|
|
return errors
|
|
|
|
|
|
disjoint_types = {
|
|
int_rprimitive.name,
|
|
bytes_rprimitive.name,
|
|
str_rprimitive.name,
|
|
dict_rprimitive.name,
|
|
list_rprimitive.name,
|
|
set_rprimitive.name,
|
|
tuple_rprimitive.name,
|
|
range_rprimitive.name,
|
|
}
|
|
|
|
|
|
def can_coerce_to(src: RType, dest: RType) -> bool:
|
|
"""Check if src can be assigned to dest_rtype.
|
|
|
|
Currently okay to have false positives.
|
|
"""
|
|
if isinstance(dest, RUnion):
|
|
return any(can_coerce_to(src, d) for d in dest.items)
|
|
|
|
if isinstance(dest, RPrimitive):
|
|
if isinstance(src, RPrimitive):
|
|
# If either src or dest is a disjoint type, then they must both be.
|
|
if src.name in disjoint_types and dest.name in disjoint_types:
|
|
return src.name == dest.name
|
|
return src.size == dest.size
|
|
if isinstance(src, RInstance):
|
|
return is_object_rprimitive(dest)
|
|
if isinstance(src, RUnion):
|
|
# IR doesn't have the ability to narrow unions based on
|
|
# control flow, so cannot be a strict all() here.
|
|
return any(can_coerce_to(s, dest) for s in src.items)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class OpChecker(OpVisitor[None]):
|
|
def __init__(self, parent_fn: FuncIR) -> None:
|
|
self.parent_fn = parent_fn
|
|
self.errors: list[FnError] = []
|
|
|
|
def fail(self, source: Op, desc: str) -> None:
|
|
self.errors.append(FnError(source=source, desc=desc))
|
|
|
|
def check_control_op_targets(self, op: ControlOp) -> None:
|
|
for target in op.targets():
|
|
if target not in self.parent_fn.blocks:
|
|
self.fail(source=op, desc=f"Invalid control operation target: {target.label}")
|
|
|
|
def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None:
|
|
if not can_coerce_to(src, dest):
|
|
self.fail(
|
|
source=op, desc=f"Cannot coerce source type {src.name} to dest type {dest.name}"
|
|
)
|
|
|
|
def check_compatibility(self, op: Op, t: RType, s: RType) -> None:
|
|
if not can_coerce_to(t, s) or not can_coerce_to(s, t):
|
|
self.fail(source=op, desc=f"{t.name} and {s.name} are not compatible")
|
|
|
|
def expect_float(self, op: Op, v: Value) -> None:
|
|
if not is_float_rprimitive(v.type):
|
|
self.fail(op, f"Float expected (actual type is {v.type})")
|
|
|
|
def expect_non_float(self, op: Op, v: Value) -> None:
|
|
if is_float_rprimitive(v.type):
|
|
self.fail(op, "Float not expected")
|
|
|
|
def visit_goto(self, op: Goto) -> None:
|
|
self.check_control_op_targets(op)
|
|
|
|
def visit_branch(self, op: Branch) -> None:
|
|
self.check_control_op_targets(op)
|
|
|
|
def visit_return(self, op: Return) -> None:
|
|
self.check_type_coercion(op, op.value.type, self.parent_fn.decl.sig.ret_type)
|
|
|
|
def visit_unreachable(self, op: Unreachable) -> None:
|
|
# Unreachables are checked at a higher level since validation
|
|
# requires access to the entire basic block.
|
|
pass
|
|
|
|
def visit_assign(self, op: Assign) -> None:
|
|
self.check_type_coercion(op, op.src.type, op.dest.type)
|
|
|
|
def visit_assign_multi(self, op: AssignMulti) -> None:
|
|
for src in op.src:
|
|
assert isinstance(op.dest.type, RArray)
|
|
self.check_type_coercion(op, src.type, op.dest.type.item_type)
|
|
|
|
def visit_load_error_value(self, op: LoadErrorValue) -> None:
|
|
# Currently it is assumed that all types have an error value.
|
|
# Once this is fixed we can validate that the rtype here actually
|
|
# has an error value.
|
|
pass
|
|
|
|
def check_tuple_items_valid_literals(self, op: LoadLiteral, t: tuple[object, ...]) -> None:
|
|
for x in t:
|
|
if x is not None and not isinstance(x, (str, bytes, bool, int, float, complex, tuple)):
|
|
self.fail(op, f"Invalid type for item of tuple literal: {type(x)})")
|
|
if isinstance(x, tuple):
|
|
self.check_tuple_items_valid_literals(op, x)
|
|
|
|
def check_frozenset_items_valid_literals(self, op: LoadLiteral, s: frozenset[object]) -> None:
|
|
for x in s:
|
|
if x is None or isinstance(x, (str, bytes, bool, int, float, complex)):
|
|
pass
|
|
elif isinstance(x, tuple):
|
|
self.check_tuple_items_valid_literals(op, x)
|
|
else:
|
|
self.fail(op, f"Invalid type for item of frozenset literal: {type(x)})")
|
|
|
|
def visit_load_literal(self, op: LoadLiteral) -> None:
|
|
expected_type = None
|
|
if op.value is None:
|
|
expected_type = "builtins.object"
|
|
elif isinstance(op.value, int):
|
|
expected_type = "builtins.int"
|
|
elif isinstance(op.value, str):
|
|
expected_type = "builtins.str"
|
|
elif isinstance(op.value, bytes):
|
|
expected_type = "builtins.bytes"
|
|
elif isinstance(op.value, bool):
|
|
expected_type = "builtins.object"
|
|
elif isinstance(op.value, float):
|
|
expected_type = "builtins.float"
|
|
elif isinstance(op.value, complex):
|
|
expected_type = "builtins.object"
|
|
elif isinstance(op.value, tuple):
|
|
expected_type = "builtins.tuple"
|
|
self.check_tuple_items_valid_literals(op, op.value)
|
|
elif isinstance(op.value, frozenset):
|
|
# There's no frozenset_rprimitive type since it'd be pretty useless so we just pretend
|
|
# it's a set (when it's really a frozenset).
|
|
expected_type = "builtins.set"
|
|
self.check_frozenset_items_valid_literals(op, op.value)
|
|
|
|
assert expected_type is not None, "Missed a case for LoadLiteral check"
|
|
|
|
if op.type.name not in [expected_type, "builtins.object"]:
|
|
self.fail(
|
|
op,
|
|
f"Invalid literal value for type: value has "
|
|
f"type {expected_type}, but op has type {op.type.name}",
|
|
)
|
|
|
|
def visit_get_attr(self, op: GetAttr) -> None:
|
|
# Nothing to do.
|
|
pass
|
|
|
|
def visit_set_attr(self, op: SetAttr) -> None:
|
|
# Nothing to do.
|
|
pass
|
|
|
|
# Static operations cannot be checked at the function level.
|
|
def visit_load_static(self, op: LoadStatic) -> None:
|
|
pass
|
|
|
|
def visit_init_static(self, op: InitStatic) -> None:
|
|
pass
|
|
|
|
def visit_tuple_get(self, op: TupleGet) -> None:
|
|
# Nothing to do.
|
|
pass
|
|
|
|
def visit_tuple_set(self, op: TupleSet) -> None:
|
|
# Nothing to do.
|
|
pass
|
|
|
|
def visit_inc_ref(self, op: IncRef) -> None:
|
|
# Nothing to do.
|
|
pass
|
|
|
|
def visit_dec_ref(self, op: DecRef) -> None:
|
|
# Nothing to do.
|
|
pass
|
|
|
|
def visit_call(self, op: Call) -> None:
|
|
# Length is checked in constructor, and return type is set
|
|
# in a way that can't be incorrect
|
|
for arg_value, arg_runtime in zip(op.args, op.fn.sig.args):
|
|
self.check_type_coercion(op, arg_value.type, arg_runtime.type)
|
|
|
|
def visit_method_call(self, op: MethodCall) -> None:
|
|
# Similar to above, but we must look up method first.
|
|
method_decl = op.receiver_type.class_ir.method_decl(op.method)
|
|
if method_decl.kind == FUNC_STATICMETHOD:
|
|
decl_index = 0
|
|
else:
|
|
decl_index = 1
|
|
|
|
if len(op.args) + decl_index != len(method_decl.sig.args):
|
|
self.fail(op, "Incorrect number of args for method call.")
|
|
|
|
# Skip the receiver argument (self)
|
|
for arg_value, arg_runtime in zip(op.args, method_decl.sig.args[decl_index:]):
|
|
self.check_type_coercion(op, arg_value.type, arg_runtime.type)
|
|
|
|
def visit_cast(self, op: Cast) -> None:
|
|
pass
|
|
|
|
def visit_box(self, op: Box) -> None:
|
|
pass
|
|
|
|
def visit_unbox(self, op: Unbox) -> None:
|
|
pass
|
|
|
|
def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
|
|
pass
|
|
|
|
def visit_call_c(self, op: CallC) -> None:
|
|
pass
|
|
|
|
def visit_truncate(self, op: Truncate) -> None:
|
|
pass
|
|
|
|
def visit_extend(self, op: Extend) -> None:
|
|
pass
|
|
|
|
def visit_load_global(self, op: LoadGlobal) -> None:
|
|
pass
|
|
|
|
def visit_int_op(self, op: IntOp) -> None:
|
|
self.expect_non_float(op, op.lhs)
|
|
self.expect_non_float(op, op.rhs)
|
|
|
|
def visit_comparison_op(self, op: ComparisonOp) -> None:
|
|
self.check_compatibility(op, op.lhs.type, op.rhs.type)
|
|
self.expect_non_float(op, op.lhs)
|
|
self.expect_non_float(op, op.rhs)
|
|
|
|
def visit_float_op(self, op: FloatOp) -> None:
|
|
self.expect_float(op, op.lhs)
|
|
self.expect_float(op, op.rhs)
|
|
|
|
def visit_float_neg(self, op: FloatNeg) -> None:
|
|
self.expect_float(op, op.src)
|
|
|
|
def visit_float_comparison_op(self, op: FloatComparisonOp) -> None:
|
|
self.expect_float(op, op.lhs)
|
|
self.expect_float(op, op.rhs)
|
|
|
|
def visit_load_mem(self, op: LoadMem) -> None:
|
|
pass
|
|
|
|
def visit_set_mem(self, op: SetMem) -> None:
|
|
pass
|
|
|
|
def visit_get_element_ptr(self, op: GetElementPtr) -> None:
|
|
pass
|
|
|
|
def visit_load_address(self, op: LoadAddress) -> None:
|
|
pass
|
|
|
|
def visit_keep_alive(self, op: KeepAlive) -> None:
|
|
pass
|