"""Type inference constraint solving""" from __future__ import annotations from collections import defaultdict from typing import Iterable, Sequence, Tuple from typing_extensions import TypeAlias as _TypeAlias from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints from mypy.expandtype import expand_type from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort from mypy.join import join_types from mypy.meet import meet_type_list, meet_types from mypy.subtypes import is_subtype from mypy.typeops import get_all_type_vars from mypy.types import ( AnyType, Instance, NoneType, ParamSpecType, ProperType, TupleType, Type, TypeOfAny, TypeVarId, TypeVarLikeType, TypeVarTupleType, TypeVarType, UninhabitedType, UnionType, UnpackType, get_proper_type, ) from mypy.typestate import type_state Bounds: _TypeAlias = "dict[TypeVarId, set[Type]]" Graph: _TypeAlias = "set[tuple[TypeVarId, TypeVarId]]" Solutions: _TypeAlias = "dict[TypeVarId, Type | None]" def solve_constraints( original_vars: Sequence[TypeVarLikeType], constraints: list[Constraint], strict: bool = True, allow_polymorphic: bool = False, ) -> tuple[list[Type | None], list[TypeVarLikeType]]: """Solve type constraints. Return the best type(s) for type variables; each type can be None if the value of the variable could not be solved. If a variable has no constraints, if strict=True then arbitrarily pick UninhabitedType as the value of the type variable. If strict=False, pick AnyType. If allow_polymorphic=True, then use the full algorithm that can potentially return free type variables in solutions (these require special care when applying). Otherwise, use a simplified algorithm that just solves each type variable individually if possible. """ vars = [tv.id for tv in original_vars] if not vars: return [], [] originals = {tv.id: tv for tv in original_vars} extra_vars: list[TypeVarId] = [] # Get additional type variables from generic actuals. for c in constraints: extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars]) originals.update({v.id: v for v in c.extra_tvars if v.id not in originals}) # Collect a list of constraints for each type variable. cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars} for con in constraints: if con.type_var in vars + extra_vars: cmap[con.type_var].append(con) if allow_polymorphic: if constraints: solutions, free_vars = solve_with_dependent( vars + extra_vars, constraints, vars, originals ) else: solutions = {} free_vars = [] else: solutions = {} free_vars = [] for tv, cs in cmap.items(): if not cs: continue lowers = [c.target for c in cs if c.op == SUPERTYPE_OF] uppers = [c.target for c in cs if c.op == SUBTYPE_OF] solution = solve_one(lowers, uppers) # Do not leak type variables in non-polymorphic solutions. if solution is None or not get_vars( solution, [tv for tv in extra_vars if tv not in vars] ): solutions[tv] = solution res: list[Type | None] = [] for v in vars: if v in solutions: res.append(solutions[v]) else: # No constraints for type variable -- 'UninhabitedType' is the most specific type. candidate: Type if strict: candidate = UninhabitedType() candidate.ambiguous = True else: candidate = AnyType(TypeOfAny.special_form) res.append(candidate) return res, free_vars def solve_with_dependent( vars: list[TypeVarId], constraints: list[Constraint], original_vars: list[TypeVarId], originals: dict[TypeVarId, TypeVarLikeType], ) -> tuple[Solutions, list[TypeVarLikeType]]: """Solve set of constraints that may depend on each other, like T <: List[S]. The whole algorithm consists of five steps: * Propagate via linear constraints and use secondary constraints to get transitive closure * Find dependencies between type variables, group them in SCCs, and sort topologically * Check that all SCC are intrinsically linear, we can't solve (express) T <: List[T] * Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC) * Solve constraints iteratively starting from leafs, updating bounds after each step. """ graph, lowers, uppers = transitive_closure(vars, constraints) dmap = compute_dependencies(vars, graph, lowers, uppers) sccs = list(strongly_connected_components(set(vars), dmap)) if not all(check_linear(scc, lowers, uppers) for scc in sccs): return {}, [] raw_batches = list(topsort(prepare_sccs(sccs, dmap))) free_vars = [] free_solutions = {} for scc in raw_batches[0]: # If there are no bounds on this SCC, then the only meaningful solution we can # express, is that each variable is equal to a new free variable. For example, # if we have T <: S, S <: U, we deduce: T = S = U = . if all(not lowers[tv] and not uppers[tv] for tv in scc): best_free = choose_free([originals[tv] for tv in scc], original_vars) if best_free: free_vars.append(best_free.id) free_solutions[best_free.id] = best_free # Update lowers/uppers with free vars, so these can now be used # as valid solutions. for l, u in graph: if l in free_vars: lowers[u].add(free_solutions[l]) if u in free_vars: uppers[l].add(free_solutions[u]) # Flatten the SCCs that are independent, we can solve them together, # since we don't need to update any targets in between. batches = [] for batch in raw_batches: next_bc = [] for scc in batch: next_bc.extend(list(scc)) batches.append(next_bc) solutions: dict[TypeVarId, Type | None] = {} for flat_batch in batches: res = solve_iteratively(flat_batch, graph, lowers, uppers) solutions.update(res) return solutions, [free_solutions[tv] for tv in free_vars] def solve_iteratively( batch: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds ) -> Solutions: """Solve transitive closure sequentially, updating upper/lower bounds after each step. Transitive closure is represented as a linear graph plus lower/upper bounds for each type variable, see transitive_closure() docstring for details. We solve for type variables that appear in `batch`. If a bound is not constant (i.e. it looks like T :> F[S, ...]), we substitute solutions found so far in the target F[S, ...] after solving the batch. Importantly, after solving each variable in a batch, we move it from linear graph to upper/lower bounds, this way we can guarantee consistency of solutions (see comment below for an example when this is important). """ solutions = {} s_batch = set(batch) while s_batch: for tv in sorted(s_batch, key=lambda x: x.raw_id): if lowers[tv] or uppers[tv]: solvable_tv = tv break else: break # Solve each solvable type variable separately. s_batch.remove(solvable_tv) result = solve_one(lowers[solvable_tv], uppers[solvable_tv]) solutions[solvable_tv] = result if result is None: # TODO: support backtracking lower/upper bound choices and order within SCCs. # (will require switching this function from iterative to recursive). continue # Update the (transitive) bounds from graph if there is a solution. # This is needed to guarantee solutions will never contradict the initial # constraints. For example, consider {T <: S, T <: A, S :> B} with A :> B. # If we would not update the uppers/lowers from graph, we would infer T = A, S = B # which is not correct. for l, u in graph.copy(): if l == u: continue if l == solvable_tv: lowers[u].add(result) graph.remove((l, u)) if u == solvable_tv: uppers[l].add(result) graph.remove((l, u)) # We can update uppers/lowers only once after solving the whole SCC, # since uppers/lowers can't depend on type variables in the SCC # (and we would reject such SCC as non-linear and therefore not solvable). subs = {tv: s for (tv, s) in solutions.items() if s is not None} for tv in lowers: lowers[tv] = {expand_type(lt, subs) for lt in lowers[tv]} for tv in uppers: uppers[tv] = {expand_type(ut, subs) for ut in uppers[tv]} return solutions def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None: """Solve constraints by finding by using meets of upper bounds, and joins of lower bounds.""" bottom: Type | None = None top: Type | None = None candidate: Type | None = None # Process each bound separately, and calculate the lower and upper # bounds based on constraints. Note that we assume that the constraint # targets do not have constraint references. for target in lowers: if bottom is None: bottom = target else: if type_state.infer_unions: # This deviates from the general mypy semantics because # recursive types are union-heavy in 95% of cases. bottom = UnionType.make_union([bottom, target]) else: bottom = join_types(bottom, target) for target in uppers: if top is None: top = target else: top = meet_types(top, target) p_top = get_proper_type(top) p_bottom = get_proper_type(bottom) if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType): source_any = top if isinstance(p_top, AnyType) else bottom assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType) return AnyType(TypeOfAny.from_another_any, source_any=source_any) elif bottom is None: if top: candidate = top else: # No constraints for type variable return None elif top is None: candidate = bottom elif is_subtype(bottom, top): candidate = bottom else: candidate = None return candidate def choose_free( scc: list[TypeVarLikeType], original_vars: list[TypeVarId] ) -> TypeVarLikeType | None: """Choose the best solution for an SCC containing only type variables. This is needed to preserve e.g. the upper bound in a situation like this: def dec(f: Callable[[T], S]) -> Callable[[T], S]: ... @dec def test(x: U) -> U: ... where U <: A. """ if len(scc) == 1: # Fast path, choice is trivial. return scc[0] common_upper_bound = meet_type_list([t.upper_bound for t in scc]) common_upper_bound_p = get_proper_type(common_upper_bound) # We include None for when strict-optional is disabled. if isinstance(common_upper_bound_p, (UninhabitedType, NoneType)): # This will cause to infer , which is better than a free TypeVar # that has an upper bound . return None values: list[Type] = [] for tv in scc: if isinstance(tv, TypeVarType) and tv.values: if values: # It is too tricky to support multiple TypeVars with values # within the same SCC. return None values = tv.values.copy() if values and not is_trivial_bound(common_upper_bound_p): # If there are both values and upper bound present, we give up, # since type variables having both are not supported. return None # For convenience with current type application machinery, we use a stable # choice that prefers the original type variables (not polymorphic ones) in SCC. best = sorted(scc, key=lambda x: (x.id not in original_vars, x.id.raw_id))[0] if isinstance(best, TypeVarType): return best.copy_modified(values=values, upper_bound=common_upper_bound) if is_trivial_bound(common_upper_bound_p): # TODO: support more cases for ParamSpecs/TypeVarTuples return best return None def is_trivial_bound(tp: ProperType) -> bool: return isinstance(tp, Instance) and tp.type.fullname == "builtins.object" def find_linear(c: Constraint) -> Tuple[bool, TypeVarId | None]: """Find out if this constraint represent a linear relationship, return target id if yes.""" if isinstance(c.origin_type_var, TypeVarType): if isinstance(c.target, TypeVarType): return True, c.target.id if isinstance(c.origin_type_var, ParamSpecType): if isinstance(c.target, ParamSpecType) and not c.target.prefix.arg_types: return True, c.target.id if isinstance(c.origin_type_var, TypeVarTupleType): target = get_proper_type(c.target) if isinstance(target, TupleType) and len(target.items) == 1: item = target.items[0] if isinstance(item, UnpackType) and isinstance(item.type, TypeVarTupleType): return True, item.type.id return False, None def transitive_closure( tvars: list[TypeVarId], constraints: list[Constraint] ) -> tuple[Graph, Bounds, Bounds]: """Find transitive closure for given constraints on type variables. Transitive closure gives maximal set of lower/upper bounds for each type variable, such that we cannot deduce any further bounds by chaining other existing bounds. The transitive closure is represented by: * A set of lower and upper bounds for each type variable, where only constant and non-linear terms are included in the bounds. * A graph of linear constraints between type variables (represented as a set of pairs) Such separation simplifies reasoning, and allows an efficient and simple incremental transitive closure algorithm that we use here. For example if we have initial constraints [T <: S, S <: U, U <: int], the transitive closure is given by: * {} <: T <: {int} * {} <: S <: {int} * {} <: U <: {int} * {T <: S, S <: U, T <: U} """ uppers: Bounds = defaultdict(set) lowers: Bounds = defaultdict(set) graph: Graph = {(tv, tv) for tv in tvars} remaining = set(constraints) while remaining: c = remaining.pop() # Note that ParamSpec constraint P <: Q may be considered linear only if Q has no prefix, # for cases like P <: Concatenate[T, Q] we should consider this non-linear and put {P} and # {T, Q} into separate SCCs. Similarly, Ts <: Tuple[*Us] considered linear, while # Ts <: Tuple[*Us, U] is non-linear. is_linear, target_id = find_linear(c) if is_linear and target_id in tvars: assert target_id is not None if c.op == SUBTYPE_OF: lower, upper = c.type_var, target_id else: lower, upper = target_id, c.type_var if (lower, upper) in graph: continue graph |= { (l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph } for u in tvars: if (upper, u) in graph: lowers[u] |= lowers[lower] for l in tvars: if (l, lower) in graph: uppers[l] |= uppers[upper] for lt in lowers[lower]: for ut in uppers[upper]: # TODO: what if secondary constraints result in inference # against polymorphic actual (also in below branches)? remaining |= set(infer_constraints(lt, ut, SUBTYPE_OF)) remaining |= set(infer_constraints(ut, lt, SUPERTYPE_OF)) elif c.op == SUBTYPE_OF: if c.target in uppers[c.type_var]: continue for l in tvars: if (l, c.type_var) in graph: uppers[l].add(c.target) for lt in lowers[c.type_var]: remaining |= set(infer_constraints(lt, c.target, SUBTYPE_OF)) remaining |= set(infer_constraints(c.target, lt, SUPERTYPE_OF)) else: assert c.op == SUPERTYPE_OF if c.target in lowers[c.type_var]: continue for u in tvars: if (c.type_var, u) in graph: lowers[u].add(c.target) for ut in uppers[c.type_var]: remaining |= set(infer_constraints(ut, c.target, SUPERTYPE_OF)) remaining |= set(infer_constraints(c.target, ut, SUBTYPE_OF)) return graph, lowers, uppers def compute_dependencies( tvars: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds ) -> dict[TypeVarId, list[TypeVarId]]: """Compute dependencies between type variables induced by constraints. If we have a constraint like T <: List[S], we say that T depends on S, since we will need to solve for S first before we can solve for T. """ res = {} for tv in tvars: deps = set() for lt in lowers[tv]: deps |= get_vars(lt, tvars) for ut in uppers[tv]: deps |= get_vars(ut, tvars) for other in tvars: if other == tv: continue if (tv, other) in graph or (other, tv) in graph: deps.add(other) res[tv] = list(deps) return res def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool: """Check there are only linear constraints between type variables in SCC. Linear are constraints like T <: S (while T <: F[S] are non-linear). """ for tv in scc: if any(get_vars(lt, list(scc)) for lt in lowers[tv]): return False if any(get_vars(ut, list(scc)) for ut in uppers[tv]): return False return True def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]: """Find type variables for which we are solving in a target type.""" return {tv.id for tv in get_all_type_vars(target)} & set(vars)