96 lines
3.2 KiB
Python
96 lines
3.2 KiB
Python
|
"""Constant folding of IR values.
|
||
|
|
||
|
For example, 3 + 5 can be constant folded into 8.
|
||
|
|
||
|
This is mostly like mypy.constant_fold, but we can bind some additional
|
||
|
NameExpr and MemberExpr references here, since we have more knowledge
|
||
|
about which definitions can be trusted -- we constant fold only references
|
||
|
to other compiled modules in the same compilation unit.
|
||
|
"""
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
from typing import Final, Union
|
||
|
|
||
|
from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op
|
||
|
from mypy.nodes import (
|
||
|
BytesExpr,
|
||
|
ComplexExpr,
|
||
|
Expression,
|
||
|
FloatExpr,
|
||
|
IntExpr,
|
||
|
MemberExpr,
|
||
|
NameExpr,
|
||
|
OpExpr,
|
||
|
StrExpr,
|
||
|
UnaryExpr,
|
||
|
Var,
|
||
|
)
|
||
|
from mypyc.irbuild.builder import IRBuilder
|
||
|
from mypyc.irbuild.util import bytes_from_str
|
||
|
|
||
|
# All possible result types of constant folding
|
||
|
ConstantValue = Union[int, float, complex, str, bytes]
|
||
|
CONST_TYPES: Final = (int, float, complex, str, bytes)
|
||
|
|
||
|
|
||
|
def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None:
|
||
|
"""Return the constant value of an expression for supported operations.
|
||
|
|
||
|
Return None otherwise.
|
||
|
"""
|
||
|
if isinstance(expr, IntExpr):
|
||
|
return expr.value
|
||
|
if isinstance(expr, FloatExpr):
|
||
|
return expr.value
|
||
|
if isinstance(expr, StrExpr):
|
||
|
return expr.value
|
||
|
if isinstance(expr, BytesExpr):
|
||
|
return bytes_from_str(expr.value)
|
||
|
if isinstance(expr, ComplexExpr):
|
||
|
return expr.value
|
||
|
elif isinstance(expr, NameExpr):
|
||
|
node = expr.node
|
||
|
if isinstance(node, Var) and node.is_final:
|
||
|
final_value = node.final_value
|
||
|
if isinstance(final_value, (CONST_TYPES)):
|
||
|
return final_value
|
||
|
elif isinstance(expr, MemberExpr):
|
||
|
final = builder.get_final_ref(expr)
|
||
|
if final is not None:
|
||
|
fn, final_var, native = final
|
||
|
if final_var.is_final:
|
||
|
final_value = final_var.final_value
|
||
|
if isinstance(final_value, (CONST_TYPES)):
|
||
|
return final_value
|
||
|
elif isinstance(expr, OpExpr):
|
||
|
left = constant_fold_expr(builder, expr.left)
|
||
|
right = constant_fold_expr(builder, expr.right)
|
||
|
if left is not None and right is not None:
|
||
|
return constant_fold_binary_op_extended(expr.op, left, right)
|
||
|
elif isinstance(expr, UnaryExpr):
|
||
|
value = constant_fold_expr(builder, expr.expr)
|
||
|
if value is not None and not isinstance(value, bytes):
|
||
|
return constant_fold_unary_op(expr.op, value)
|
||
|
return None
|
||
|
|
||
|
|
||
|
def constant_fold_binary_op_extended(
|
||
|
op: str, left: ConstantValue, right: ConstantValue
|
||
|
) -> ConstantValue | None:
|
||
|
"""Like mypy's constant_fold_binary_op(), but includes bytes support.
|
||
|
|
||
|
mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc.
|
||
|
"""
|
||
|
if not isinstance(left, bytes) and not isinstance(right, bytes):
|
||
|
return constant_fold_binary_op(op, left, right)
|
||
|
|
||
|
if op == "+" and isinstance(left, bytes) and isinstance(right, bytes):
|
||
|
return left + right
|
||
|
elif op == "*" and isinstance(left, bytes) and isinstance(right, int):
|
||
|
return left * right
|
||
|
elif op == "*" and isinstance(left, int) and isinstance(right, bytes):
|
||
|
return left * right
|
||
|
|
||
|
return None
|