"""Fix up various things after deserialization.""" from __future__ import annotations from typing import Any, Final from mypy.lookup import lookup_fully_qualified from mypy.nodes import ( Block, ClassDef, Decorator, FuncDef, MypyFile, OverloadedFuncDef, ParamSpecExpr, SymbolTable, TypeAlias, TypeInfo, TypeVarExpr, TypeVarTupleExpr, Var, ) from mypy.types import ( NOT_READY, AnyType, CallableType, Instance, LiteralType, Overloaded, Parameters, ParamSpecType, TupleType, TypeAliasType, TypedDictType, TypeOfAny, TypeType, TypeVarTupleType, TypeVarType, TypeVisitor, UnboundType, UnionType, UnpackType, ) from mypy.visitor import NodeVisitor # N.B: we do a allow_missing fixup when fixing up a fine-grained # incremental cache load (since there may be cross-refs into deleted # modules) def fixup_module(tree: MypyFile, modules: dict[str, MypyFile], allow_missing: bool) -> None: node_fixer = NodeFixer(modules, allow_missing) node_fixer.visit_symbol_table(tree.names, tree.fullname) # TODO: Fix up .info when deserializing, i.e. much earlier. class NodeFixer(NodeVisitor[None]): current_info: TypeInfo | None = None def __init__(self, modules: dict[str, MypyFile], allow_missing: bool) -> None: self.modules = modules self.allow_missing = allow_missing self.type_fixer = TypeFixer(self.modules, allow_missing) # NOTE: This method isn't (yet) part of the NodeVisitor API. def visit_type_info(self, info: TypeInfo) -> None: save_info = self.current_info try: self.current_info = info if info.defn: info.defn.accept(self) if info.names: self.visit_symbol_table(info.names, info.fullname) if info.bases: for base in info.bases: base.accept(self.type_fixer) if info._promote: for p in info._promote: p.accept(self.type_fixer) if info.tuple_type: info.tuple_type.accept(self.type_fixer) info.update_tuple_type(info.tuple_type) if info.special_alias: info.special_alias.alias_tvars = list(info.defn.type_vars) if info.typeddict_type: info.typeddict_type.accept(self.type_fixer) info.update_typeddict_type(info.typeddict_type) if info.special_alias: info.special_alias.alias_tvars = list(info.defn.type_vars) if info.declared_metaclass: info.declared_metaclass.accept(self.type_fixer) if info.metaclass_type: info.metaclass_type.accept(self.type_fixer) if info.alt_promote: info.alt_promote.accept(self.type_fixer) instance = Instance(info, []) # Hack: We may also need to add a backwards promotion (from int to native int), # since it might not be serialized. if instance not in info.alt_promote.type._promote: info.alt_promote.type._promote.append(instance) if info._mro_refs: info.mro = [ lookup_fully_qualified_typeinfo( self.modules, name, allow_missing=self.allow_missing ) for name in info._mro_refs ] info._mro_refs = None finally: self.current_info = save_info # NOTE: This method *definitely* isn't part of the NodeVisitor API. def visit_symbol_table(self, symtab: SymbolTable, table_fullname: str) -> None: # Copy the items because we may mutate symtab. for key, value in list(symtab.items()): cross_ref = value.cross_ref if cross_ref is not None: # Fix up cross-reference. value.cross_ref = None if cross_ref in self.modules: value.node = self.modules[cross_ref] else: stnode = lookup_fully_qualified( cross_ref, self.modules, raise_on_missing=not self.allow_missing ) if stnode is not None: assert stnode.node is not None, (table_fullname + "." + key, cross_ref) value.node = stnode.node elif not self.allow_missing: assert False, f"Could not find cross-ref {cross_ref}" else: # We have a missing crossref in allow missing mode, need to put something value.node = missing_info(self.modules) else: if isinstance(value.node, TypeInfo): # TypeInfo has no accept(). TODO: Add it? self.visit_type_info(value.node) elif value.node is not None: value.node.accept(self) else: assert False, f"Unexpected empty node {key!r}: {value}" def visit_func_def(self, func: FuncDef) -> None: if self.current_info is not None: func.info = self.current_info if func.type is not None: func.type.accept(self.type_fixer) def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None: if self.current_info is not None: o.info = self.current_info if o.type: o.type.accept(self.type_fixer) for item in o.items: item.accept(self) if o.impl: o.impl.accept(self) def visit_decorator(self, d: Decorator) -> None: if self.current_info is not None: d.var.info = self.current_info if d.func: d.func.accept(self) if d.var: d.var.accept(self) for node in d.decorators: node.accept(self) def visit_class_def(self, c: ClassDef) -> None: for v in c.type_vars: if isinstance(v, TypeVarType): for value in v.values: value.accept(self.type_fixer) v.upper_bound.accept(self.type_fixer) v.default.accept(self.type_fixer) def visit_type_var_expr(self, tv: TypeVarExpr) -> None: for value in tv.values: value.accept(self.type_fixer) tv.upper_bound.accept(self.type_fixer) tv.default.accept(self.type_fixer) def visit_paramspec_expr(self, p: ParamSpecExpr) -> None: p.upper_bound.accept(self.type_fixer) p.default.accept(self.type_fixer) def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None: tv.upper_bound.accept(self.type_fixer) tv.default.accept(self.type_fixer) def visit_var(self, v: Var) -> None: if self.current_info is not None: v.info = self.current_info if v.type is not None: v.type.accept(self.type_fixer) def visit_type_alias(self, a: TypeAlias) -> None: a.target.accept(self.type_fixer) for v in a.alias_tvars: v.accept(self.type_fixer) class TypeFixer(TypeVisitor[None]): def __init__(self, modules: dict[str, MypyFile], allow_missing: bool) -> None: self.modules = modules self.allow_missing = allow_missing def visit_instance(self, inst: Instance) -> None: # TODO: Combine Instances that are exactly the same? type_ref = inst.type_ref if type_ref is None: return # We've already been here. inst.type_ref = None inst.type = lookup_fully_qualified_typeinfo( self.modules, type_ref, allow_missing=self.allow_missing ) # TODO: Is this needed or redundant? # Also fix up the bases, just in case. for base in inst.type.bases: if base.type is NOT_READY: base.accept(self) for a in inst.args: a.accept(self) if inst.last_known_value is not None: inst.last_known_value.accept(self) def visit_type_alias_type(self, t: TypeAliasType) -> None: type_ref = t.type_ref if type_ref is None: return # We've already been here. t.type_ref = None t.alias = lookup_fully_qualified_alias( self.modules, type_ref, allow_missing=self.allow_missing ) for a in t.args: a.accept(self) def visit_any(self, o: Any) -> None: pass # Nothing to descend into. def visit_callable_type(self, ct: CallableType) -> None: if ct.fallback: ct.fallback.accept(self) for argt in ct.arg_types: # argt may be None, e.g. for __self in NamedTuple constructors. if argt is not None: argt.accept(self) if ct.ret_type is not None: ct.ret_type.accept(self) for v in ct.variables: v.accept(self) for arg in ct.bound_args: if arg: arg.accept(self) if ct.type_guard is not None: ct.type_guard.accept(self) def visit_overloaded(self, t: Overloaded) -> None: for ct in t.items: ct.accept(self) def visit_erased_type(self, o: Any) -> None: # This type should exist only temporarily during type inference raise RuntimeError("Shouldn't get here", o) def visit_deleted_type(self, o: Any) -> None: pass # Nothing to descend into. def visit_none_type(self, o: Any) -> None: pass # Nothing to descend into. def visit_uninhabited_type(self, o: Any) -> None: pass # Nothing to descend into. def visit_partial_type(self, o: Any) -> None: raise RuntimeError("Shouldn't get here", o) def visit_tuple_type(self, tt: TupleType) -> None: if tt.items: for it in tt.items: it.accept(self) if tt.partial_fallback is not None: tt.partial_fallback.accept(self) def visit_typeddict_type(self, tdt: TypedDictType) -> None: if tdt.items: for it in tdt.items.values(): it.accept(self) if tdt.fallback is not None: if tdt.fallback.type_ref is not None: if ( lookup_fully_qualified( tdt.fallback.type_ref, self.modules, raise_on_missing=not self.allow_missing, ) is None ): # We reject fake TypeInfos for TypedDict fallbacks because # the latter are used in type checking and must be valid. tdt.fallback.type_ref = "typing._TypedDict" tdt.fallback.accept(self) def visit_literal_type(self, lt: LiteralType) -> None: lt.fallback.accept(self) def visit_type_var(self, tvt: TypeVarType) -> None: if tvt.values: for vt in tvt.values: vt.accept(self) tvt.upper_bound.accept(self) tvt.default.accept(self) def visit_param_spec(self, p: ParamSpecType) -> None: p.upper_bound.accept(self) p.default.accept(self) def visit_type_var_tuple(self, t: TypeVarTupleType) -> None: t.upper_bound.accept(self) t.default.accept(self) def visit_unpack_type(self, u: UnpackType) -> None: u.type.accept(self) def visit_parameters(self, p: Parameters) -> None: for argt in p.arg_types: if argt is not None: argt.accept(self) for var in p.variables: var.accept(self) def visit_unbound_type(self, o: UnboundType) -> None: for a in o.args: a.accept(self) def visit_union_type(self, ut: UnionType) -> None: if ut.items: for it in ut.items: it.accept(self) def visit_void(self, o: Any) -> None: pass # Nothing to descend into. def visit_type_type(self, t: TypeType) -> None: t.item.accept(self) def lookup_fully_qualified_typeinfo( modules: dict[str, MypyFile], name: str, *, allow_missing: bool ) -> TypeInfo: stnode = lookup_fully_qualified(name, modules, raise_on_missing=not allow_missing) node = stnode.node if stnode else None if isinstance(node, TypeInfo): return node else: # Looks like a missing TypeInfo during an initial daemon load, put something there assert ( allow_missing ), "Should never get here in normal mode, got {}:{} instead of TypeInfo".format( type(node).__name__, node.fullname if node else "" ) return missing_info(modules) def lookup_fully_qualified_alias( modules: dict[str, MypyFile], name: str, *, allow_missing: bool ) -> TypeAlias: stnode = lookup_fully_qualified(name, modules, raise_on_missing=not allow_missing) node = stnode.node if stnode else None if isinstance(node, TypeAlias): return node elif isinstance(node, TypeInfo): if node.special_alias: # Already fixed up. return node.special_alias if node.tuple_type: alias = TypeAlias.from_tuple_type(node) elif node.typeddict_type: alias = TypeAlias.from_typeddict_type(node) else: assert allow_missing return missing_alias() node.special_alias = alias return alias else: # Looks like a missing TypeAlias during an initial daemon load, put something there assert ( allow_missing ), "Should never get here in normal mode, got {}:{} instead of TypeAlias".format( type(node).__name__, node.fullname if node else "" ) return missing_alias() _SUGGESTION: Final = "" def missing_info(modules: dict[str, MypyFile]) -> TypeInfo: suggestion = _SUGGESTION.format("info") dummy_def = ClassDef(suggestion, Block([])) dummy_def.fullname = suggestion info = TypeInfo(SymbolTable(), dummy_def, "") obj_type = lookup_fully_qualified_typeinfo(modules, "builtins.object", allow_missing=False) info.bases = [Instance(obj_type, [])] info.mro = [info, obj_type] return info def missing_alias() -> TypeAlias: suggestion = _SUGGESTION.format("alias") return TypeAlias(AnyType(TypeOfAny.special_form), suggestion, line=-1, column=-1)