141 lines
5 KiB
Python
141 lines
5 KiB
Python
|
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
|
||
|
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
|
||
|
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
|
||
|
|
||
|
"""Classes representing different types of constraints on inference values."""
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import sys
|
||
|
from abc import ABC, abstractmethod
|
||
|
from collections.abc import Iterator
|
||
|
from typing import TYPE_CHECKING, Union
|
||
|
|
||
|
from astroid import nodes, util
|
||
|
from astroid.typing import InferenceResult
|
||
|
|
||
|
if sys.version_info >= (3, 11):
|
||
|
from typing import Self
|
||
|
else:
|
||
|
from typing_extensions import Self
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from astroid import bases
|
||
|
|
||
|
_NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name]
|
||
|
|
||
|
|
||
|
class Constraint(ABC):
|
||
|
"""Represents a single constraint on a variable."""
|
||
|
|
||
|
def __init__(self, node: nodes.NodeNG, negate: bool) -> None:
|
||
|
self.node = node
|
||
|
"""The node that this constraint applies to."""
|
||
|
self.negate = negate
|
||
|
"""True if this constraint is negated. E.g., "is not" instead of "is"."""
|
||
|
|
||
|
@classmethod
|
||
|
@abstractmethod
|
||
|
def match(
|
||
|
cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
|
||
|
) -> Self | None:
|
||
|
"""Return a new constraint for node matched from expr, if expr matches
|
||
|
the constraint pattern.
|
||
|
|
||
|
If negate is True, negate the constraint.
|
||
|
"""
|
||
|
|
||
|
@abstractmethod
|
||
|
def satisfied_by(self, inferred: InferenceResult) -> bool:
|
||
|
"""Return True if this constraint is satisfied by the given inferred value."""
|
||
|
|
||
|
|
||
|
class NoneConstraint(Constraint):
|
||
|
"""Represents an "is None" or "is not None" constraint."""
|
||
|
|
||
|
CONST_NONE: nodes.Const = nodes.Const(None)
|
||
|
|
||
|
@classmethod
|
||
|
def match(
|
||
|
cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
|
||
|
) -> Self | None:
|
||
|
"""Return a new constraint for node matched from expr, if expr matches
|
||
|
the constraint pattern.
|
||
|
|
||
|
Negate the constraint based on the value of negate.
|
||
|
"""
|
||
|
if isinstance(expr, nodes.Compare) and len(expr.ops) == 1:
|
||
|
left = expr.left
|
||
|
op, right = expr.ops[0]
|
||
|
if op in {"is", "is not"} and (
|
||
|
_matches(left, node) and _matches(right, cls.CONST_NONE)
|
||
|
):
|
||
|
negate = (op == "is" and negate) or (op == "is not" and not negate)
|
||
|
return cls(node=node, negate=negate)
|
||
|
|
||
|
return None
|
||
|
|
||
|
def satisfied_by(self, inferred: InferenceResult) -> bool:
|
||
|
"""Return True if this constraint is satisfied by the given inferred value."""
|
||
|
# Assume true if uninferable
|
||
|
if isinstance(inferred, util.UninferableBase):
|
||
|
return True
|
||
|
|
||
|
# Return the XOR of self.negate and matches(inferred, self.CONST_NONE)
|
||
|
return self.negate ^ _matches(inferred, self.CONST_NONE)
|
||
|
|
||
|
|
||
|
def get_constraints(
|
||
|
expr: _NameNodes, frame: nodes.LocalsDictNodeNG
|
||
|
) -> dict[nodes.If, set[Constraint]]:
|
||
|
"""Returns the constraints for the given expression.
|
||
|
|
||
|
The returned dictionary maps the node where the constraint was generated to the
|
||
|
corresponding constraint(s).
|
||
|
|
||
|
Constraints are computed statically by analysing the code surrounding expr.
|
||
|
Currently this only supports constraints generated from if conditions.
|
||
|
"""
|
||
|
current_node: nodes.NodeNG | None = expr
|
||
|
constraints_mapping: dict[nodes.If, set[Constraint]] = {}
|
||
|
while current_node is not None and current_node is not frame:
|
||
|
parent = current_node.parent
|
||
|
if isinstance(parent, nodes.If):
|
||
|
branch, _ = parent.locate_child(current_node)
|
||
|
constraints: set[Constraint] | None = None
|
||
|
if branch == "body":
|
||
|
constraints = set(_match_constraint(expr, parent.test))
|
||
|
elif branch == "orelse":
|
||
|
constraints = set(_match_constraint(expr, parent.test, invert=True))
|
||
|
|
||
|
if constraints:
|
||
|
constraints_mapping[parent] = constraints
|
||
|
current_node = parent
|
||
|
|
||
|
return constraints_mapping
|
||
|
|
||
|
|
||
|
ALL_CONSTRAINT_CLASSES = frozenset((NoneConstraint,))
|
||
|
"""All supported constraint types."""
|
||
|
|
||
|
|
||
|
def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool:
|
||
|
"""Returns True if the two nodes match."""
|
||
|
if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name):
|
||
|
return node1.name == node2.name
|
||
|
if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute):
|
||
|
return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr)
|
||
|
if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const):
|
||
|
return node1.value == node2.value
|
||
|
|
||
|
return False
|
||
|
|
||
|
|
||
|
def _match_constraint(
|
||
|
node: _NameNodes, expr: nodes.NodeNG, invert: bool = False
|
||
|
) -> Iterator[Constraint]:
|
||
|
"""Yields all constraint patterns for node that match."""
|
||
|
for constraint_cls in ALL_CONSTRAINT_CLASSES:
|
||
|
constraint = constraint_cls.match(node, expr, invert)
|
||
|
if constraint:
|
||
|
yield constraint
|