Tipragot
628be439b8
Cela permet de ne pas avoir de problèmes de compatibilité car python est dans le git.
472 lines
18 KiB
Python
472 lines
18 KiB
Python
"""Type inference constraint solving"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections import defaultdict
|
|
from typing import Iterable, Sequence, Tuple
|
|
from typing_extensions import TypeAlias as _TypeAlias
|
|
|
|
from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints
|
|
from mypy.expandtype import expand_type
|
|
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
|
|
from mypy.join import join_types
|
|
from mypy.meet import meet_type_list, meet_types
|
|
from mypy.subtypes import is_subtype
|
|
from mypy.typeops import get_all_type_vars
|
|
from mypy.types import (
|
|
AnyType,
|
|
Instance,
|
|
NoneType,
|
|
ParamSpecType,
|
|
ProperType,
|
|
TupleType,
|
|
Type,
|
|
TypeOfAny,
|
|
TypeVarId,
|
|
TypeVarLikeType,
|
|
TypeVarTupleType,
|
|
TypeVarType,
|
|
UninhabitedType,
|
|
UnionType,
|
|
UnpackType,
|
|
get_proper_type,
|
|
)
|
|
from mypy.typestate import type_state
|
|
|
|
Bounds: _TypeAlias = "dict[TypeVarId, set[Type]]"
|
|
Graph: _TypeAlias = "set[tuple[TypeVarId, TypeVarId]]"
|
|
Solutions: _TypeAlias = "dict[TypeVarId, Type | None]"
|
|
|
|
|
|
def solve_constraints(
|
|
original_vars: Sequence[TypeVarLikeType],
|
|
constraints: list[Constraint],
|
|
strict: bool = True,
|
|
allow_polymorphic: bool = False,
|
|
) -> tuple[list[Type | None], list[TypeVarLikeType]]:
|
|
"""Solve type constraints.
|
|
|
|
Return the best type(s) for type variables; each type can be None if the value of
|
|
the variable could not be solved.
|
|
|
|
If a variable has no constraints, if strict=True then arbitrarily
|
|
pick UninhabitedType as the value of the type variable. If strict=False, pick AnyType.
|
|
If allow_polymorphic=True, then use the full algorithm that can potentially return
|
|
free type variables in solutions (these require special care when applying). Otherwise,
|
|
use a simplified algorithm that just solves each type variable individually if possible.
|
|
"""
|
|
vars = [tv.id for tv in original_vars]
|
|
if not vars:
|
|
return [], []
|
|
|
|
originals = {tv.id: tv for tv in original_vars}
|
|
extra_vars: list[TypeVarId] = []
|
|
# Get additional type variables from generic actuals.
|
|
for c in constraints:
|
|
extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars])
|
|
originals.update({v.id: v for v in c.extra_tvars if v.id not in originals})
|
|
|
|
# Collect a list of constraints for each type variable.
|
|
cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars}
|
|
for con in constraints:
|
|
if con.type_var in vars + extra_vars:
|
|
cmap[con.type_var].append(con)
|
|
|
|
if allow_polymorphic:
|
|
if constraints:
|
|
solutions, free_vars = solve_with_dependent(
|
|
vars + extra_vars, constraints, vars, originals
|
|
)
|
|
else:
|
|
solutions = {}
|
|
free_vars = []
|
|
else:
|
|
solutions = {}
|
|
free_vars = []
|
|
for tv, cs in cmap.items():
|
|
if not cs:
|
|
continue
|
|
lowers = [c.target for c in cs if c.op == SUPERTYPE_OF]
|
|
uppers = [c.target for c in cs if c.op == SUBTYPE_OF]
|
|
solution = solve_one(lowers, uppers)
|
|
|
|
# Do not leak type variables in non-polymorphic solutions.
|
|
if solution is None or not get_vars(
|
|
solution, [tv for tv in extra_vars if tv not in vars]
|
|
):
|
|
solutions[tv] = solution
|
|
|
|
res: list[Type | None] = []
|
|
for v in vars:
|
|
if v in solutions:
|
|
res.append(solutions[v])
|
|
else:
|
|
# No constraints for type variable -- 'UninhabitedType' is the most specific type.
|
|
candidate: Type
|
|
if strict:
|
|
candidate = UninhabitedType()
|
|
candidate.ambiguous = True
|
|
else:
|
|
candidate = AnyType(TypeOfAny.special_form)
|
|
res.append(candidate)
|
|
return res, free_vars
|
|
|
|
|
|
def solve_with_dependent(
|
|
vars: list[TypeVarId],
|
|
constraints: list[Constraint],
|
|
original_vars: list[TypeVarId],
|
|
originals: dict[TypeVarId, TypeVarLikeType],
|
|
) -> tuple[Solutions, list[TypeVarLikeType]]:
|
|
"""Solve set of constraints that may depend on each other, like T <: List[S].
|
|
|
|
The whole algorithm consists of five steps:
|
|
* Propagate via linear constraints and use secondary constraints to get transitive closure
|
|
* Find dependencies between type variables, group them in SCCs, and sort topologically
|
|
* Check that all SCC are intrinsically linear, we can't solve (express) T <: List[T]
|
|
* Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC)
|
|
* Solve constraints iteratively starting from leafs, updating bounds after each step.
|
|
"""
|
|
graph, lowers, uppers = transitive_closure(vars, constraints)
|
|
|
|
dmap = compute_dependencies(vars, graph, lowers, uppers)
|
|
sccs = list(strongly_connected_components(set(vars), dmap))
|
|
if not all(check_linear(scc, lowers, uppers) for scc in sccs):
|
|
return {}, []
|
|
raw_batches = list(topsort(prepare_sccs(sccs, dmap)))
|
|
|
|
free_vars = []
|
|
free_solutions = {}
|
|
for scc in raw_batches[0]:
|
|
# If there are no bounds on this SCC, then the only meaningful solution we can
|
|
# express, is that each variable is equal to a new free variable. For example,
|
|
# if we have T <: S, S <: U, we deduce: T = S = U = <free>.
|
|
if all(not lowers[tv] and not uppers[tv] for tv in scc):
|
|
best_free = choose_free([originals[tv] for tv in scc], original_vars)
|
|
if best_free:
|
|
free_vars.append(best_free.id)
|
|
free_solutions[best_free.id] = best_free
|
|
|
|
# Update lowers/uppers with free vars, so these can now be used
|
|
# as valid solutions.
|
|
for l, u in graph:
|
|
if l in free_vars:
|
|
lowers[u].add(free_solutions[l])
|
|
if u in free_vars:
|
|
uppers[l].add(free_solutions[u])
|
|
|
|
# Flatten the SCCs that are independent, we can solve them together,
|
|
# since we don't need to update any targets in between.
|
|
batches = []
|
|
for batch in raw_batches:
|
|
next_bc = []
|
|
for scc in batch:
|
|
next_bc.extend(list(scc))
|
|
batches.append(next_bc)
|
|
|
|
solutions: dict[TypeVarId, Type | None] = {}
|
|
for flat_batch in batches:
|
|
res = solve_iteratively(flat_batch, graph, lowers, uppers)
|
|
solutions.update(res)
|
|
return solutions, [free_solutions[tv] for tv in free_vars]
|
|
|
|
|
|
def solve_iteratively(
|
|
batch: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds
|
|
) -> Solutions:
|
|
"""Solve transitive closure sequentially, updating upper/lower bounds after each step.
|
|
|
|
Transitive closure is represented as a linear graph plus lower/upper bounds for each
|
|
type variable, see transitive_closure() docstring for details.
|
|
|
|
We solve for type variables that appear in `batch`. If a bound is not constant (i.e. it
|
|
looks like T :> F[S, ...]), we substitute solutions found so far in the target F[S, ...]
|
|
after solving the batch.
|
|
|
|
Importantly, after solving each variable in a batch, we move it from linear graph to
|
|
upper/lower bounds, this way we can guarantee consistency of solutions (see comment below
|
|
for an example when this is important).
|
|
"""
|
|
solutions = {}
|
|
s_batch = set(batch)
|
|
while s_batch:
|
|
for tv in sorted(s_batch, key=lambda x: x.raw_id):
|
|
if lowers[tv] or uppers[tv]:
|
|
solvable_tv = tv
|
|
break
|
|
else:
|
|
break
|
|
# Solve each solvable type variable separately.
|
|
s_batch.remove(solvable_tv)
|
|
result = solve_one(lowers[solvable_tv], uppers[solvable_tv])
|
|
solutions[solvable_tv] = result
|
|
if result is None:
|
|
# TODO: support backtracking lower/upper bound choices and order within SCCs.
|
|
# (will require switching this function from iterative to recursive).
|
|
continue
|
|
|
|
# Update the (transitive) bounds from graph if there is a solution.
|
|
# This is needed to guarantee solutions will never contradict the initial
|
|
# constraints. For example, consider {T <: S, T <: A, S :> B} with A :> B.
|
|
# If we would not update the uppers/lowers from graph, we would infer T = A, S = B
|
|
# which is not correct.
|
|
for l, u in graph.copy():
|
|
if l == u:
|
|
continue
|
|
if l == solvable_tv:
|
|
lowers[u].add(result)
|
|
graph.remove((l, u))
|
|
if u == solvable_tv:
|
|
uppers[l].add(result)
|
|
graph.remove((l, u))
|
|
|
|
# We can update uppers/lowers only once after solving the whole SCC,
|
|
# since uppers/lowers can't depend on type variables in the SCC
|
|
# (and we would reject such SCC as non-linear and therefore not solvable).
|
|
subs = {tv: s for (tv, s) in solutions.items() if s is not None}
|
|
for tv in lowers:
|
|
lowers[tv] = {expand_type(lt, subs) for lt in lowers[tv]}
|
|
for tv in uppers:
|
|
uppers[tv] = {expand_type(ut, subs) for ut in uppers[tv]}
|
|
return solutions
|
|
|
|
|
|
def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None:
|
|
"""Solve constraints by finding by using meets of upper bounds, and joins of lower bounds."""
|
|
bottom: Type | None = None
|
|
top: Type | None = None
|
|
candidate: Type | None = None
|
|
|
|
# Process each bound separately, and calculate the lower and upper
|
|
# bounds based on constraints. Note that we assume that the constraint
|
|
# targets do not have constraint references.
|
|
for target in lowers:
|
|
if bottom is None:
|
|
bottom = target
|
|
else:
|
|
if type_state.infer_unions:
|
|
# This deviates from the general mypy semantics because
|
|
# recursive types are union-heavy in 95% of cases.
|
|
bottom = UnionType.make_union([bottom, target])
|
|
else:
|
|
bottom = join_types(bottom, target)
|
|
|
|
for target in uppers:
|
|
if top is None:
|
|
top = target
|
|
else:
|
|
top = meet_types(top, target)
|
|
|
|
p_top = get_proper_type(top)
|
|
p_bottom = get_proper_type(bottom)
|
|
if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType):
|
|
source_any = top if isinstance(p_top, AnyType) else bottom
|
|
assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType)
|
|
return AnyType(TypeOfAny.from_another_any, source_any=source_any)
|
|
elif bottom is None:
|
|
if top:
|
|
candidate = top
|
|
else:
|
|
# No constraints for type variable
|
|
return None
|
|
elif top is None:
|
|
candidate = bottom
|
|
elif is_subtype(bottom, top):
|
|
candidate = bottom
|
|
else:
|
|
candidate = None
|
|
return candidate
|
|
|
|
|
|
def choose_free(
|
|
scc: list[TypeVarLikeType], original_vars: list[TypeVarId]
|
|
) -> TypeVarLikeType | None:
|
|
"""Choose the best solution for an SCC containing only type variables.
|
|
|
|
This is needed to preserve e.g. the upper bound in a situation like this:
|
|
def dec(f: Callable[[T], S]) -> Callable[[T], S]: ...
|
|
|
|
@dec
|
|
def test(x: U) -> U: ...
|
|
|
|
where U <: A.
|
|
"""
|
|
|
|
if len(scc) == 1:
|
|
# Fast path, choice is trivial.
|
|
return scc[0]
|
|
|
|
common_upper_bound = meet_type_list([t.upper_bound for t in scc])
|
|
common_upper_bound_p = get_proper_type(common_upper_bound)
|
|
# We include None for when strict-optional is disabled.
|
|
if isinstance(common_upper_bound_p, (UninhabitedType, NoneType)):
|
|
# This will cause to infer <nothing>, which is better than a free TypeVar
|
|
# that has an upper bound <nothing>.
|
|
return None
|
|
|
|
values: list[Type] = []
|
|
for tv in scc:
|
|
if isinstance(tv, TypeVarType) and tv.values:
|
|
if values:
|
|
# It is too tricky to support multiple TypeVars with values
|
|
# within the same SCC.
|
|
return None
|
|
values = tv.values.copy()
|
|
|
|
if values and not is_trivial_bound(common_upper_bound_p):
|
|
# If there are both values and upper bound present, we give up,
|
|
# since type variables having both are not supported.
|
|
return None
|
|
|
|
# For convenience with current type application machinery, we use a stable
|
|
# choice that prefers the original type variables (not polymorphic ones) in SCC.
|
|
best = sorted(scc, key=lambda x: (x.id not in original_vars, x.id.raw_id))[0]
|
|
if isinstance(best, TypeVarType):
|
|
return best.copy_modified(values=values, upper_bound=common_upper_bound)
|
|
if is_trivial_bound(common_upper_bound_p):
|
|
# TODO: support more cases for ParamSpecs/TypeVarTuples
|
|
return best
|
|
return None
|
|
|
|
|
|
def is_trivial_bound(tp: ProperType) -> bool:
|
|
return isinstance(tp, Instance) and tp.type.fullname == "builtins.object"
|
|
|
|
|
|
def find_linear(c: Constraint) -> Tuple[bool, TypeVarId | None]:
|
|
"""Find out if this constraint represent a linear relationship, return target id if yes."""
|
|
if isinstance(c.origin_type_var, TypeVarType):
|
|
if isinstance(c.target, TypeVarType):
|
|
return True, c.target.id
|
|
if isinstance(c.origin_type_var, ParamSpecType):
|
|
if isinstance(c.target, ParamSpecType) and not c.target.prefix.arg_types:
|
|
return True, c.target.id
|
|
if isinstance(c.origin_type_var, TypeVarTupleType):
|
|
target = get_proper_type(c.target)
|
|
if isinstance(target, TupleType) and len(target.items) == 1:
|
|
item = target.items[0]
|
|
if isinstance(item, UnpackType) and isinstance(item.type, TypeVarTupleType):
|
|
return True, item.type.id
|
|
return False, None
|
|
|
|
|
|
def transitive_closure(
|
|
tvars: list[TypeVarId], constraints: list[Constraint]
|
|
) -> tuple[Graph, Bounds, Bounds]:
|
|
"""Find transitive closure for given constraints on type variables.
|
|
|
|
Transitive closure gives maximal set of lower/upper bounds for each type variable,
|
|
such that we cannot deduce any further bounds by chaining other existing bounds.
|
|
|
|
The transitive closure is represented by:
|
|
* A set of lower and upper bounds for each type variable, where only constant and
|
|
non-linear terms are included in the bounds.
|
|
* A graph of linear constraints between type variables (represented as a set of pairs)
|
|
Such separation simplifies reasoning, and allows an efficient and simple incremental
|
|
transitive closure algorithm that we use here.
|
|
|
|
For example if we have initial constraints [T <: S, S <: U, U <: int], the transitive
|
|
closure is given by:
|
|
* {} <: T <: {int}
|
|
* {} <: S <: {int}
|
|
* {} <: U <: {int}
|
|
* {T <: S, S <: U, T <: U}
|
|
"""
|
|
uppers: Bounds = defaultdict(set)
|
|
lowers: Bounds = defaultdict(set)
|
|
graph: Graph = {(tv, tv) for tv in tvars}
|
|
|
|
remaining = set(constraints)
|
|
while remaining:
|
|
c = remaining.pop()
|
|
# Note that ParamSpec constraint P <: Q may be considered linear only if Q has no prefix,
|
|
# for cases like P <: Concatenate[T, Q] we should consider this non-linear and put {P} and
|
|
# {T, Q} into separate SCCs. Similarly, Ts <: Tuple[*Us] considered linear, while
|
|
# Ts <: Tuple[*Us, U] is non-linear.
|
|
is_linear, target_id = find_linear(c)
|
|
if is_linear and target_id in tvars:
|
|
assert target_id is not None
|
|
if c.op == SUBTYPE_OF:
|
|
lower, upper = c.type_var, target_id
|
|
else:
|
|
lower, upper = target_id, c.type_var
|
|
if (lower, upper) in graph:
|
|
continue
|
|
graph |= {
|
|
(l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph
|
|
}
|
|
for u in tvars:
|
|
if (upper, u) in graph:
|
|
lowers[u] |= lowers[lower]
|
|
for l in tvars:
|
|
if (l, lower) in graph:
|
|
uppers[l] |= uppers[upper]
|
|
for lt in lowers[lower]:
|
|
for ut in uppers[upper]:
|
|
# TODO: what if secondary constraints result in inference
|
|
# against polymorphic actual (also in below branches)?
|
|
remaining |= set(infer_constraints(lt, ut, SUBTYPE_OF))
|
|
remaining |= set(infer_constraints(ut, lt, SUPERTYPE_OF))
|
|
elif c.op == SUBTYPE_OF:
|
|
if c.target in uppers[c.type_var]:
|
|
continue
|
|
for l in tvars:
|
|
if (l, c.type_var) in graph:
|
|
uppers[l].add(c.target)
|
|
for lt in lowers[c.type_var]:
|
|
remaining |= set(infer_constraints(lt, c.target, SUBTYPE_OF))
|
|
remaining |= set(infer_constraints(c.target, lt, SUPERTYPE_OF))
|
|
else:
|
|
assert c.op == SUPERTYPE_OF
|
|
if c.target in lowers[c.type_var]:
|
|
continue
|
|
for u in tvars:
|
|
if (c.type_var, u) in graph:
|
|
lowers[u].add(c.target)
|
|
for ut in uppers[c.type_var]:
|
|
remaining |= set(infer_constraints(ut, c.target, SUPERTYPE_OF))
|
|
remaining |= set(infer_constraints(c.target, ut, SUBTYPE_OF))
|
|
return graph, lowers, uppers
|
|
|
|
|
|
def compute_dependencies(
|
|
tvars: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds
|
|
) -> dict[TypeVarId, list[TypeVarId]]:
|
|
"""Compute dependencies between type variables induced by constraints.
|
|
|
|
If we have a constraint like T <: List[S], we say that T depends on S, since
|
|
we will need to solve for S first before we can solve for T.
|
|
"""
|
|
res = {}
|
|
for tv in tvars:
|
|
deps = set()
|
|
for lt in lowers[tv]:
|
|
deps |= get_vars(lt, tvars)
|
|
for ut in uppers[tv]:
|
|
deps |= get_vars(ut, tvars)
|
|
for other in tvars:
|
|
if other == tv:
|
|
continue
|
|
if (tv, other) in graph or (other, tv) in graph:
|
|
deps.add(other)
|
|
res[tv] = list(deps)
|
|
return res
|
|
|
|
|
|
def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool:
|
|
"""Check there are only linear constraints between type variables in SCC.
|
|
|
|
Linear are constraints like T <: S (while T <: F[S] are non-linear).
|
|
"""
|
|
for tv in scc:
|
|
if any(get_vars(lt, list(scc)) for lt in lowers[tv]):
|
|
return False
|
|
if any(get_vars(ut, list(scc)) for ut in uppers[tv]):
|
|
return False
|
|
return True
|
|
|
|
|
|
def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]:
|
|
"""Find type variables for which we are solving in a target type."""
|
|
return {tv.id for tv in get_all_type_vars(target)} & set(vars)
|