from __future__ import annotations from mypy.nodes import ( Block, Decorator, Expression, FuncDef, FuncItem, Import, LambdaExpr, MemberExpr, MypyFile, NameExpr, Node, SymbolNode, Var, ) from mypy.traverser import ExtendedTraverserVisitor from mypyc.errors import Errors class PreBuildVisitor(ExtendedTraverserVisitor): """Mypy file AST visitor run before building the IR. This collects various things, including: * Determine relationships between nested functions and functions that contain nested functions * Find non-local variables (free variables) * Find property setters * Find decorators of functions * Find module import groups The main IR build pass uses this information. """ def __init__( self, errors: Errors, current_file: MypyFile, decorators_to_remove: dict[FuncDef, list[int]], ) -> None: super().__init__() # Dict from a function to symbols defined directly in the # function that are used as non-local (free) variables within a # nested function. self.free_variables: dict[FuncItem, set[SymbolNode]] = {} # Intermediate data structure used to find the function where # a SymbolNode is declared. Initially this may point to a # function nested inside the function with the declaration, # but we'll eventually update this to refer to the function # with the declaration. self.symbols_to_funcs: dict[SymbolNode, FuncItem] = {} # Stack representing current function nesting. self.funcs: list[FuncItem] = [] # All property setters encountered so far. self.prop_setters: set[FuncDef] = set() # A map from any function that contains nested functions to # a set of all the functions that are nested within it. self.encapsulating_funcs: dict[FuncItem, list[FuncItem]] = {} # Map nested function to its parent/encapsulating function. self.nested_funcs: dict[FuncItem, FuncItem] = {} # Map function to its non-special decorators. self.funcs_to_decorators: dict[FuncDef, list[Expression]] = {} # Map function to indices of decorators to remove self.decorators_to_remove: dict[FuncDef, list[int]] = decorators_to_remove # A mapping of import groups (a series of Import nodes with # nothing inbetween) where each group is keyed by its first # import node. self.module_import_groups: dict[Import, list[Import]] = {} self._current_import_group: Import | None = None self.errors: Errors = errors self.current_file: MypyFile = current_file def visit(self, o: Node) -> bool: if not isinstance(o, Import): self._current_import_group = None return True def visit_block(self, block: Block) -> None: self._current_import_group = None super().visit_block(block) self._current_import_group = None def visit_decorator(self, dec: Decorator) -> None: if dec.decorators: # Only add the function being decorated if there exist # (ordinary) decorators in the decorator list. Certain # decorators (such as @property, @abstractmethod) are # special cased and removed from this list by # mypy. Functions decorated only by special decorators # (and property setters) are not treated as decorated # functions by the IR builder. if isinstance(dec.decorators[0], MemberExpr) and dec.decorators[0].name == "setter": # Property setters are not treated as decorated methods. self.prop_setters.add(dec.func) else: decorators_to_store = dec.decorators.copy() if dec.func in self.decorators_to_remove: to_remove = self.decorators_to_remove[dec.func] for i in reversed(to_remove): del decorators_to_store[i] # if all of the decorators are removed, we shouldn't treat this as a decorated # function because there aren't any decorators to apply if not decorators_to_store: return self.funcs_to_decorators[dec.func] = decorators_to_store super().visit_decorator(dec) def visit_func_def(self, fdef: FuncItem) -> None: # TODO: What about overloaded functions? self.visit_func(fdef) def visit_lambda_expr(self, expr: LambdaExpr) -> None: self.visit_func(expr) def visit_func(self, func: FuncItem) -> None: # If there were already functions or lambda expressions # defined in the function stack, then note the previous # FuncItem as containing a nested function and the current # FuncItem as being a nested function. if self.funcs: # Add the new func to the set of nested funcs within the # func at top of the func stack. self.encapsulating_funcs.setdefault(self.funcs[-1], []).append(func) # Add the func at top of the func stack as the parent of # new func. self.nested_funcs[func] = self.funcs[-1] self.funcs.append(func) super().visit_func(func) self.funcs.pop() def visit_import(self, imp: Import) -> None: if self._current_import_group is not None: self.module_import_groups[self._current_import_group].append(imp) else: self.module_import_groups[imp] = [imp] self._current_import_group = imp super().visit_import(imp) def visit_name_expr(self, expr: NameExpr) -> None: if isinstance(expr.node, (Var, FuncDef)): self.visit_symbol_node(expr.node) def visit_var(self, var: Var) -> None: self.visit_symbol_node(var) def visit_symbol_node(self, symbol: SymbolNode) -> None: if not self.funcs: # We are not inside a function and hence do not need to do # anything regarding free variables. return if symbol in self.symbols_to_funcs: orig_func = self.symbols_to_funcs[symbol] if self.is_parent(self.funcs[-1], orig_func): # The function in which the symbol was previously seen is # nested within the function currently being visited. Thus # the current function is a better candidate to contain the # declaration. self.symbols_to_funcs[symbol] = self.funcs[-1] # TODO: Remove from the orig_func free_variables set? self.free_variables.setdefault(self.funcs[-1], set()).add(symbol) elif self.is_parent(orig_func, self.funcs[-1]): # The SymbolNode instance has already been visited # before in a parent function, thus it's a non-local # symbol. self.add_free_variable(symbol) else: # This is the first time the SymbolNode is being # visited. We map the SymbolNode to the current FuncDef # being visited to note where it was first visited. self.symbols_to_funcs[symbol] = self.funcs[-1] def is_parent(self, fitem: FuncItem, child: FuncItem) -> bool: # Check if child is nested within fdef (possibly indirectly # within multiple nested functions). if child not in self.nested_funcs: return False parent = self.nested_funcs[child] return parent == fitem or self.is_parent(fitem, parent) def add_free_variable(self, symbol: SymbolNode) -> None: # Find the function where the symbol was (likely) first declared, # and mark is as a non-local symbol within that function. func = self.symbols_to_funcs[symbol] self.free_variables.setdefault(func, set()).add(symbol)