from __future__ import annotations from enum import Enum from mypy import checker, errorcodes from mypy.messages import MessageBuilder from mypy.nodes import ( AssertStmt, AssignmentExpr, AssignmentStmt, BreakStmt, ClassDef, Context, ContinueStmt, DictionaryComprehension, Expression, ExpressionStmt, ForStmt, FuncDef, FuncItem, GeneratorExpr, GlobalDecl, IfStmt, Import, ImportFrom, LambdaExpr, ListExpr, Lvalue, MatchStmt, MypyFile, NameExpr, NonlocalDecl, RaiseStmt, ReturnStmt, StarExpr, SymbolTable, TryStmt, TupleExpr, WhileStmt, WithStmt, implicit_module_attrs, ) from mypy.options import Options from mypy.patterns import AsPattern, StarredPattern from mypy.reachability import ALWAYS_TRUE, infer_pattern_value from mypy.traverser import ExtendedTraverserVisitor from mypy.types import Type, UninhabitedType class BranchState: """BranchState contains information about variable definition at the end of a branching statement. `if` and `match` are examples of branching statements. `may_be_defined` contains variables that were defined in only some branches. `must_be_defined` contains variables that were defined in all branches. """ def __init__( self, must_be_defined: set[str] | None = None, may_be_defined: set[str] | None = None, skipped: bool = False, ) -> None: if may_be_defined is None: may_be_defined = set() if must_be_defined is None: must_be_defined = set() self.may_be_defined = set(may_be_defined) self.must_be_defined = set(must_be_defined) self.skipped = skipped def copy(self) -> BranchState: return BranchState( must_be_defined=set(self.must_be_defined), may_be_defined=set(self.may_be_defined), skipped=self.skipped, ) class BranchStatement: def __init__(self, initial_state: BranchState | None = None) -> None: if initial_state is None: initial_state = BranchState() self.initial_state = initial_state self.branches: list[BranchState] = [ BranchState( must_be_defined=self.initial_state.must_be_defined, may_be_defined=self.initial_state.may_be_defined, ) ] def copy(self) -> BranchStatement: result = BranchStatement(self.initial_state) result.branches = [b.copy() for b in self.branches] return result def next_branch(self) -> None: self.branches.append( BranchState( must_be_defined=self.initial_state.must_be_defined, may_be_defined=self.initial_state.may_be_defined, ) ) def record_definition(self, name: str) -> None: assert len(self.branches) > 0 self.branches[-1].must_be_defined.add(name) self.branches[-1].may_be_defined.discard(name) def delete_var(self, name: str) -> None: assert len(self.branches) > 0 self.branches[-1].must_be_defined.discard(name) self.branches[-1].may_be_defined.discard(name) def record_nested_branch(self, state: BranchState) -> None: assert len(self.branches) > 0 current_branch = self.branches[-1] if state.skipped: current_branch.skipped = True return current_branch.must_be_defined.update(state.must_be_defined) current_branch.may_be_defined.update(state.may_be_defined) current_branch.may_be_defined.difference_update(current_branch.must_be_defined) def skip_branch(self) -> None: assert len(self.branches) > 0 self.branches[-1].skipped = True def is_possibly_undefined(self, name: str) -> bool: assert len(self.branches) > 0 return name in self.branches[-1].may_be_defined def is_undefined(self, name: str) -> bool: assert len(self.branches) > 0 branch = self.branches[-1] return name not in branch.may_be_defined and name not in branch.must_be_defined def is_defined_in_a_branch(self, name: str) -> bool: assert len(self.branches) > 0 for b in self.branches: if name in b.must_be_defined or name in b.may_be_defined: return True return False def done(self) -> BranchState: # First, compute all vars, including skipped branches. We include skipped branches # because our goal is to capture all variables that semantic analyzer would # consider defined. all_vars = set() for b in self.branches: all_vars.update(b.may_be_defined) all_vars.update(b.must_be_defined) # For the rest of the things, we only care about branches that weren't skipped. non_skipped_branches = [b for b in self.branches if not b.skipped] if non_skipped_branches: must_be_defined = non_skipped_branches[0].must_be_defined for b in non_skipped_branches[1:]: must_be_defined.intersection_update(b.must_be_defined) else: must_be_defined = set() # Everything that wasn't defined in all branches but was defined # in at least one branch should be in `may_be_defined`! may_be_defined = all_vars.difference(must_be_defined) return BranchState( must_be_defined=must_be_defined, may_be_defined=may_be_defined, skipped=len(non_skipped_branches) == 0, ) class ScopeType(Enum): Global = 1 Class = 2 Func = 3 Generator = 4 class Scope: def __init__(self, stmts: list[BranchStatement], scope_type: ScopeType) -> None: self.branch_stmts: list[BranchStatement] = stmts self.scope_type = scope_type self.undefined_refs: dict[str, set[NameExpr]] = {} def copy(self) -> Scope: result = Scope([s.copy() for s in self.branch_stmts], self.scope_type) result.undefined_refs = self.undefined_refs.copy() return result def record_undefined_ref(self, o: NameExpr) -> None: if o.name not in self.undefined_refs: self.undefined_refs[o.name] = set() self.undefined_refs[o.name].add(o) def pop_undefined_ref(self, name: str) -> set[NameExpr]: return self.undefined_refs.pop(name, set()) class DefinedVariableTracker: """DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor.""" def __init__(self) -> None: # There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement. self.scopes: list[Scope] = [Scope([BranchStatement()], ScopeType.Global)] # disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful # in things like try/except/finally statements. self.disable_branch_skip = False def copy(self) -> DefinedVariableTracker: result = DefinedVariableTracker() result.scopes = [s.copy() for s in self.scopes] result.disable_branch_skip = self.disable_branch_skip return result def _scope(self) -> Scope: assert len(self.scopes) > 0 return self.scopes[-1] def enter_scope(self, scope_type: ScopeType) -> None: assert len(self._scope().branch_stmts) > 0 initial_state = None if scope_type == ScopeType.Generator: # Generators are special because they inherit the outer scope. initial_state = self._scope().branch_stmts[-1].branches[-1] self.scopes.append(Scope([BranchStatement(initial_state)], scope_type)) def exit_scope(self) -> None: self.scopes.pop() def in_scope(self, scope_type: ScopeType) -> bool: return self._scope().scope_type == scope_type def start_branch_statement(self) -> None: assert len(self._scope().branch_stmts) > 0 self._scope().branch_stmts.append( BranchStatement(self._scope().branch_stmts[-1].branches[-1]) ) def next_branch(self) -> None: assert len(self._scope().branch_stmts) > 1 self._scope().branch_stmts[-1].next_branch() def end_branch_statement(self) -> None: assert len(self._scope().branch_stmts) > 1 result = self._scope().branch_stmts.pop().done() self._scope().branch_stmts[-1].record_nested_branch(result) def skip_branch(self) -> None: # Only skip branch if we're outside of "root" branch statement. if len(self._scope().branch_stmts) > 1 and not self.disable_branch_skip: self._scope().branch_stmts[-1].skip_branch() def record_definition(self, name: str) -> None: assert len(self.scopes) > 0 assert len(self.scopes[-1].branch_stmts) > 0 self._scope().branch_stmts[-1].record_definition(name) def delete_var(self, name: str) -> None: assert len(self.scopes) > 0 assert len(self.scopes[-1].branch_stmts) > 0 self._scope().branch_stmts[-1].delete_var(name) def record_undefined_ref(self, o: NameExpr) -> None: """Records an undefined reference. These can later be retrieved via `pop_undefined_ref`.""" assert len(self.scopes) > 0 self._scope().record_undefined_ref(o) def pop_undefined_ref(self, name: str) -> set[NameExpr]: """If name has previously been reported as undefined, the NameExpr that was called will be returned.""" assert len(self.scopes) > 0 return self._scope().pop_undefined_ref(name) def is_possibly_undefined(self, name: str) -> bool: assert len(self._scope().branch_stmts) > 0 # A variable is undefined if it's in a set of `may_be_defined` but not in `must_be_defined`. return self._scope().branch_stmts[-1].is_possibly_undefined(name) def is_defined_in_different_branch(self, name: str) -> bool: """This will return true if a variable is defined in a branch that's not the current branch.""" assert len(self._scope().branch_stmts) > 0 stmt = self._scope().branch_stmts[-1] if not stmt.is_undefined(name): return False for stmt in self._scope().branch_stmts: if stmt.is_defined_in_a_branch(name): return True return False def is_undefined(self, name: str) -> bool: assert len(self._scope().branch_stmts) > 0 return self._scope().branch_stmts[-1].is_undefined(name) class Loop: def __init__(self) -> None: self.has_break = False class PossiblyUndefinedVariableVisitor(ExtendedTraverserVisitor): """Detects the following cases: - A variable that's defined only part of the time. - If a variable is used before definition An example of a partial definition: if foo(): x = 1 print(x) # Error: "x" may be undefined. Example of a used before definition: x = y y: int = 2 Note that this code does not detect variables not defined in any of the branches -- that is handled by the semantic analyzer. """ def __init__( self, msg: MessageBuilder, type_map: dict[Expression, Type], options: Options, names: SymbolTable, ) -> None: self.msg = msg self.type_map = type_map self.options = options self.builtins = SymbolTable() builtins_mod = names.get("__builtins__", None) if builtins_mod: assert isinstance(builtins_mod.node, MypyFile) self.builtins = builtins_mod.node.names self.loops: list[Loop] = [] self.try_depth = 0 self.tracker = DefinedVariableTracker() for name in implicit_module_attrs: self.tracker.record_definition(name) def var_used_before_def(self, name: str, context: Context) -> None: if self.msg.errors.is_error_code_enabled(errorcodes.USED_BEFORE_DEF): self.msg.var_used_before_def(name, context) def variable_may_be_undefined(self, name: str, context: Context) -> None: if self.msg.errors.is_error_code_enabled(errorcodes.POSSIBLY_UNDEFINED): self.msg.variable_may_be_undefined(name, context) def process_definition(self, name: str) -> None: # Was this name previously used? If yes, it's a used-before-definition error. if not self.tracker.in_scope(ScopeType.Class): refs = self.tracker.pop_undefined_ref(name) for ref in refs: if self.loops: self.variable_may_be_undefined(name, ref) else: self.var_used_before_def(name, ref) else: # Errors in class scopes are caught by the semantic analyzer. pass self.tracker.record_definition(name) def visit_global_decl(self, o: GlobalDecl) -> None: for name in o.names: self.process_definition(name) super().visit_global_decl(o) def visit_nonlocal_decl(self, o: NonlocalDecl) -> None: for name in o.names: self.process_definition(name) super().visit_nonlocal_decl(o) def process_lvalue(self, lvalue: Lvalue | None) -> None: if isinstance(lvalue, NameExpr): self.process_definition(lvalue.name) elif isinstance(lvalue, StarExpr): self.process_lvalue(lvalue.expr) elif isinstance(lvalue, (ListExpr, TupleExpr)): for item in lvalue.items: self.process_lvalue(item) def visit_assignment_stmt(self, o: AssignmentStmt) -> None: for lvalue in o.lvalues: self.process_lvalue(lvalue) super().visit_assignment_stmt(o) def visit_assignment_expr(self, o: AssignmentExpr) -> None: o.value.accept(self) self.process_lvalue(o.target) def visit_if_stmt(self, o: IfStmt) -> None: for e in o.expr: e.accept(self) self.tracker.start_branch_statement() for b in o.body: if b.is_unreachable: continue b.accept(self) self.tracker.next_branch() if o.else_body: if not o.else_body.is_unreachable: o.else_body.accept(self) else: self.tracker.skip_branch() self.tracker.end_branch_statement() def visit_match_stmt(self, o: MatchStmt) -> None: o.subject.accept(self) self.tracker.start_branch_statement() for i in range(len(o.patterns)): pattern = o.patterns[i] pattern.accept(self) guard = o.guards[i] if guard is not None: guard.accept(self) if not o.bodies[i].is_unreachable: o.bodies[i].accept(self) else: self.tracker.skip_branch() is_catchall = infer_pattern_value(pattern) == ALWAYS_TRUE if not is_catchall: self.tracker.next_branch() self.tracker.end_branch_statement() def visit_func_def(self, o: FuncDef) -> None: self.process_definition(o.name) super().visit_func_def(o) def visit_func(self, o: FuncItem) -> None: if o.is_dynamic() and not self.options.check_untyped_defs: return args = o.arguments or [] # Process initializers (defaults) outside the function scope. for arg in args: if arg.initializer is not None: arg.initializer.accept(self) self.tracker.enter_scope(ScopeType.Func) for arg in args: self.process_definition(arg.variable.name) super().visit_var(arg.variable) o.body.accept(self) self.tracker.exit_scope() def visit_generator_expr(self, o: GeneratorExpr) -> None: self.tracker.enter_scope(ScopeType.Generator) for idx in o.indices: self.process_lvalue(idx) super().visit_generator_expr(o) self.tracker.exit_scope() def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None: self.tracker.enter_scope(ScopeType.Generator) for idx in o.indices: self.process_lvalue(idx) super().visit_dictionary_comprehension(o) self.tracker.exit_scope() def visit_for_stmt(self, o: ForStmt) -> None: o.expr.accept(self) self.process_lvalue(o.index) o.index.accept(self) self.tracker.start_branch_statement() loop = Loop() self.loops.append(loop) o.body.accept(self) self.tracker.next_branch() self.tracker.end_branch_statement() if o.else_body is not None: # If the loop has a `break` inside, `else` is executed conditionally. # If the loop doesn't have a `break` either the function will return or # execute the `else`. has_break = loop.has_break if has_break: self.tracker.start_branch_statement() self.tracker.next_branch() o.else_body.accept(self) if has_break: self.tracker.end_branch_statement() self.loops.pop() def visit_return_stmt(self, o: ReturnStmt) -> None: super().visit_return_stmt(o) self.tracker.skip_branch() def visit_lambda_expr(self, o: LambdaExpr) -> None: self.tracker.enter_scope(ScopeType.Func) super().visit_lambda_expr(o) self.tracker.exit_scope() def visit_assert_stmt(self, o: AssertStmt) -> None: super().visit_assert_stmt(o) if checker.is_false_literal(o.expr): self.tracker.skip_branch() def visit_raise_stmt(self, o: RaiseStmt) -> None: super().visit_raise_stmt(o) self.tracker.skip_branch() def visit_continue_stmt(self, o: ContinueStmt) -> None: super().visit_continue_stmt(o) self.tracker.skip_branch() def visit_break_stmt(self, o: BreakStmt) -> None: super().visit_break_stmt(o) if self.loops: self.loops[-1].has_break = True self.tracker.skip_branch() def visit_expression_stmt(self, o: ExpressionStmt) -> None: if isinstance(self.type_map.get(o.expr, None), UninhabitedType): self.tracker.skip_branch() super().visit_expression_stmt(o) def visit_try_stmt(self, o: TryStmt) -> None: """ Note that finding undefined vars in `finally` requires different handling from the rest of the code. In particular, we want to disallow skipping branches due to jump statements in except/else clauses for finally but not for other cases. Imagine a case like: def f() -> int: try: x = 1 except: # This jump statement needs to be handled differently depending on whether or # not we're trying to process `finally` or not. return 0 finally: # `x` may be undefined here. pass # `x` is always defined here. return x """ self.try_depth += 1 if o.finally_body is not None: # In order to find undefined vars in `finally`, we need to # process try/except with branch skipping disabled. However, for the rest of the code # after finally, we need to process try/except with branch skipping enabled. # Therefore, we need to process try/finally twice. # Because processing is not idempotent, we should make a copy of the tracker. old_tracker = self.tracker.copy() self.tracker.disable_branch_skip = True self.process_try_stmt(o) self.tracker = old_tracker self.process_try_stmt(o) self.try_depth -= 1 def process_try_stmt(self, o: TryStmt) -> None: """ Processes try statement decomposing it into the following: if ...: body else_body elif ...: except 1 elif ...: except 2 else: except n finally """ self.tracker.start_branch_statement() o.body.accept(self) if o.else_body is not None: o.else_body.accept(self) if len(o.handlers) > 0: assert len(o.handlers) == len(o.vars) == len(o.types) for i in range(len(o.handlers)): self.tracker.next_branch() exc_type = o.types[i] if exc_type is not None: exc_type.accept(self) var = o.vars[i] if var is not None: self.process_definition(var.name) var.accept(self) o.handlers[i].accept(self) if var is not None: self.tracker.delete_var(var.name) self.tracker.end_branch_statement() if o.finally_body is not None: o.finally_body.accept(self) def visit_while_stmt(self, o: WhileStmt) -> None: o.expr.accept(self) self.tracker.start_branch_statement() loop = Loop() self.loops.append(loop) o.body.accept(self) has_break = loop.has_break if not checker.is_true_literal(o.expr): # If this is a loop like `while True`, we can consider the body to be # a single branch statement (we're guaranteed that the body is executed at least once). # If not, call next_branch() to make all variables defined there conditional. self.tracker.next_branch() self.tracker.end_branch_statement() if o.else_body is not None: # If the loop has a `break` inside, `else` is executed conditionally. # If the loop doesn't have a `break` either the function will return or # execute the `else`. if has_break: self.tracker.start_branch_statement() self.tracker.next_branch() if o.else_body: o.else_body.accept(self) if has_break: self.tracker.end_branch_statement() self.loops.pop() def visit_as_pattern(self, o: AsPattern) -> None: if o.name is not None: self.process_lvalue(o.name) super().visit_as_pattern(o) def visit_starred_pattern(self, o: StarredPattern) -> None: if o.capture is not None: self.process_lvalue(o.capture) super().visit_starred_pattern(o) def visit_name_expr(self, o: NameExpr) -> None: if o.name in self.builtins and self.tracker.in_scope(ScopeType.Global): return if self.tracker.is_possibly_undefined(o.name): # A variable is only defined in some branches. self.variable_may_be_undefined(o.name, o) # We don't want to report the error on the same variable multiple times. self.tracker.record_definition(o.name) elif self.tracker.is_defined_in_different_branch(o.name): # A variable is defined in one branch but used in a different branch. if self.loops or self.try_depth > 0: # If we're in a loop or in a try, we can't be sure that this variable # is undefined. Report it as "may be undefined". self.variable_may_be_undefined(o.name, o) else: self.var_used_before_def(o.name, o) elif self.tracker.is_undefined(o.name): # A variable is undefined. It could be due to two things: # 1. A variable is just totally undefined # 2. The variable is defined later in the code. # Case (1) will be caught by semantic analyzer. Case (2) is a forward ref that should # be caught by this visitor. Save the ref for later, so that if we see a definition, # we know it's a used-before-definition scenario. self.tracker.record_undefined_ref(o) super().visit_name_expr(o) def visit_with_stmt(self, o: WithStmt) -> None: for expr, idx in zip(o.expr, o.target): expr.accept(self) self.process_lvalue(idx) o.body.accept(self) def visit_class_def(self, o: ClassDef) -> None: self.process_definition(o.name) self.tracker.enter_scope(ScopeType.Class) super().visit_class_def(o) self.tracker.exit_scope() def visit_import(self, o: Import) -> None: for mod, alias in o.ids: if alias is not None: self.tracker.record_definition(alias) else: # When you do `import x.y`, only `x` becomes defined. names = mod.split(".") if names: # `names` should always be nonempty, but we don't want mypy # to crash on invalid code. self.tracker.record_definition(names[0]) super().visit_import(o) def visit_import_from(self, o: ImportFrom) -> None: for mod, alias in o.names: name = alias if name is None: name = mod self.tracker.record_definition(name) super().visit_import_from(o)