239 lines
8.4 KiB
Python
239 lines
8.4 KiB
Python
|
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
|
||
|
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
|
||
|
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
|
||
|
|
||
|
"""A few useful function/method decorators."""
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import functools
|
||
|
import inspect
|
||
|
import sys
|
||
|
import warnings
|
||
|
from collections.abc import Callable, Generator
|
||
|
from typing import TypeVar
|
||
|
|
||
|
from astroid import util
|
||
|
from astroid.context import InferenceContext
|
||
|
from astroid.exceptions import InferenceError
|
||
|
from astroid.typing import InferenceResult
|
||
|
|
||
|
if sys.version_info >= (3, 10):
|
||
|
from typing import ParamSpec
|
||
|
else:
|
||
|
from typing_extensions import ParamSpec
|
||
|
|
||
|
_R = TypeVar("_R")
|
||
|
_P = ParamSpec("_P")
|
||
|
|
||
|
|
||
|
def path_wrapper(func):
|
||
|
"""Return the given infer function wrapped to handle the path.
|
||
|
|
||
|
Used to stop inference if the node has already been looked
|
||
|
at for a given `InferenceContext` to prevent infinite recursion
|
||
|
"""
|
||
|
|
||
|
@functools.wraps(func)
|
||
|
def wrapped(
|
||
|
node, context: InferenceContext | None = None, _func=func, **kwargs
|
||
|
) -> Generator:
|
||
|
"""Wrapper function handling context."""
|
||
|
if context is None:
|
||
|
context = InferenceContext()
|
||
|
if context.push(node):
|
||
|
return
|
||
|
|
||
|
yielded = set()
|
||
|
|
||
|
for res in _func(node, context, **kwargs):
|
||
|
# unproxy only true instance, not const, tuple, dict...
|
||
|
if res.__class__.__name__ == "Instance":
|
||
|
ares = res._proxied
|
||
|
else:
|
||
|
ares = res
|
||
|
if ares not in yielded:
|
||
|
yield res
|
||
|
yielded.add(ares)
|
||
|
|
||
|
return wrapped
|
||
|
|
||
|
|
||
|
def yes_if_nothing_inferred(
|
||
|
func: Callable[_P, Generator[InferenceResult, None, None]]
|
||
|
) -> Callable[_P, Generator[InferenceResult, None, None]]:
|
||
|
def inner(
|
||
|
*args: _P.args, **kwargs: _P.kwargs
|
||
|
) -> Generator[InferenceResult, None, None]:
|
||
|
generator = func(*args, **kwargs)
|
||
|
|
||
|
try:
|
||
|
yield next(generator)
|
||
|
except StopIteration:
|
||
|
# generator is empty
|
||
|
yield util.Uninferable
|
||
|
return
|
||
|
|
||
|
yield from generator
|
||
|
|
||
|
return inner
|
||
|
|
||
|
|
||
|
def raise_if_nothing_inferred(
|
||
|
func: Callable[_P, Generator[InferenceResult, None, None]],
|
||
|
) -> Callable[_P, Generator[InferenceResult, None, None]]:
|
||
|
def inner(
|
||
|
*args: _P.args, **kwargs: _P.kwargs
|
||
|
) -> Generator[InferenceResult, None, None]:
|
||
|
generator = func(*args, **kwargs)
|
||
|
try:
|
||
|
yield next(generator)
|
||
|
except StopIteration as error:
|
||
|
# generator is empty
|
||
|
if error.args:
|
||
|
raise InferenceError(**error.args[0]) from error
|
||
|
raise InferenceError(
|
||
|
"StopIteration raised without any error information."
|
||
|
) from error
|
||
|
except RecursionError as error:
|
||
|
raise InferenceError(
|
||
|
f"RecursionError raised with limit {sys.getrecursionlimit()}."
|
||
|
) from error
|
||
|
|
||
|
yield from generator
|
||
|
|
||
|
return inner
|
||
|
|
||
|
|
||
|
# Expensive decorators only used to emit Deprecation warnings.
|
||
|
# If no other than the default DeprecationWarning are enabled,
|
||
|
# fall back to passthrough implementations.
|
||
|
if util.check_warnings_filter(): # noqa: C901
|
||
|
|
||
|
def deprecate_default_argument_values(
|
||
|
astroid_version: str = "3.0", **arguments: str
|
||
|
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||
|
"""Decorator which emits a DeprecationWarning if any arguments specified
|
||
|
are None or not passed at all.
|
||
|
|
||
|
Arguments should be a key-value mapping, with the key being the argument to check
|
||
|
and the value being a type annotation as string for the value of the argument.
|
||
|
|
||
|
To improve performance, only used when DeprecationWarnings other than
|
||
|
the default one are enabled.
|
||
|
"""
|
||
|
# Helpful links
|
||
|
# Decorator for DeprecationWarning: https://stackoverflow.com/a/49802489
|
||
|
# Typing of stacked decorators: https://stackoverflow.com/a/68290080
|
||
|
|
||
|
def deco(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
||
|
"""Decorator function."""
|
||
|
|
||
|
@functools.wraps(func)
|
||
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||
|
"""Emit DeprecationWarnings if conditions are met."""
|
||
|
|
||
|
keys = list(inspect.signature(func).parameters.keys())
|
||
|
for arg, type_annotation in arguments.items():
|
||
|
try:
|
||
|
index = keys.index(arg)
|
||
|
except ValueError:
|
||
|
raise ValueError(
|
||
|
f"Can't find argument '{arg}' for '{args[0].__class__.__qualname__}'"
|
||
|
) from None
|
||
|
if (
|
||
|
# Check kwargs
|
||
|
# - if found, check it's not None
|
||
|
(arg in kwargs and kwargs[arg] is None)
|
||
|
# Check args
|
||
|
# - make sure not in kwargs
|
||
|
# - len(args) needs to be long enough, if too short
|
||
|
# arg can't be in args either
|
||
|
# - args[index] should not be None
|
||
|
or arg not in kwargs
|
||
|
and (
|
||
|
index == -1
|
||
|
or len(args) <= index
|
||
|
or (len(args) > index and args[index] is None)
|
||
|
)
|
||
|
):
|
||
|
warnings.warn(
|
||
|
f"'{arg}' will be a required argument for "
|
||
|
f"'{args[0].__class__.__qualname__}.{func.__name__}'"
|
||
|
f" in astroid {astroid_version} "
|
||
|
f"('{arg}' should be of type: '{type_annotation}')",
|
||
|
DeprecationWarning,
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
return func(*args, **kwargs)
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
return deco
|
||
|
|
||
|
def deprecate_arguments(
|
||
|
astroid_version: str = "3.0", **arguments: str
|
||
|
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||
|
"""Decorator which emits a DeprecationWarning if any arguments specified
|
||
|
are passed.
|
||
|
|
||
|
Arguments should be a key-value mapping, with the key being the argument to check
|
||
|
and the value being a string that explains what to do instead of passing the argument.
|
||
|
|
||
|
To improve performance, only used when DeprecationWarnings other than
|
||
|
the default one are enabled.
|
||
|
"""
|
||
|
|
||
|
def deco(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
||
|
@functools.wraps(func)
|
||
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||
|
keys = list(inspect.signature(func).parameters.keys())
|
||
|
for arg, note in arguments.items():
|
||
|
try:
|
||
|
index = keys.index(arg)
|
||
|
except ValueError:
|
||
|
raise ValueError(
|
||
|
f"Can't find argument '{arg}' for '{args[0].__class__.__qualname__}'"
|
||
|
) from None
|
||
|
if arg in kwargs or len(args) > index:
|
||
|
warnings.warn(
|
||
|
f"The argument '{arg}' for "
|
||
|
f"'{args[0].__class__.__qualname__}.{func.__name__}' is deprecated "
|
||
|
f"and will be removed in astroid {astroid_version} ({note})",
|
||
|
DeprecationWarning,
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
return func(*args, **kwargs)
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
return deco
|
||
|
|
||
|
else:
|
||
|
|
||
|
def deprecate_default_argument_values(
|
||
|
astroid_version: str = "3.0", **arguments: str
|
||
|
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||
|
"""Passthrough decorator to improve performance if DeprecationWarnings are
|
||
|
disabled.
|
||
|
"""
|
||
|
|
||
|
def deco(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
||
|
"""Decorator function."""
|
||
|
return func
|
||
|
|
||
|
return deco
|
||
|
|
||
|
def deprecate_arguments(
|
||
|
astroid_version: str = "3.0", **arguments: str
|
||
|
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||
|
"""Passthrough decorator to improve performance if DeprecationWarnings are
|
||
|
disabled.
|
||
|
"""
|
||
|
|
||
|
def deco(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
||
|
"""Decorator function."""
|
||
|
return func
|
||
|
|
||
|
return deco
|