203 lines
7.9 KiB
Python
203 lines
7.9 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
from mypy.nodes import (
|
||
|
Block,
|
||
|
Decorator,
|
||
|
Expression,
|
||
|
FuncDef,
|
||
|
FuncItem,
|
||
|
Import,
|
||
|
LambdaExpr,
|
||
|
MemberExpr,
|
||
|
MypyFile,
|
||
|
NameExpr,
|
||
|
Node,
|
||
|
SymbolNode,
|
||
|
Var,
|
||
|
)
|
||
|
from mypy.traverser import ExtendedTraverserVisitor
|
||
|
from mypyc.errors import Errors
|
||
|
|
||
|
|
||
|
class PreBuildVisitor(ExtendedTraverserVisitor):
|
||
|
"""Mypy file AST visitor run before building the IR.
|
||
|
|
||
|
This collects various things, including:
|
||
|
|
||
|
* Determine relationships between nested functions and functions that
|
||
|
contain nested functions
|
||
|
* Find non-local variables (free variables)
|
||
|
* Find property setters
|
||
|
* Find decorators of functions
|
||
|
* Find module import groups
|
||
|
|
||
|
The main IR build pass uses this information.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
errors: Errors,
|
||
|
current_file: MypyFile,
|
||
|
decorators_to_remove: dict[FuncDef, list[int]],
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
# Dict from a function to symbols defined directly in the
|
||
|
# function that are used as non-local (free) variables within a
|
||
|
# nested function.
|
||
|
self.free_variables: dict[FuncItem, set[SymbolNode]] = {}
|
||
|
|
||
|
# Intermediate data structure used to find the function where
|
||
|
# a SymbolNode is declared. Initially this may point to a
|
||
|
# function nested inside the function with the declaration,
|
||
|
# but we'll eventually update this to refer to the function
|
||
|
# with the declaration.
|
||
|
self.symbols_to_funcs: dict[SymbolNode, FuncItem] = {}
|
||
|
|
||
|
# Stack representing current function nesting.
|
||
|
self.funcs: list[FuncItem] = []
|
||
|
|
||
|
# All property setters encountered so far.
|
||
|
self.prop_setters: set[FuncDef] = set()
|
||
|
|
||
|
# A map from any function that contains nested functions to
|
||
|
# a set of all the functions that are nested within it.
|
||
|
self.encapsulating_funcs: dict[FuncItem, list[FuncItem]] = {}
|
||
|
|
||
|
# Map nested function to its parent/encapsulating function.
|
||
|
self.nested_funcs: dict[FuncItem, FuncItem] = {}
|
||
|
|
||
|
# Map function to its non-special decorators.
|
||
|
self.funcs_to_decorators: dict[FuncDef, list[Expression]] = {}
|
||
|
|
||
|
# Map function to indices of decorators to remove
|
||
|
self.decorators_to_remove: dict[FuncDef, list[int]] = decorators_to_remove
|
||
|
|
||
|
# A mapping of import groups (a series of Import nodes with
|
||
|
# nothing inbetween) where each group is keyed by its first
|
||
|
# import node.
|
||
|
self.module_import_groups: dict[Import, list[Import]] = {}
|
||
|
self._current_import_group: Import | None = None
|
||
|
|
||
|
self.errors: Errors = errors
|
||
|
|
||
|
self.current_file: MypyFile = current_file
|
||
|
|
||
|
def visit(self, o: Node) -> bool:
|
||
|
if not isinstance(o, Import):
|
||
|
self._current_import_group = None
|
||
|
return True
|
||
|
|
||
|
def visit_block(self, block: Block) -> None:
|
||
|
self._current_import_group = None
|
||
|
super().visit_block(block)
|
||
|
self._current_import_group = None
|
||
|
|
||
|
def visit_decorator(self, dec: Decorator) -> None:
|
||
|
if dec.decorators:
|
||
|
# Only add the function being decorated if there exist
|
||
|
# (ordinary) decorators in the decorator list. Certain
|
||
|
# decorators (such as @property, @abstractmethod) are
|
||
|
# special cased and removed from this list by
|
||
|
# mypy. Functions decorated only by special decorators
|
||
|
# (and property setters) are not treated as decorated
|
||
|
# functions by the IR builder.
|
||
|
if isinstance(dec.decorators[0], MemberExpr) and dec.decorators[0].name == "setter":
|
||
|
# Property setters are not treated as decorated methods.
|
||
|
self.prop_setters.add(dec.func)
|
||
|
else:
|
||
|
decorators_to_store = dec.decorators.copy()
|
||
|
if dec.func in self.decorators_to_remove:
|
||
|
to_remove = self.decorators_to_remove[dec.func]
|
||
|
|
||
|
for i in reversed(to_remove):
|
||
|
del decorators_to_store[i]
|
||
|
# if all of the decorators are removed, we shouldn't treat this as a decorated
|
||
|
# function because there aren't any decorators to apply
|
||
|
if not decorators_to_store:
|
||
|
return
|
||
|
|
||
|
self.funcs_to_decorators[dec.func] = decorators_to_store
|
||
|
super().visit_decorator(dec)
|
||
|
|
||
|
def visit_func_def(self, fdef: FuncItem) -> None:
|
||
|
# TODO: What about overloaded functions?
|
||
|
self.visit_func(fdef)
|
||
|
|
||
|
def visit_lambda_expr(self, expr: LambdaExpr) -> None:
|
||
|
self.visit_func(expr)
|
||
|
|
||
|
def visit_func(self, func: FuncItem) -> None:
|
||
|
# If there were already functions or lambda expressions
|
||
|
# defined in the function stack, then note the previous
|
||
|
# FuncItem as containing a nested function and the current
|
||
|
# FuncItem as being a nested function.
|
||
|
if self.funcs:
|
||
|
# Add the new func to the set of nested funcs within the
|
||
|
# func at top of the func stack.
|
||
|
self.encapsulating_funcs.setdefault(self.funcs[-1], []).append(func)
|
||
|
# Add the func at top of the func stack as the parent of
|
||
|
# new func.
|
||
|
self.nested_funcs[func] = self.funcs[-1]
|
||
|
|
||
|
self.funcs.append(func)
|
||
|
super().visit_func(func)
|
||
|
self.funcs.pop()
|
||
|
|
||
|
def visit_import(self, imp: Import) -> None:
|
||
|
if self._current_import_group is not None:
|
||
|
self.module_import_groups[self._current_import_group].append(imp)
|
||
|
else:
|
||
|
self.module_import_groups[imp] = [imp]
|
||
|
self._current_import_group = imp
|
||
|
super().visit_import(imp)
|
||
|
|
||
|
def visit_name_expr(self, expr: NameExpr) -> None:
|
||
|
if isinstance(expr.node, (Var, FuncDef)):
|
||
|
self.visit_symbol_node(expr.node)
|
||
|
|
||
|
def visit_var(self, var: Var) -> None:
|
||
|
self.visit_symbol_node(var)
|
||
|
|
||
|
def visit_symbol_node(self, symbol: SymbolNode) -> None:
|
||
|
if not self.funcs:
|
||
|
# We are not inside a function and hence do not need to do
|
||
|
# anything regarding free variables.
|
||
|
return
|
||
|
|
||
|
if symbol in self.symbols_to_funcs:
|
||
|
orig_func = self.symbols_to_funcs[symbol]
|
||
|
if self.is_parent(self.funcs[-1], orig_func):
|
||
|
# The function in which the symbol was previously seen is
|
||
|
# nested within the function currently being visited. Thus
|
||
|
# the current function is a better candidate to contain the
|
||
|
# declaration.
|
||
|
self.symbols_to_funcs[symbol] = self.funcs[-1]
|
||
|
# TODO: Remove from the orig_func free_variables set?
|
||
|
self.free_variables.setdefault(self.funcs[-1], set()).add(symbol)
|
||
|
|
||
|
elif self.is_parent(orig_func, self.funcs[-1]):
|
||
|
# The SymbolNode instance has already been visited
|
||
|
# before in a parent function, thus it's a non-local
|
||
|
# symbol.
|
||
|
self.add_free_variable(symbol)
|
||
|
|
||
|
else:
|
||
|
# This is the first time the SymbolNode is being
|
||
|
# visited. We map the SymbolNode to the current FuncDef
|
||
|
# being visited to note where it was first visited.
|
||
|
self.symbols_to_funcs[symbol] = self.funcs[-1]
|
||
|
|
||
|
def is_parent(self, fitem: FuncItem, child: FuncItem) -> bool:
|
||
|
# Check if child is nested within fdef (possibly indirectly
|
||
|
# within multiple nested functions).
|
||
|
if child not in self.nested_funcs:
|
||
|
return False
|
||
|
parent = self.nested_funcs[child]
|
||
|
return parent == fitem or self.is_parent(fitem, parent)
|
||
|
|
||
|
def add_free_variable(self, symbol: SymbolNode) -> None:
|
||
|
# Find the function where the symbol was (likely) first declared,
|
||
|
# and mark is as a non-local symbol within that function.
|
||
|
func = self.symbols_to_funcs[symbol]
|
||
|
self.free_variables.setdefault(func, set()).add(symbol)
|