# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt """Special methods checker and helper function's module.""" from __future__ import annotations from collections.abc import Callable import astroid from astroid import bases, nodes, util from astroid.context import InferenceContext from astroid.typing import InferenceResult from pylint.checkers import BaseChecker from pylint.checkers.utils import ( PYMETHODS, SPECIAL_METHODS_PARAMS, decorated_with, is_function_body_ellipsis, only_required_for_messages, safe_infer, ) from pylint.lint.pylinter import PyLinter NEXT_METHOD = "__next__" def _safe_infer_call_result( node: nodes.FunctionDef, caller: nodes.FunctionDef, context: InferenceContext | None = None, ) -> InferenceResult | None: """Safely infer the return value of a function. Returns None if inference failed or if there is some ambiguity (more than one node has been inferred). Otherwise, returns inferred value. """ try: inferit = node.infer_call_result(caller, context=context) value = next(inferit) except astroid.InferenceError: return None # inference failed except StopIteration: return None # no values inferred try: next(inferit) return None # there is ambiguity on the inferred node except astroid.InferenceError: return None # there is some kind of ambiguity except StopIteration: return value class SpecialMethodsChecker(BaseChecker): """Checker which verifies that special methods are implemented correctly. """ name = "classes" msgs = { "E0301": ( "__iter__ returns non-iterator", "non-iterator-returned", "Used when an __iter__ method returns something which is not an " f"iterable (i.e. has no `{NEXT_METHOD}` method)", { "old_names": [ ("W0234", "old-non-iterator-returned-1"), ("E0234", "old-non-iterator-returned-2"), ] }, ), "E0302": ( "The special method %r expects %s param(s), %d %s given", "unexpected-special-method-signature", "Emitted when a special method was defined with an " "invalid number of parameters. If it has too few or " "too many, it might not work at all.", {"old_names": [("E0235", "bad-context-manager")]}, ), "E0303": ( "__len__ does not return non-negative integer", "invalid-length-returned", "Used when a __len__ method returns something which is not a " "non-negative integer", ), "E0304": ( "__bool__ does not return bool", "invalid-bool-returned", "Used when a __bool__ method returns something which is not a bool", ), "E0305": ( "__index__ does not return int", "invalid-index-returned", "Used when an __index__ method returns something which is not " "an integer", ), "E0306": ( "__repr__ does not return str", "invalid-repr-returned", "Used when a __repr__ method returns something which is not a string", ), "E0307": ( "__str__ does not return str", "invalid-str-returned", "Used when a __str__ method returns something which is not a string", ), "E0308": ( "__bytes__ does not return bytes", "invalid-bytes-returned", "Used when a __bytes__ method returns something which is not bytes", ), "E0309": ( "__hash__ does not return int", "invalid-hash-returned", "Used when a __hash__ method returns something which is not an integer", ), "E0310": ( "__length_hint__ does not return non-negative integer", "invalid-length-hint-returned", "Used when a __length_hint__ method returns something which is not a " "non-negative integer", ), "E0311": ( "__format__ does not return str", "invalid-format-returned", "Used when a __format__ method returns something which is not a string", ), "E0312": ( "__getnewargs__ does not return a tuple", "invalid-getnewargs-returned", "Used when a __getnewargs__ method returns something which is not " "a tuple", ), "E0313": ( "__getnewargs_ex__ does not return a tuple containing (tuple, dict)", "invalid-getnewargs-ex-returned", "Used when a __getnewargs_ex__ method returns something which is not " "of the form tuple(tuple, dict)", ), } def __init__(self, linter: PyLinter) -> None: super().__init__(linter) self._protocol_map: dict[ str, Callable[[nodes.FunctionDef, InferenceResult], None] ] = { "__iter__": self._check_iter, "__len__": self._check_len, "__bool__": self._check_bool, "__index__": self._check_index, "__repr__": self._check_repr, "__str__": self._check_str, "__bytes__": self._check_bytes, "__hash__": self._check_hash, "__length_hint__": self._check_length_hint, "__format__": self._check_format, "__getnewargs__": self._check_getnewargs, "__getnewargs_ex__": self._check_getnewargs_ex, } @only_required_for_messages( "unexpected-special-method-signature", "non-iterator-returned", "invalid-length-returned", "invalid-bool-returned", "invalid-index-returned", "invalid-repr-returned", "invalid-str-returned", "invalid-bytes-returned", "invalid-hash-returned", "invalid-length-hint-returned", "invalid-format-returned", "invalid-getnewargs-returned", "invalid-getnewargs-ex-returned", ) def visit_functiondef(self, node: nodes.FunctionDef) -> None: if not node.is_method(): return inferred = _safe_infer_call_result(node, node) # Only want to check types that we are able to infer if ( inferred and node.name in self._protocol_map and not is_function_body_ellipsis(node) ): self._protocol_map[node.name](node, inferred) if node.name in PYMETHODS: self._check_unexpected_method_signature(node) visit_asyncfunctiondef = visit_functiondef def _check_unexpected_method_signature(self, node: nodes.FunctionDef) -> None: expected_params = SPECIAL_METHODS_PARAMS[node.name] if expected_params is None: # This can support a variable number of parameters. return if not node.args.args and not node.args.vararg: # Method has no parameter, will be caught # by no-method-argument. return if decorated_with(node, ["builtins.staticmethod"]): # We expect to not take in consideration self. all_args = node.args.args else: all_args = node.args.args[1:] mandatory = len(all_args) - len(node.args.defaults) optional = len(node.args.defaults) current_params = mandatory + optional emit = False # If we don't know we choose a false negative if isinstance(expected_params, tuple): # The expected number of parameters can be any value from this # tuple, although the user should implement the method # to take all of them in consideration. emit = mandatory not in expected_params # mypy thinks that expected_params has type tuple[int, int] | int | None # But at this point it must be 'tuple[int, int]' because of the type check expected_params = f"between {expected_params[0]} or {expected_params[1]}" # type: ignore[assignment] else: # If the number of mandatory parameters doesn't # suffice, the expected parameters for this # function will be deduced from the optional # parameters. rest = expected_params - mandatory if rest == 0: emit = False elif rest < 0: emit = True elif rest > 0: emit = not ((optional - rest) >= 0 or node.args.vararg) if emit: verb = "was" if current_params <= 1 else "were" self.add_message( "unexpected-special-method-signature", args=(node.name, expected_params, current_params, verb), node=node, ) @staticmethod def _is_wrapped_type(node: InferenceResult, type_: str) -> bool: return ( isinstance(node, bases.Instance) and node.name == type_ and not isinstance(node, nodes.Const) ) @staticmethod def _is_int(node: InferenceResult) -> bool: if SpecialMethodsChecker._is_wrapped_type(node, "int"): return True return isinstance(node, nodes.Const) and isinstance(node.value, int) @staticmethod def _is_str(node: InferenceResult) -> bool: if SpecialMethodsChecker._is_wrapped_type(node, "str"): return True return isinstance(node, nodes.Const) and isinstance(node.value, str) @staticmethod def _is_bool(node: InferenceResult) -> bool: if SpecialMethodsChecker._is_wrapped_type(node, "bool"): return True return isinstance(node, nodes.Const) and isinstance(node.value, bool) @staticmethod def _is_bytes(node: InferenceResult) -> bool: if SpecialMethodsChecker._is_wrapped_type(node, "bytes"): return True return isinstance(node, nodes.Const) and isinstance(node.value, bytes) @staticmethod def _is_tuple(node: InferenceResult) -> bool: if SpecialMethodsChecker._is_wrapped_type(node, "tuple"): return True return isinstance(node, nodes.Const) and isinstance(node.value, tuple) @staticmethod def _is_dict(node: InferenceResult) -> bool: if SpecialMethodsChecker._is_wrapped_type(node, "dict"): return True return isinstance(node, nodes.Const) and isinstance(node.value, dict) @staticmethod def _is_iterator(node: InferenceResult) -> bool: if isinstance(node, bases.Generator): # Generators can be iterated. return True if isinstance(node, nodes.ComprehensionScope): # Comprehensions can be iterated. return True if isinstance(node, bases.Instance): try: node.local_attr(NEXT_METHOD) return True except astroid.NotFoundError: pass elif isinstance(node, nodes.ClassDef): metaclass = node.metaclass() if metaclass and isinstance(metaclass, nodes.ClassDef): try: metaclass.local_attr(NEXT_METHOD) return True except astroid.NotFoundError: pass return False def _check_iter(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None: if not self._is_iterator(inferred): self.add_message("non-iterator-returned", node=node) def _check_len(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None: if not self._is_int(inferred): self.add_message("invalid-length-returned", node=node) elif isinstance(inferred, nodes.Const) and inferred.value < 0: self.add_message("invalid-length-returned", node=node) def _check_bool(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None: if not self._is_bool(inferred): self.add_message("invalid-bool-returned", node=node) def _check_index(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None: if not self._is_int(inferred): self.add_message("invalid-index-returned", node=node) def _check_repr(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None: if not self._is_str(inferred): self.add_message("invalid-repr-returned", node=node) def _check_str(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None: if not self._is_str(inferred): self.add_message("invalid-str-returned", node=node) def _check_bytes(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None: if not self._is_bytes(inferred): self.add_message("invalid-bytes-returned", node=node) def _check_hash(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None: if not self._is_int(inferred): self.add_message("invalid-hash-returned", node=node) def _check_length_hint( self, node: nodes.FunctionDef, inferred: InferenceResult ) -> None: if not self._is_int(inferred): self.add_message("invalid-length-hint-returned", node=node) elif isinstance(inferred, nodes.Const) and inferred.value < 0: self.add_message("invalid-length-hint-returned", node=node) def _check_format(self, node: nodes.FunctionDef, inferred: InferenceResult) -> None: if not self._is_str(inferred): self.add_message("invalid-format-returned", node=node) def _check_getnewargs( self, node: nodes.FunctionDef, inferred: InferenceResult ) -> None: if not self._is_tuple(inferred): self.add_message("invalid-getnewargs-returned", node=node) def _check_getnewargs_ex( self, node: nodes.FunctionDef, inferred: InferenceResult ) -> None: if not self._is_tuple(inferred): self.add_message("invalid-getnewargs-ex-returned", node=node) return if not isinstance(inferred, nodes.Tuple): # If it's not an astroid.Tuple we can't analyze it further return found_error = False if len(inferred.elts) != 2: found_error = True else: for arg, check in ( (inferred.elts[0], self._is_tuple), (inferred.elts[1], self._is_dict), ): if isinstance(arg, nodes.Call): arg = safe_infer(arg) if arg and not isinstance(arg, util.UninferableBase): if not check(arg): found_error = True break if found_error: self.add_message("invalid-getnewargs-ex-returned", node=node)