356 lines
12 KiB
Python
356 lines
12 KiB
Python
|
from contextlib import contextmanager
|
||
|
from typing import Generator, List, Optional, Tuple
|
||
|
|
||
|
from mypy.nodes import MatchStmt, NameExpr, TypeInfo
|
||
|
from mypy.patterns import (
|
||
|
AsPattern,
|
||
|
ClassPattern,
|
||
|
MappingPattern,
|
||
|
OrPattern,
|
||
|
Pattern,
|
||
|
SequencePattern,
|
||
|
SingletonPattern,
|
||
|
StarredPattern,
|
||
|
ValuePattern,
|
||
|
)
|
||
|
from mypy.traverser import TraverserVisitor
|
||
|
from mypy.types import Instance, TupleType, get_proper_type
|
||
|
from mypyc.ir.ops import BasicBlock, Value
|
||
|
from mypyc.ir.rtypes import object_rprimitive
|
||
|
from mypyc.irbuild.builder import IRBuilder
|
||
|
from mypyc.primitives.dict_ops import (
|
||
|
dict_copy,
|
||
|
dict_del_item,
|
||
|
mapping_has_key,
|
||
|
supports_mapping_protocol,
|
||
|
)
|
||
|
from mypyc.primitives.generic_ops import generic_ssize_t_len_op
|
||
|
from mypyc.primitives.list_ops import (
|
||
|
sequence_get_item,
|
||
|
sequence_get_slice,
|
||
|
supports_sequence_protocol,
|
||
|
)
|
||
|
from mypyc.primitives.misc_ops import fast_isinstance_op, slow_isinstance_op
|
||
|
|
||
|
# From: https://peps.python.org/pep-0634/#class-patterns
|
||
|
MATCHABLE_BUILTINS = {
|
||
|
"builtins.bool",
|
||
|
"builtins.bytearray",
|
||
|
"builtins.bytes",
|
||
|
"builtins.dict",
|
||
|
"builtins.float",
|
||
|
"builtins.frozenset",
|
||
|
"builtins.int",
|
||
|
"builtins.list",
|
||
|
"builtins.set",
|
||
|
"builtins.str",
|
||
|
"builtins.tuple",
|
||
|
}
|
||
|
|
||
|
|
||
|
class MatchVisitor(TraverserVisitor):
|
||
|
builder: IRBuilder
|
||
|
code_block: BasicBlock
|
||
|
next_block: BasicBlock
|
||
|
final_block: BasicBlock
|
||
|
subject: Value
|
||
|
match: MatchStmt
|
||
|
|
||
|
as_pattern: Optional[AsPattern] = None
|
||
|
|
||
|
def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None:
|
||
|
self.builder = builder
|
||
|
|
||
|
self.code_block = BasicBlock()
|
||
|
self.next_block = BasicBlock()
|
||
|
self.final_block = BasicBlock()
|
||
|
|
||
|
self.match = match_node
|
||
|
self.subject = builder.accept(match_node.subject)
|
||
|
|
||
|
def build_match_body(self, index: int) -> None:
|
||
|
self.builder.activate_block(self.code_block)
|
||
|
|
||
|
guard = self.match.guards[index]
|
||
|
|
||
|
if guard:
|
||
|
self.code_block = BasicBlock()
|
||
|
|
||
|
cond = self.builder.accept(guard)
|
||
|
self.builder.add_bool_branch(cond, self.code_block, self.next_block)
|
||
|
|
||
|
self.builder.activate_block(self.code_block)
|
||
|
|
||
|
self.builder.accept(self.match.bodies[index])
|
||
|
self.builder.goto(self.final_block)
|
||
|
|
||
|
def visit_match_stmt(self, m: MatchStmt) -> None:
|
||
|
for i, pattern in enumerate(m.patterns):
|
||
|
self.code_block = BasicBlock()
|
||
|
self.next_block = BasicBlock()
|
||
|
|
||
|
pattern.accept(self)
|
||
|
|
||
|
self.build_match_body(i)
|
||
|
self.builder.activate_block(self.next_block)
|
||
|
|
||
|
self.builder.goto_and_activate(self.final_block)
|
||
|
|
||
|
def visit_value_pattern(self, pattern: ValuePattern) -> None:
|
||
|
value = self.builder.accept(pattern.expr)
|
||
|
|
||
|
cond = self.builder.binary_op(self.subject, value, "==", pattern.expr.line)
|
||
|
|
||
|
self.bind_as_pattern(value)
|
||
|
|
||
|
self.builder.add_bool_branch(cond, self.code_block, self.next_block)
|
||
|
|
||
|
def visit_or_pattern(self, pattern: OrPattern) -> None:
|
||
|
backup_block = self.next_block
|
||
|
self.next_block = BasicBlock()
|
||
|
|
||
|
for p in pattern.patterns:
|
||
|
# Hack to ensure the as pattern is bound to each pattern in the
|
||
|
# "or" pattern, but not every subpattern
|
||
|
backup = self.as_pattern
|
||
|
p.accept(self)
|
||
|
self.as_pattern = backup
|
||
|
|
||
|
self.builder.activate_block(self.next_block)
|
||
|
self.next_block = BasicBlock()
|
||
|
|
||
|
self.next_block = backup_block
|
||
|
self.builder.goto(self.next_block)
|
||
|
|
||
|
def visit_class_pattern(self, pattern: ClassPattern) -> None:
|
||
|
# TODO: use faster instance check for native classes (while still
|
||
|
# making sure to account for inheritence)
|
||
|
isinstance_op = (
|
||
|
fast_isinstance_op
|
||
|
if self.builder.is_builtin_ref_expr(pattern.class_ref)
|
||
|
else slow_isinstance_op
|
||
|
)
|
||
|
|
||
|
cond = self.builder.call_c(
|
||
|
isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], pattern.line
|
||
|
)
|
||
|
|
||
|
self.builder.add_bool_branch(cond, self.code_block, self.next_block)
|
||
|
|
||
|
self.bind_as_pattern(self.subject, new_block=True)
|
||
|
|
||
|
if pattern.positionals:
|
||
|
if pattern.class_ref.fullname in MATCHABLE_BUILTINS:
|
||
|
self.builder.activate_block(self.code_block)
|
||
|
self.code_block = BasicBlock()
|
||
|
|
||
|
pattern.positionals[0].accept(self)
|
||
|
|
||
|
return
|
||
|
|
||
|
node = pattern.class_ref.node
|
||
|
assert isinstance(node, TypeInfo)
|
||
|
|
||
|
ty = node.names.get("__match_args__")
|
||
|
assert ty
|
||
|
|
||
|
match_args_type = get_proper_type(ty.type)
|
||
|
assert isinstance(match_args_type, TupleType)
|
||
|
|
||
|
match_args: List[str] = []
|
||
|
|
||
|
for item in match_args_type.items:
|
||
|
proper_item = get_proper_type(item)
|
||
|
assert isinstance(proper_item, Instance) and proper_item.last_known_value
|
||
|
|
||
|
match_arg = proper_item.last_known_value.value
|
||
|
assert isinstance(match_arg, str)
|
||
|
|
||
|
match_args.append(match_arg)
|
||
|
|
||
|
for i, expr in enumerate(pattern.positionals):
|
||
|
self.builder.activate_block(self.code_block)
|
||
|
self.code_block = BasicBlock()
|
||
|
|
||
|
# TODO: use faster "get_attr" method instead when calling on native or
|
||
|
# builtin objects
|
||
|
positional = self.builder.py_get_attr(self.subject, match_args[i], expr.line)
|
||
|
|
||
|
with self.enter_subpattern(positional):
|
||
|
expr.accept(self)
|
||
|
|
||
|
for key, value in zip(pattern.keyword_keys, pattern.keyword_values):
|
||
|
self.builder.activate_block(self.code_block)
|
||
|
self.code_block = BasicBlock()
|
||
|
|
||
|
# TODO: same as above "get_attr" comment
|
||
|
attr = self.builder.py_get_attr(self.subject, key, value.line)
|
||
|
|
||
|
with self.enter_subpattern(attr):
|
||
|
value.accept(self)
|
||
|
|
||
|
def visit_as_pattern(self, pattern: AsPattern) -> None:
|
||
|
if pattern.pattern:
|
||
|
old_pattern = self.as_pattern
|
||
|
self.as_pattern = pattern
|
||
|
pattern.pattern.accept(self)
|
||
|
self.as_pattern = old_pattern
|
||
|
|
||
|
elif pattern.name:
|
||
|
target = self.builder.get_assignment_target(pattern.name)
|
||
|
|
||
|
self.builder.assign(target, self.subject, pattern.line)
|
||
|
|
||
|
self.builder.goto(self.code_block)
|
||
|
|
||
|
def visit_singleton_pattern(self, pattern: SingletonPattern) -> None:
|
||
|
if pattern.value is None:
|
||
|
obj = self.builder.none_object()
|
||
|
elif pattern.value is True:
|
||
|
obj = self.builder.true()
|
||
|
else:
|
||
|
obj = self.builder.false()
|
||
|
|
||
|
cond = self.builder.binary_op(self.subject, obj, "is", pattern.line)
|
||
|
|
||
|
self.builder.add_bool_branch(cond, self.code_block, self.next_block)
|
||
|
|
||
|
def visit_mapping_pattern(self, pattern: MappingPattern) -> None:
|
||
|
is_dict = self.builder.call_c(supports_mapping_protocol, [self.subject], pattern.line)
|
||
|
|
||
|
self.builder.add_bool_branch(is_dict, self.code_block, self.next_block)
|
||
|
|
||
|
keys: List[Value] = []
|
||
|
|
||
|
for key, value in zip(pattern.keys, pattern.values):
|
||
|
self.builder.activate_block(self.code_block)
|
||
|
self.code_block = BasicBlock()
|
||
|
|
||
|
key_value = self.builder.accept(key)
|
||
|
keys.append(key_value)
|
||
|
|
||
|
exists = self.builder.call_c(mapping_has_key, [self.subject, key_value], pattern.line)
|
||
|
|
||
|
self.builder.add_bool_branch(exists, self.code_block, self.next_block)
|
||
|
self.builder.activate_block(self.code_block)
|
||
|
self.code_block = BasicBlock()
|
||
|
|
||
|
item = self.builder.gen_method_call(
|
||
|
self.subject, "__getitem__", [key_value], object_rprimitive, pattern.line
|
||
|
)
|
||
|
|
||
|
with self.enter_subpattern(item):
|
||
|
value.accept(self)
|
||
|
|
||
|
if pattern.rest:
|
||
|
self.builder.activate_block(self.code_block)
|
||
|
self.code_block = BasicBlock()
|
||
|
|
||
|
rest = self.builder.call_c(dict_copy, [self.subject], pattern.rest.line)
|
||
|
|
||
|
target = self.builder.get_assignment_target(pattern.rest)
|
||
|
|
||
|
self.builder.assign(target, rest, pattern.rest.line)
|
||
|
|
||
|
for i, key_name in enumerate(keys):
|
||
|
self.builder.call_c(dict_del_item, [rest, key_name], pattern.keys[i].line)
|
||
|
|
||
|
self.builder.goto(self.code_block)
|
||
|
|
||
|
def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None:
|
||
|
star_index, capture, patterns = prep_sequence_pattern(seq_pattern)
|
||
|
|
||
|
is_list = self.builder.call_c(supports_sequence_protocol, [self.subject], seq_pattern.line)
|
||
|
|
||
|
self.builder.add_bool_branch(is_list, self.code_block, self.next_block)
|
||
|
|
||
|
self.builder.activate_block(self.code_block)
|
||
|
self.code_block = BasicBlock()
|
||
|
|
||
|
actual_len = self.builder.call_c(generic_ssize_t_len_op, [self.subject], seq_pattern.line)
|
||
|
min_len = len(patterns)
|
||
|
|
||
|
is_long_enough = self.builder.binary_op(
|
||
|
actual_len,
|
||
|
self.builder.load_int(min_len),
|
||
|
"==" if star_index is None else ">=",
|
||
|
seq_pattern.line,
|
||
|
)
|
||
|
|
||
|
self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block)
|
||
|
|
||
|
for i, pattern in enumerate(patterns):
|
||
|
self.builder.activate_block(self.code_block)
|
||
|
self.code_block = BasicBlock()
|
||
|
|
||
|
if star_index is not None and i >= star_index:
|
||
|
current = self.builder.binary_op(
|
||
|
actual_len, self.builder.load_int(min_len - i), "-", pattern.line
|
||
|
)
|
||
|
|
||
|
else:
|
||
|
current = self.builder.load_int(i)
|
||
|
|
||
|
item = self.builder.call_c(sequence_get_item, [self.subject, current], pattern.line)
|
||
|
|
||
|
with self.enter_subpattern(item):
|
||
|
pattern.accept(self)
|
||
|
|
||
|
if capture and star_index is not None:
|
||
|
self.builder.activate_block(self.code_block)
|
||
|
self.code_block = BasicBlock()
|
||
|
|
||
|
capture_end = self.builder.binary_op(
|
||
|
actual_len, self.builder.load_int(min_len - star_index), "-", capture.line
|
||
|
)
|
||
|
|
||
|
rest = self.builder.call_c(
|
||
|
sequence_get_slice,
|
||
|
[self.subject, self.builder.load_int(star_index), capture_end],
|
||
|
capture.line,
|
||
|
)
|
||
|
|
||
|
target = self.builder.get_assignment_target(capture)
|
||
|
self.builder.assign(target, rest, capture.line)
|
||
|
|
||
|
self.builder.goto(self.code_block)
|
||
|
|
||
|
def bind_as_pattern(self, value: Value, new_block: bool = False) -> None:
|
||
|
if self.as_pattern and self.as_pattern.pattern and self.as_pattern.name:
|
||
|
if new_block:
|
||
|
self.builder.activate_block(self.code_block)
|
||
|
self.code_block = BasicBlock()
|
||
|
|
||
|
target = self.builder.get_assignment_target(self.as_pattern.name)
|
||
|
self.builder.assign(target, value, self.as_pattern.pattern.line)
|
||
|
|
||
|
self.as_pattern = None
|
||
|
|
||
|
if new_block:
|
||
|
self.builder.goto(self.code_block)
|
||
|
|
||
|
@contextmanager
|
||
|
def enter_subpattern(self, subject: Value) -> Generator[None, None, None]:
|
||
|
old_subject = self.subject
|
||
|
self.subject = subject
|
||
|
yield
|
||
|
self.subject = old_subject
|
||
|
|
||
|
|
||
|
def prep_sequence_pattern(
|
||
|
seq_pattern: SequencePattern,
|
||
|
) -> Tuple[Optional[int], Optional[NameExpr], List[Pattern]]:
|
||
|
star_index: Optional[int] = None
|
||
|
capture: Optional[NameExpr] = None
|
||
|
patterns: List[Pattern] = []
|
||
|
|
||
|
for i, pattern in enumerate(seq_pattern.patterns):
|
||
|
if isinstance(pattern, StarredPattern):
|
||
|
star_index = i
|
||
|
capture = pattern.capture
|
||
|
|
||
|
else:
|
||
|
patterns.append(pattern)
|
||
|
|
||
|
return star_index, capture, patterns
|