2148 lines
81 KiB
Python
2148 lines
81 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import copy
|
||
|
import re
|
||
|
import sys
|
||
|
import warnings
|
||
|
from typing import Any, Callable, Final, List, Optional, Sequence, TypeVar, Union, cast
|
||
|
from typing_extensions import Literal, overload
|
||
|
|
||
|
from mypy import defaults, errorcodes as codes, message_registry
|
||
|
from mypy.errors import Errors
|
||
|
from mypy.message_registry import ErrorMessage
|
||
|
from mypy.nodes import (
|
||
|
ARG_NAMED,
|
||
|
ARG_NAMED_OPT,
|
||
|
ARG_OPT,
|
||
|
ARG_POS,
|
||
|
ARG_STAR,
|
||
|
ARG_STAR2,
|
||
|
ArgKind,
|
||
|
Argument,
|
||
|
AssertStmt,
|
||
|
AssignmentExpr,
|
||
|
AssignmentStmt,
|
||
|
AwaitExpr,
|
||
|
Block,
|
||
|
BreakStmt,
|
||
|
BytesExpr,
|
||
|
CallExpr,
|
||
|
ClassDef,
|
||
|
ComparisonExpr,
|
||
|
ComplexExpr,
|
||
|
ConditionalExpr,
|
||
|
ContinueStmt,
|
||
|
Decorator,
|
||
|
DelStmt,
|
||
|
DictExpr,
|
||
|
DictionaryComprehension,
|
||
|
EllipsisExpr,
|
||
|
Expression,
|
||
|
ExpressionStmt,
|
||
|
FakeInfo,
|
||
|
FloatExpr,
|
||
|
ForStmt,
|
||
|
FuncDef,
|
||
|
GeneratorExpr,
|
||
|
GlobalDecl,
|
||
|
IfStmt,
|
||
|
Import,
|
||
|
ImportAll,
|
||
|
ImportBase,
|
||
|
ImportFrom,
|
||
|
IndexExpr,
|
||
|
IntExpr,
|
||
|
LambdaExpr,
|
||
|
ListComprehension,
|
||
|
ListExpr,
|
||
|
MatchStmt,
|
||
|
MemberExpr,
|
||
|
MypyFile,
|
||
|
NameExpr,
|
||
|
Node,
|
||
|
NonlocalDecl,
|
||
|
OperatorAssignmentStmt,
|
||
|
OpExpr,
|
||
|
OverloadedFuncDef,
|
||
|
OverloadPart,
|
||
|
PassStmt,
|
||
|
RaiseStmt,
|
||
|
RefExpr,
|
||
|
ReturnStmt,
|
||
|
SetComprehension,
|
||
|
SetExpr,
|
||
|
SliceExpr,
|
||
|
StarExpr,
|
||
|
Statement,
|
||
|
StrExpr,
|
||
|
SuperExpr,
|
||
|
TempNode,
|
||
|
TryStmt,
|
||
|
TupleExpr,
|
||
|
UnaryExpr,
|
||
|
Var,
|
||
|
WhileStmt,
|
||
|
WithStmt,
|
||
|
YieldExpr,
|
||
|
YieldFromExpr,
|
||
|
check_arg_names,
|
||
|
)
|
||
|
from mypy.options import Options
|
||
|
from mypy.patterns import (
|
||
|
AsPattern,
|
||
|
ClassPattern,
|
||
|
MappingPattern,
|
||
|
OrPattern,
|
||
|
SequencePattern,
|
||
|
SingletonPattern,
|
||
|
StarredPattern,
|
||
|
ValuePattern,
|
||
|
)
|
||
|
from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable
|
||
|
from mypy.sharedparse import argument_elide_name, special_function_elide_names
|
||
|
from mypy.traverser import TraverserVisitor
|
||
|
from mypy.types import (
|
||
|
AnyType,
|
||
|
CallableArgument,
|
||
|
CallableType,
|
||
|
EllipsisType,
|
||
|
Instance,
|
||
|
ProperType,
|
||
|
RawExpressionType,
|
||
|
TupleType,
|
||
|
Type,
|
||
|
TypeList,
|
||
|
TypeOfAny,
|
||
|
UnboundType,
|
||
|
UnionType,
|
||
|
UnpackType,
|
||
|
)
|
||
|
from mypy.util import bytes_to_human_readable_repr, unnamed_function
|
||
|
|
||
|
# pull this into a final variable to make mypyc be quiet about the
|
||
|
# the default argument warning
|
||
|
PY_MINOR_VERSION: Final = sys.version_info[1]
|
||
|
|
||
|
import ast as ast3
|
||
|
|
||
|
# TODO: Index, ExtSlice are deprecated in 3.9.
|
||
|
from ast import AST, Attribute, Call, FunctionType, Index, Name, Starred, UnaryOp, USub
|
||
|
|
||
|
|
||
|
def ast3_parse(
|
||
|
source: str | bytes, filename: str, mode: str, feature_version: int = PY_MINOR_VERSION
|
||
|
) -> AST:
|
||
|
return ast3.parse(
|
||
|
source,
|
||
|
filename,
|
||
|
mode,
|
||
|
type_comments=True, # This works the magic
|
||
|
feature_version=feature_version,
|
||
|
)
|
||
|
|
||
|
|
||
|
NamedExpr = ast3.NamedExpr
|
||
|
Constant = ast3.Constant
|
||
|
|
||
|
if sys.version_info >= (3, 12):
|
||
|
ast_TypeAlias = ast3.TypeAlias
|
||
|
else:
|
||
|
ast_TypeAlias = Any
|
||
|
|
||
|
if sys.version_info >= (3, 10):
|
||
|
Match = ast3.Match
|
||
|
MatchValue = ast3.MatchValue
|
||
|
MatchSingleton = ast3.MatchSingleton
|
||
|
MatchSequence = ast3.MatchSequence
|
||
|
MatchStar = ast3.MatchStar
|
||
|
MatchMapping = ast3.MatchMapping
|
||
|
MatchClass = ast3.MatchClass
|
||
|
MatchAs = ast3.MatchAs
|
||
|
MatchOr = ast3.MatchOr
|
||
|
AstNode = Union[ast3.expr, ast3.stmt, ast3.pattern, ast3.ExceptHandler]
|
||
|
else:
|
||
|
Match = Any
|
||
|
MatchValue = Any
|
||
|
MatchSingleton = Any
|
||
|
MatchSequence = Any
|
||
|
MatchStar = Any
|
||
|
MatchMapping = Any
|
||
|
MatchClass = Any
|
||
|
MatchAs = Any
|
||
|
MatchOr = Any
|
||
|
AstNode = Union[ast3.expr, ast3.stmt, ast3.ExceptHandler]
|
||
|
if sys.version_info >= (3, 11):
|
||
|
TryStar = ast3.TryStar
|
||
|
else:
|
||
|
TryStar = Any
|
||
|
|
||
|
N = TypeVar("N", bound=Node)
|
||
|
|
||
|
# There is no way to create reasonable fallbacks at this stage,
|
||
|
# they must be patched later.
|
||
|
MISSING_FALLBACK: Final = FakeInfo("fallback can't be filled out until semanal")
|
||
|
_dummy_fallback: Final = Instance(MISSING_FALLBACK, [], -1)
|
||
|
|
||
|
TYPE_IGNORE_PATTERN: Final = re.compile(r"[^#]*#\s*type:\s*ignore\s*(.*)")
|
||
|
|
||
|
|
||
|
def parse(
|
||
|
source: str | bytes,
|
||
|
fnam: str,
|
||
|
module: str | None,
|
||
|
errors: Errors | None = None,
|
||
|
options: Options | None = None,
|
||
|
) -> MypyFile:
|
||
|
"""Parse a source file, without doing any semantic analysis.
|
||
|
|
||
|
Return the parse tree. If errors is not provided, raise ParseError
|
||
|
on failure. Otherwise, use the errors object to report parse errors.
|
||
|
"""
|
||
|
ignore_errors = (options is not None and options.ignore_errors) or (
|
||
|
errors is not None and fnam in errors.ignored_files
|
||
|
)
|
||
|
# If errors are ignored, we can drop many function bodies to speed up type checking.
|
||
|
strip_function_bodies = ignore_errors and (options is None or not options.preserve_asts)
|
||
|
raise_on_error = False
|
||
|
if options is None:
|
||
|
options = Options()
|
||
|
if errors is None:
|
||
|
errors = Errors(options)
|
||
|
raise_on_error = True
|
||
|
errors.set_file(fnam, module, options=options)
|
||
|
is_stub_file = fnam.endswith(".pyi")
|
||
|
if is_stub_file:
|
||
|
feature_version = defaults.PYTHON3_VERSION[1]
|
||
|
if options.python_version[0] == 3 and options.python_version[1] > feature_version:
|
||
|
feature_version = options.python_version[1]
|
||
|
else:
|
||
|
assert options.python_version[0] >= 3
|
||
|
feature_version = options.python_version[1]
|
||
|
try:
|
||
|
# Disable deprecation warnings about \u
|
||
|
with warnings.catch_warnings():
|
||
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||
|
ast = ast3_parse(source, fnam, "exec", feature_version=feature_version)
|
||
|
|
||
|
tree = ASTConverter(
|
||
|
options=options,
|
||
|
is_stub=is_stub_file,
|
||
|
errors=errors,
|
||
|
ignore_errors=ignore_errors,
|
||
|
strip_function_bodies=strip_function_bodies,
|
||
|
).visit(ast)
|
||
|
tree.path = fnam
|
||
|
tree.is_stub = is_stub_file
|
||
|
except SyntaxError as e:
|
||
|
# alias to please mypyc
|
||
|
is_py38_or_earlier = sys.version_info < (3, 9)
|
||
|
if is_py38_or_earlier and e.filename == "<fstring>":
|
||
|
# In Python 3.8 and earlier, syntax errors in f-strings have lineno relative to the
|
||
|
# start of the f-string. This would be misleading, as mypy will report the error as the
|
||
|
# lineno within the file.
|
||
|
e.lineno = None
|
||
|
message = e.msg
|
||
|
if feature_version > sys.version_info.minor and message.startswith("invalid syntax"):
|
||
|
python_version_str = f"{options.python_version[0]}.{options.python_version[1]}"
|
||
|
message += f"; you likely need to run mypy using Python {python_version_str} or newer"
|
||
|
errors.report(
|
||
|
e.lineno if e.lineno is not None else -1,
|
||
|
e.offset,
|
||
|
message,
|
||
|
blocker=True,
|
||
|
code=codes.SYNTAX,
|
||
|
)
|
||
|
tree = MypyFile([], [], False, {})
|
||
|
|
||
|
if raise_on_error and errors.is_errors():
|
||
|
errors.raise_error()
|
||
|
|
||
|
assert isinstance(tree, MypyFile)
|
||
|
return tree
|
||
|
|
||
|
|
||
|
def parse_type_ignore_tag(tag: str | None) -> list[str] | None:
|
||
|
"""Parse optional "[code, ...]" tag after "# type: ignore".
|
||
|
|
||
|
Return:
|
||
|
* [] if no tag was found (ignore all errors)
|
||
|
* list of ignored error codes if a tag was found
|
||
|
* None if the tag was invalid.
|
||
|
"""
|
||
|
if not tag or tag.strip() == "" or tag.strip().startswith("#"):
|
||
|
# No tag -- ignore all errors.
|
||
|
return []
|
||
|
m = re.match(r"\s*\[([^]#]*)\]\s*(#.*)?$", tag)
|
||
|
if m is None:
|
||
|
# Invalid "# type: ignore" comment.
|
||
|
return None
|
||
|
return [code.strip() for code in m.group(1).split(",")]
|
||
|
|
||
|
|
||
|
def parse_type_comment(
|
||
|
type_comment: str, line: int, column: int, errors: Errors | None
|
||
|
) -> tuple[list[str] | None, ProperType | None]:
|
||
|
"""Parse type portion of a type comment (+ optional type ignore).
|
||
|
|
||
|
Return (ignore info, parsed type).
|
||
|
"""
|
||
|
try:
|
||
|
typ = ast3_parse(type_comment, "<type_comment>", "eval")
|
||
|
except SyntaxError:
|
||
|
if errors is not None:
|
||
|
stripped_type = type_comment.split("#", 2)[0].strip()
|
||
|
err_msg = message_registry.TYPE_COMMENT_SYNTAX_ERROR_VALUE.format(stripped_type)
|
||
|
errors.report(line, column, err_msg.value, blocker=True, code=err_msg.code)
|
||
|
return None, None
|
||
|
else:
|
||
|
raise
|
||
|
else:
|
||
|
extra_ignore = TYPE_IGNORE_PATTERN.match(type_comment)
|
||
|
if extra_ignore:
|
||
|
tag: str | None = extra_ignore.group(1)
|
||
|
ignored: list[str] | None = parse_type_ignore_tag(tag)
|
||
|
if ignored is None:
|
||
|
if errors is not None:
|
||
|
errors.report(
|
||
|
line, column, message_registry.INVALID_TYPE_IGNORE.value, code=codes.SYNTAX
|
||
|
)
|
||
|
else:
|
||
|
raise SyntaxError
|
||
|
else:
|
||
|
ignored = None
|
||
|
assert isinstance(typ, ast3.Expression)
|
||
|
converted = TypeConverter(
|
||
|
errors, line=line, override_column=column, is_evaluated=False
|
||
|
).visit(typ.body)
|
||
|
return ignored, converted
|
||
|
|
||
|
|
||
|
def parse_type_string(
|
||
|
expr_string: str, expr_fallback_name: str, line: int, column: int
|
||
|
) -> ProperType:
|
||
|
"""Parses a type that was originally present inside of an explicit string.
|
||
|
|
||
|
For example, suppose we have the type `Foo["blah"]`. We should parse the
|
||
|
string expression "blah" using this function.
|
||
|
"""
|
||
|
try:
|
||
|
_, node = parse_type_comment(expr_string.strip(), line=line, column=column, errors=None)
|
||
|
if isinstance(node, UnboundType) and node.original_str_expr is None:
|
||
|
node.original_str_expr = expr_string
|
||
|
node.original_str_fallback = expr_fallback_name
|
||
|
return node
|
||
|
elif isinstance(node, UnionType):
|
||
|
return node
|
||
|
else:
|
||
|
return RawExpressionType(expr_string, expr_fallback_name, line, column)
|
||
|
except (SyntaxError, ValueError):
|
||
|
# Note: the parser will raise a `ValueError` instead of a SyntaxError if
|
||
|
# the string happens to contain things like \x00.
|
||
|
return RawExpressionType(expr_string, expr_fallback_name, line, column)
|
||
|
|
||
|
|
||
|
def is_no_type_check_decorator(expr: ast3.expr) -> bool:
|
||
|
if isinstance(expr, Name):
|
||
|
return expr.id == "no_type_check"
|
||
|
elif isinstance(expr, Attribute):
|
||
|
if isinstance(expr.value, Name):
|
||
|
return expr.value.id == "typing" and expr.attr == "no_type_check"
|
||
|
return False
|
||
|
|
||
|
|
||
|
class ASTConverter:
|
||
|
def __init__(
|
||
|
self,
|
||
|
options: Options,
|
||
|
is_stub: bool,
|
||
|
errors: Errors,
|
||
|
*,
|
||
|
ignore_errors: bool,
|
||
|
strip_function_bodies: bool,
|
||
|
) -> None:
|
||
|
# 'C' for class, 'D' for function signature, 'F' for function, 'L' for lambda
|
||
|
self.class_and_function_stack: list[Literal["C", "D", "F", "L"]] = []
|
||
|
self.imports: list[ImportBase] = []
|
||
|
|
||
|
self.options = options
|
||
|
self.is_stub = is_stub
|
||
|
self.errors = errors
|
||
|
self.ignore_errors = ignore_errors
|
||
|
self.strip_function_bodies = strip_function_bodies
|
||
|
|
||
|
self.type_ignores: dict[int, list[str]] = {}
|
||
|
|
||
|
# Cache of visit_X methods keyed by type of visited object
|
||
|
self.visitor_cache: dict[type, Callable[[AST | None], Any]] = {}
|
||
|
|
||
|
def note(self, msg: str, line: int, column: int) -> None:
|
||
|
self.errors.report(line, column, msg, severity="note", code=codes.SYNTAX)
|
||
|
|
||
|
def fail(self, msg: ErrorMessage, line: int, column: int, blocker: bool = True) -> None:
|
||
|
if blocker or not self.options.ignore_errors:
|
||
|
self.errors.report(line, column, msg.value, blocker=blocker, code=msg.code)
|
||
|
|
||
|
def fail_merge_overload(self, node: IfStmt) -> None:
|
||
|
self.fail(
|
||
|
message_registry.FAILED_TO_MERGE_OVERLOADS,
|
||
|
line=node.line,
|
||
|
column=node.column,
|
||
|
blocker=False,
|
||
|
)
|
||
|
|
||
|
def visit(self, node: AST | None) -> Any:
|
||
|
if node is None:
|
||
|
return None
|
||
|
typeobj = type(node)
|
||
|
visitor = self.visitor_cache.get(typeobj)
|
||
|
if visitor is None:
|
||
|
method = "visit_" + node.__class__.__name__
|
||
|
visitor = getattr(self, method)
|
||
|
self.visitor_cache[typeobj] = visitor
|
||
|
return visitor(node)
|
||
|
|
||
|
def set_line(self, node: N, n: AstNode) -> N:
|
||
|
node.line = n.lineno
|
||
|
node.column = n.col_offset
|
||
|
node.end_line = getattr(n, "end_lineno", None)
|
||
|
node.end_column = getattr(n, "end_col_offset", None)
|
||
|
|
||
|
return node
|
||
|
|
||
|
def translate_opt_expr_list(self, l: Sequence[AST | None]) -> list[Expression | None]:
|
||
|
res: list[Expression | None] = []
|
||
|
for e in l:
|
||
|
exp = self.visit(e)
|
||
|
res.append(exp)
|
||
|
return res
|
||
|
|
||
|
def translate_expr_list(self, l: Sequence[AST]) -> list[Expression]:
|
||
|
return cast(List[Expression], self.translate_opt_expr_list(l))
|
||
|
|
||
|
def get_lineno(self, node: ast3.expr | ast3.stmt) -> int:
|
||
|
if (
|
||
|
isinstance(node, (ast3.AsyncFunctionDef, ast3.ClassDef, ast3.FunctionDef))
|
||
|
and node.decorator_list
|
||
|
):
|
||
|
return node.decorator_list[0].lineno
|
||
|
return node.lineno
|
||
|
|
||
|
def translate_stmt_list(
|
||
|
self,
|
||
|
stmts: Sequence[ast3.stmt],
|
||
|
*,
|
||
|
ismodule: bool = False,
|
||
|
can_strip: bool = False,
|
||
|
is_coroutine: bool = False,
|
||
|
) -> list[Statement]:
|
||
|
# A "# type: ignore" comment before the first statement of a module
|
||
|
# ignores the whole module:
|
||
|
if (
|
||
|
ismodule
|
||
|
and stmts
|
||
|
and self.type_ignores
|
||
|
and min(self.type_ignores) < self.get_lineno(stmts[0])
|
||
|
):
|
||
|
ignores = self.type_ignores[min(self.type_ignores)]
|
||
|
if ignores:
|
||
|
joined_ignores = ", ".join(ignores)
|
||
|
self.fail(
|
||
|
message_registry.TYPE_IGNORE_WITH_ERRCODE_ON_MODULE.format(joined_ignores),
|
||
|
line=min(self.type_ignores),
|
||
|
column=0,
|
||
|
blocker=False,
|
||
|
)
|
||
|
self.errors.used_ignored_lines[self.errors.file][min(self.type_ignores)].append(
|
||
|
codes.FILE.code
|
||
|
)
|
||
|
block = Block(self.fix_function_overloads(self.translate_stmt_list(stmts)))
|
||
|
self.set_block_lines(block, stmts)
|
||
|
mark_block_unreachable(block)
|
||
|
return [block]
|
||
|
|
||
|
stack = self.class_and_function_stack
|
||
|
# Fast case for stripping function bodies
|
||
|
if (
|
||
|
can_strip
|
||
|
and self.strip_function_bodies
|
||
|
and len(stack) == 1
|
||
|
and stack[0] == "F"
|
||
|
and not is_coroutine
|
||
|
):
|
||
|
return []
|
||
|
|
||
|
res: list[Statement] = []
|
||
|
for stmt in stmts:
|
||
|
node = self.visit(stmt)
|
||
|
res.append(node)
|
||
|
|
||
|
# Slow case for stripping function bodies
|
||
|
if can_strip and self.strip_function_bodies:
|
||
|
if stack[-2:] == ["C", "F"]:
|
||
|
if is_possible_trivial_body(res):
|
||
|
can_strip = False
|
||
|
else:
|
||
|
# We only strip method bodies if they don't assign to an attribute, as
|
||
|
# this may define an attribute which has an externally visible effect.
|
||
|
visitor = FindAttributeAssign()
|
||
|
for s in res:
|
||
|
s.accept(visitor)
|
||
|
if visitor.found:
|
||
|
can_strip = False
|
||
|
break
|
||
|
|
||
|
if can_strip and stack[-1] == "F" and is_coroutine:
|
||
|
# Yields inside an async function affect the return type and should not
|
||
|
# be stripped.
|
||
|
yield_visitor = FindYield()
|
||
|
for s in res:
|
||
|
s.accept(yield_visitor)
|
||
|
if yield_visitor.found:
|
||
|
can_strip = False
|
||
|
break
|
||
|
|
||
|
if can_strip:
|
||
|
return []
|
||
|
return res
|
||
|
|
||
|
def translate_type_comment(
|
||
|
self, n: ast3.stmt | ast3.arg, type_comment: str | None
|
||
|
) -> ProperType | None:
|
||
|
if type_comment is None:
|
||
|
return None
|
||
|
else:
|
||
|
lineno = n.lineno
|
||
|
extra_ignore, typ = parse_type_comment(type_comment, lineno, n.col_offset, self.errors)
|
||
|
if extra_ignore is not None:
|
||
|
self.type_ignores[lineno] = extra_ignore
|
||
|
return typ
|
||
|
|
||
|
op_map: Final[dict[type[AST], str]] = {
|
||
|
ast3.Add: "+",
|
||
|
ast3.Sub: "-",
|
||
|
ast3.Mult: "*",
|
||
|
ast3.MatMult: "@",
|
||
|
ast3.Div: "/",
|
||
|
ast3.Mod: "%",
|
||
|
ast3.Pow: "**",
|
||
|
ast3.LShift: "<<",
|
||
|
ast3.RShift: ">>",
|
||
|
ast3.BitOr: "|",
|
||
|
ast3.BitXor: "^",
|
||
|
ast3.BitAnd: "&",
|
||
|
ast3.FloorDiv: "//",
|
||
|
}
|
||
|
|
||
|
def from_operator(self, op: ast3.operator) -> str:
|
||
|
op_name = ASTConverter.op_map.get(type(op))
|
||
|
if op_name is None:
|
||
|
raise RuntimeError("Unknown operator " + str(type(op)))
|
||
|
else:
|
||
|
return op_name
|
||
|
|
||
|
comp_op_map: Final[dict[type[AST], str]] = {
|
||
|
ast3.Gt: ">",
|
||
|
ast3.Lt: "<",
|
||
|
ast3.Eq: "==",
|
||
|
ast3.GtE: ">=",
|
||
|
ast3.LtE: "<=",
|
||
|
ast3.NotEq: "!=",
|
||
|
ast3.Is: "is",
|
||
|
ast3.IsNot: "is not",
|
||
|
ast3.In: "in",
|
||
|
ast3.NotIn: "not in",
|
||
|
}
|
||
|
|
||
|
def from_comp_operator(self, op: ast3.cmpop) -> str:
|
||
|
op_name = ASTConverter.comp_op_map.get(type(op))
|
||
|
if op_name is None:
|
||
|
raise RuntimeError("Unknown comparison operator " + str(type(op)))
|
||
|
else:
|
||
|
return op_name
|
||
|
|
||
|
def set_block_lines(self, b: Block, stmts: Sequence[ast3.stmt]) -> None:
|
||
|
first, last = stmts[0], stmts[-1]
|
||
|
b.line = first.lineno
|
||
|
b.column = first.col_offset
|
||
|
b.end_line = getattr(last, "end_lineno", None)
|
||
|
b.end_column = getattr(last, "end_col_offset", None)
|
||
|
if not b.body:
|
||
|
return
|
||
|
new_first = b.body[0]
|
||
|
if isinstance(new_first, (Decorator, OverloadedFuncDef)):
|
||
|
# Decorated function lines are different between Python versions.
|
||
|
# copy the normalization we do for them to block first lines.
|
||
|
b.line = new_first.line
|
||
|
b.column = new_first.column
|
||
|
|
||
|
def as_block(self, stmts: list[ast3.stmt]) -> Block | None:
|
||
|
b = None
|
||
|
if stmts:
|
||
|
b = Block(self.fix_function_overloads(self.translate_stmt_list(stmts)))
|
||
|
self.set_block_lines(b, stmts)
|
||
|
return b
|
||
|
|
||
|
def as_required_block(
|
||
|
self, stmts: list[ast3.stmt], *, can_strip: bool = False, is_coroutine: bool = False
|
||
|
) -> Block:
|
||
|
assert stmts # must be non-empty
|
||
|
b = Block(
|
||
|
self.fix_function_overloads(
|
||
|
self.translate_stmt_list(stmts, can_strip=can_strip, is_coroutine=is_coroutine)
|
||
|
)
|
||
|
)
|
||
|
self.set_block_lines(b, stmts)
|
||
|
return b
|
||
|
|
||
|
def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
|
||
|
ret: list[Statement] = []
|
||
|
current_overload: list[OverloadPart] = []
|
||
|
current_overload_name: str | None = None
|
||
|
seen_unconditional_func_def = False
|
||
|
last_if_stmt: IfStmt | None = None
|
||
|
last_if_overload: Decorator | FuncDef | OverloadedFuncDef | None = None
|
||
|
last_if_stmt_overload_name: str | None = None
|
||
|
last_if_unknown_truth_value: IfStmt | None = None
|
||
|
skipped_if_stmts: list[IfStmt] = []
|
||
|
for stmt in stmts:
|
||
|
if_overload_name: str | None = None
|
||
|
if_block_with_overload: Block | None = None
|
||
|
if_unknown_truth_value: IfStmt | None = None
|
||
|
if isinstance(stmt, IfStmt) and seen_unconditional_func_def is False:
|
||
|
# Check IfStmt block to determine if function overloads can be merged
|
||
|
if_overload_name = self._check_ifstmt_for_overloads(stmt, current_overload_name)
|
||
|
if if_overload_name is not None:
|
||
|
(
|
||
|
if_block_with_overload,
|
||
|
if_unknown_truth_value,
|
||
|
) = self._get_executable_if_block_with_overloads(stmt)
|
||
|
|
||
|
if (
|
||
|
current_overload_name is not None
|
||
|
and isinstance(stmt, (Decorator, FuncDef))
|
||
|
and stmt.name == current_overload_name
|
||
|
):
|
||
|
if last_if_stmt is not None:
|
||
|
skipped_if_stmts.append(last_if_stmt)
|
||
|
if last_if_overload is not None:
|
||
|
# Last stmt was an IfStmt with same overload name
|
||
|
# Add overloads to current_overload
|
||
|
if isinstance(last_if_overload, OverloadedFuncDef):
|
||
|
current_overload.extend(last_if_overload.items)
|
||
|
else:
|
||
|
current_overload.append(last_if_overload)
|
||
|
last_if_stmt, last_if_overload = None, None
|
||
|
if last_if_unknown_truth_value:
|
||
|
self.fail_merge_overload(last_if_unknown_truth_value)
|
||
|
last_if_unknown_truth_value = None
|
||
|
current_overload.append(stmt)
|
||
|
if isinstance(stmt, FuncDef):
|
||
|
seen_unconditional_func_def = True
|
||
|
elif (
|
||
|
current_overload_name is not None
|
||
|
and isinstance(stmt, IfStmt)
|
||
|
and if_overload_name == current_overload_name
|
||
|
):
|
||
|
# IfStmt only contains stmts relevant to current_overload.
|
||
|
# Check if stmts are reachable and add them to current_overload,
|
||
|
# otherwise skip IfStmt to allow subsequent overload
|
||
|
# or function definitions.
|
||
|
skipped_if_stmts.append(stmt)
|
||
|
if if_block_with_overload is None:
|
||
|
if if_unknown_truth_value is not None:
|
||
|
self.fail_merge_overload(if_unknown_truth_value)
|
||
|
continue
|
||
|
if last_if_overload is not None:
|
||
|
# Last stmt was an IfStmt with same overload name
|
||
|
# Add overloads to current_overload
|
||
|
if isinstance(last_if_overload, OverloadedFuncDef):
|
||
|
current_overload.extend(last_if_overload.items)
|
||
|
else:
|
||
|
current_overload.append(last_if_overload)
|
||
|
last_if_stmt, last_if_overload = None, None
|
||
|
if isinstance(if_block_with_overload.body[-1], OverloadedFuncDef):
|
||
|
skipped_if_stmts.extend(cast(List[IfStmt], if_block_with_overload.body[:-1]))
|
||
|
current_overload.extend(if_block_with_overload.body[-1].items)
|
||
|
else:
|
||
|
current_overload.append(
|
||
|
cast(Union[Decorator, FuncDef], if_block_with_overload.body[0])
|
||
|
)
|
||
|
else:
|
||
|
if last_if_stmt is not None:
|
||
|
ret.append(last_if_stmt)
|
||
|
last_if_stmt_overload_name = current_overload_name
|
||
|
last_if_stmt, last_if_overload = None, None
|
||
|
last_if_unknown_truth_value = None
|
||
|
|
||
|
if current_overload and current_overload_name == last_if_stmt_overload_name:
|
||
|
# Remove last stmt (IfStmt) from ret if the overload names matched
|
||
|
# Only happens if no executable block had been found in IfStmt
|
||
|
popped = ret.pop()
|
||
|
assert isinstance(popped, IfStmt)
|
||
|
skipped_if_stmts.append(popped)
|
||
|
if current_overload and skipped_if_stmts:
|
||
|
# Add bare IfStmt (without overloads) to ret
|
||
|
# Required for mypy to be able to still check conditions
|
||
|
for if_stmt in skipped_if_stmts:
|
||
|
self._strip_contents_from_if_stmt(if_stmt)
|
||
|
ret.append(if_stmt)
|
||
|
skipped_if_stmts = []
|
||
|
if len(current_overload) == 1:
|
||
|
ret.append(current_overload[0])
|
||
|
elif len(current_overload) > 1:
|
||
|
ret.append(OverloadedFuncDef(current_overload))
|
||
|
|
||
|
# If we have multiple decorated functions named "_" next to each, we want to treat
|
||
|
# them as a series of regular FuncDefs instead of one OverloadedFuncDef because
|
||
|
# most of mypy/mypyc assumes that all the functions in an OverloadedFuncDef are
|
||
|
# related, but multiple underscore functions next to each other aren't necessarily
|
||
|
# related
|
||
|
seen_unconditional_func_def = False
|
||
|
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
|
||
|
current_overload = [stmt]
|
||
|
current_overload_name = stmt.name
|
||
|
elif isinstance(stmt, IfStmt) and if_overload_name is not None:
|
||
|
current_overload = []
|
||
|
current_overload_name = if_overload_name
|
||
|
last_if_stmt = stmt
|
||
|
last_if_stmt_overload_name = None
|
||
|
if if_block_with_overload is not None:
|
||
|
skipped_if_stmts.extend(
|
||
|
cast(List[IfStmt], if_block_with_overload.body[:-1])
|
||
|
)
|
||
|
last_if_overload = cast(
|
||
|
Union[Decorator, FuncDef, OverloadedFuncDef],
|
||
|
if_block_with_overload.body[-1],
|
||
|
)
|
||
|
last_if_unknown_truth_value = if_unknown_truth_value
|
||
|
else:
|
||
|
current_overload = []
|
||
|
current_overload_name = None
|
||
|
ret.append(stmt)
|
||
|
|
||
|
if current_overload and skipped_if_stmts:
|
||
|
# Add bare IfStmt (without overloads) to ret
|
||
|
# Required for mypy to be able to still check conditions
|
||
|
for if_stmt in skipped_if_stmts:
|
||
|
self._strip_contents_from_if_stmt(if_stmt)
|
||
|
ret.append(if_stmt)
|
||
|
if len(current_overload) == 1:
|
||
|
ret.append(current_overload[0])
|
||
|
elif len(current_overload) > 1:
|
||
|
ret.append(OverloadedFuncDef(current_overload))
|
||
|
elif last_if_overload is not None:
|
||
|
ret.append(last_if_overload)
|
||
|
elif last_if_stmt is not None:
|
||
|
ret.append(last_if_stmt)
|
||
|
return ret
|
||
|
|
||
|
def _check_ifstmt_for_overloads(
|
||
|
self, stmt: IfStmt, current_overload_name: str | None = None
|
||
|
) -> str | None:
|
||
|
"""Check if IfStmt contains only overloads with the same name.
|
||
|
Return overload_name if found, None otherwise.
|
||
|
"""
|
||
|
# Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef.
|
||
|
# Multiple overloads have already been merged as OverloadedFuncDef.
|
||
|
if not (
|
||
|
len(stmt.body[0].body) == 1
|
||
|
and (
|
||
|
isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
|
||
|
or current_overload_name is not None
|
||
|
and isinstance(stmt.body[0].body[0], FuncDef)
|
||
|
)
|
||
|
or len(stmt.body[0].body) > 1
|
||
|
and isinstance(stmt.body[0].body[-1], OverloadedFuncDef)
|
||
|
and all(self._is_stripped_if_stmt(if_stmt) for if_stmt in stmt.body[0].body[:-1])
|
||
|
):
|
||
|
return None
|
||
|
|
||
|
overload_name = cast(
|
||
|
Union[Decorator, FuncDef, OverloadedFuncDef], stmt.body[0].body[-1]
|
||
|
).name
|
||
|
if stmt.else_body is None:
|
||
|
return overload_name
|
||
|
|
||
|
if len(stmt.else_body.body) == 1:
|
||
|
# For elif: else_body contains an IfStmt itself -> do a recursive check.
|
||
|
if (
|
||
|
isinstance(stmt.else_body.body[0], (Decorator, FuncDef, OverloadedFuncDef))
|
||
|
and stmt.else_body.body[0].name == overload_name
|
||
|
):
|
||
|
return overload_name
|
||
|
if (
|
||
|
isinstance(stmt.else_body.body[0], IfStmt)
|
||
|
and self._check_ifstmt_for_overloads(stmt.else_body.body[0], current_overload_name)
|
||
|
== overload_name
|
||
|
):
|
||
|
return overload_name
|
||
|
|
||
|
return None
|
||
|
|
||
|
def _get_executable_if_block_with_overloads(
|
||
|
self, stmt: IfStmt
|
||
|
) -> tuple[Block | None, IfStmt | None]:
|
||
|
"""Return block from IfStmt that will get executed.
|
||
|
|
||
|
Return
|
||
|
0 -> A block if sure that alternative blocks are unreachable.
|
||
|
1 -> An IfStmt if the reachability of it can't be inferred,
|
||
|
i.e. the truth value is unknown.
|
||
|
"""
|
||
|
infer_reachability_of_if_statement(stmt, self.options)
|
||
|
if stmt.else_body is None and stmt.body[0].is_unreachable is True:
|
||
|
# always False condition with no else
|
||
|
return None, None
|
||
|
if (
|
||
|
stmt.else_body is None
|
||
|
or stmt.body[0].is_unreachable is False
|
||
|
and stmt.else_body.is_unreachable is False
|
||
|
):
|
||
|
# The truth value is unknown, thus not conclusive
|
||
|
return None, stmt
|
||
|
if stmt.else_body.is_unreachable is True:
|
||
|
# else_body will be set unreachable if condition is always True
|
||
|
return stmt.body[0], None
|
||
|
if stmt.body[0].is_unreachable is True:
|
||
|
# body will be set unreachable if condition is always False
|
||
|
# else_body can contain an IfStmt itself (for elif) -> do a recursive check
|
||
|
if isinstance(stmt.else_body.body[0], IfStmt):
|
||
|
return self._get_executable_if_block_with_overloads(stmt.else_body.body[0])
|
||
|
return stmt.else_body, None
|
||
|
return None, stmt
|
||
|
|
||
|
def _strip_contents_from_if_stmt(self, stmt: IfStmt) -> None:
|
||
|
"""Remove contents from IfStmt.
|
||
|
|
||
|
Needed to still be able to check the conditions after the contents
|
||
|
have been merged with the surrounding function overloads.
|
||
|
"""
|
||
|
if len(stmt.body) == 1:
|
||
|
stmt.body[0].body = []
|
||
|
if stmt.else_body and len(stmt.else_body.body) == 1:
|
||
|
if isinstance(stmt.else_body.body[0], IfStmt):
|
||
|
self._strip_contents_from_if_stmt(stmt.else_body.body[0])
|
||
|
else:
|
||
|
stmt.else_body.body = []
|
||
|
|
||
|
def _is_stripped_if_stmt(self, stmt: Statement) -> bool:
|
||
|
"""Check stmt to make sure it is a stripped IfStmt.
|
||
|
|
||
|
See also: _strip_contents_from_if_stmt
|
||
|
"""
|
||
|
if not isinstance(stmt, IfStmt):
|
||
|
return False
|
||
|
|
||
|
if not (len(stmt.body) == 1 and len(stmt.body[0].body) == 0):
|
||
|
# Body not empty
|
||
|
return False
|
||
|
|
||
|
if not stmt.else_body or len(stmt.else_body.body) == 0:
|
||
|
# No or empty else_body
|
||
|
return True
|
||
|
|
||
|
# For elif, IfStmt are stored recursively in else_body
|
||
|
return self._is_stripped_if_stmt(stmt.else_body.body[0])
|
||
|
|
||
|
def translate_module_id(self, id: str) -> str:
|
||
|
"""Return the actual, internal module id for a source text id."""
|
||
|
if id == self.options.custom_typing_module:
|
||
|
return "typing"
|
||
|
return id
|
||
|
|
||
|
def visit_Module(self, mod: ast3.Module) -> MypyFile:
|
||
|
self.type_ignores = {}
|
||
|
for ti in mod.type_ignores:
|
||
|
parsed = parse_type_ignore_tag(ti.tag)
|
||
|
if parsed is not None:
|
||
|
self.type_ignores[ti.lineno] = parsed
|
||
|
else:
|
||
|
self.fail(message_registry.INVALID_TYPE_IGNORE, ti.lineno, -1, blocker=False)
|
||
|
body = self.fix_function_overloads(self.translate_stmt_list(mod.body, ismodule=True))
|
||
|
return MypyFile(body, self.imports, False, self.type_ignores)
|
||
|
|
||
|
# --- stmt ---
|
||
|
# FunctionDef(identifier name, arguments args,
|
||
|
# stmt* body, expr* decorator_list, expr? returns, string? type_comment)
|
||
|
# arguments = (arg* args, arg? vararg, arg* kwonlyargs, expr* kw_defaults,
|
||
|
# arg? kwarg, expr* defaults)
|
||
|
def visit_FunctionDef(self, n: ast3.FunctionDef) -> FuncDef | Decorator:
|
||
|
return self.do_func_def(n)
|
||
|
|
||
|
# AsyncFunctionDef(identifier name, arguments args,
|
||
|
# stmt* body, expr* decorator_list, expr? returns, string? type_comment)
|
||
|
def visit_AsyncFunctionDef(self, n: ast3.AsyncFunctionDef) -> FuncDef | Decorator:
|
||
|
return self.do_func_def(n, is_coroutine=True)
|
||
|
|
||
|
def do_func_def(
|
||
|
self, n: ast3.FunctionDef | ast3.AsyncFunctionDef, is_coroutine: bool = False
|
||
|
) -> FuncDef | Decorator:
|
||
|
"""Helper shared between visit_FunctionDef and visit_AsyncFunctionDef."""
|
||
|
self.class_and_function_stack.append("D")
|
||
|
no_type_check = bool(
|
||
|
n.decorator_list and any(is_no_type_check_decorator(d) for d in n.decorator_list)
|
||
|
)
|
||
|
|
||
|
lineno = n.lineno
|
||
|
args = self.transform_args(n.args, lineno, no_type_check=no_type_check)
|
||
|
if special_function_elide_names(n.name):
|
||
|
for arg in args:
|
||
|
arg.pos_only = True
|
||
|
|
||
|
arg_kinds = [arg.kind for arg in args]
|
||
|
arg_names = [None if arg.pos_only else arg.variable.name for arg in args]
|
||
|
|
||
|
arg_types: list[Type | None] = []
|
||
|
if no_type_check:
|
||
|
arg_types = [None] * len(args)
|
||
|
return_type = None
|
||
|
elif n.type_comment is not None:
|
||
|
try:
|
||
|
func_type_ast = ast3_parse(n.type_comment, "<func_type>", "func_type")
|
||
|
assert isinstance(func_type_ast, FunctionType)
|
||
|
# for ellipsis arg
|
||
|
if (
|
||
|
len(func_type_ast.argtypes) == 1
|
||
|
and isinstance(func_type_ast.argtypes[0], Constant)
|
||
|
and func_type_ast.argtypes[0].value is Ellipsis
|
||
|
):
|
||
|
if n.returns:
|
||
|
# PEP 484 disallows both type annotations and type comments
|
||
|
self.fail(message_registry.DUPLICATE_TYPE_SIGNATURES, lineno, n.col_offset)
|
||
|
arg_types = [
|
||
|
a.type_annotation
|
||
|
if a.type_annotation is not None
|
||
|
else AnyType(TypeOfAny.unannotated)
|
||
|
for a in args
|
||
|
]
|
||
|
else:
|
||
|
# PEP 484 disallows both type annotations and type comments
|
||
|
if n.returns or any(a.type_annotation is not None for a in args):
|
||
|
self.fail(message_registry.DUPLICATE_TYPE_SIGNATURES, lineno, n.col_offset)
|
||
|
translated_args: list[Type] = TypeConverter(
|
||
|
self.errors, line=lineno, override_column=n.col_offset
|
||
|
).translate_expr_list(func_type_ast.argtypes)
|
||
|
# Use a cast to work around `list` invariance
|
||
|
arg_types = cast(List[Optional[Type]], translated_args)
|
||
|
return_type = TypeConverter(self.errors, line=lineno).visit(func_type_ast.returns)
|
||
|
|
||
|
# add implicit self type
|
||
|
in_method_scope = self.class_and_function_stack[-2:] == ["C", "D"]
|
||
|
if in_method_scope and len(arg_types) < len(args):
|
||
|
arg_types.insert(0, AnyType(TypeOfAny.special_form))
|
||
|
except SyntaxError:
|
||
|
stripped_type = n.type_comment.split("#", 2)[0].strip()
|
||
|
err_msg = message_registry.TYPE_COMMENT_SYNTAX_ERROR_VALUE.format(stripped_type)
|
||
|
self.fail(err_msg, lineno, n.col_offset)
|
||
|
if n.type_comment and n.type_comment[0] not in ["(", "#"]:
|
||
|
self.note(
|
||
|
"Suggestion: wrap argument types in parentheses", lineno, n.col_offset
|
||
|
)
|
||
|
arg_types = [AnyType(TypeOfAny.from_error)] * len(args)
|
||
|
return_type = AnyType(TypeOfAny.from_error)
|
||
|
else:
|
||
|
if sys.version_info >= (3, 12) and n.type_params:
|
||
|
self.fail(
|
||
|
ErrorMessage("PEP 695 generics are not yet supported", code=codes.VALID_TYPE),
|
||
|
n.type_params[0].lineno,
|
||
|
n.type_params[0].col_offset,
|
||
|
blocker=False,
|
||
|
)
|
||
|
|
||
|
arg_types = [a.type_annotation for a in args]
|
||
|
return_type = TypeConverter(
|
||
|
self.errors, line=n.returns.lineno if n.returns else lineno
|
||
|
).visit(n.returns)
|
||
|
|
||
|
for arg, arg_type in zip(args, arg_types):
|
||
|
self.set_type_optional(arg_type, arg.initializer)
|
||
|
|
||
|
func_type = None
|
||
|
if any(arg_types) or return_type:
|
||
|
if len(arg_types) != 1 and any(isinstance(t, EllipsisType) for t in arg_types):
|
||
|
self.fail(message_registry.ELLIPSIS_WITH_OTHER_TYPEARGS, lineno, n.col_offset)
|
||
|
elif len(arg_types) > len(arg_kinds):
|
||
|
self.fail(
|
||
|
message_registry.TYPE_SIGNATURE_TOO_MANY_ARGS,
|
||
|
lineno,
|
||
|
n.col_offset,
|
||
|
blocker=False,
|
||
|
)
|
||
|
elif len(arg_types) < len(arg_kinds):
|
||
|
self.fail(
|
||
|
message_registry.TYPE_SIGNATURE_TOO_FEW_ARGS,
|
||
|
lineno,
|
||
|
n.col_offset,
|
||
|
blocker=False,
|
||
|
)
|
||
|
else:
|
||
|
func_type = CallableType(
|
||
|
[a if a is not None else AnyType(TypeOfAny.unannotated) for a in arg_types],
|
||
|
arg_kinds,
|
||
|
arg_names,
|
||
|
return_type if return_type is not None else AnyType(TypeOfAny.unannotated),
|
||
|
_dummy_fallback,
|
||
|
)
|
||
|
|
||
|
# End position is always the same.
|
||
|
end_line = getattr(n, "end_lineno", None)
|
||
|
end_column = getattr(n, "end_col_offset", None)
|
||
|
|
||
|
self.class_and_function_stack.pop()
|
||
|
self.class_and_function_stack.append("F")
|
||
|
body = self.as_required_block(n.body, can_strip=True, is_coroutine=is_coroutine)
|
||
|
func_def = FuncDef(n.name, args, body, func_type)
|
||
|
if isinstance(func_def.type, CallableType):
|
||
|
# semanal.py does some in-place modifications we want to avoid
|
||
|
func_def.unanalyzed_type = func_def.type.copy_modified()
|
||
|
if is_coroutine:
|
||
|
func_def.is_coroutine = True
|
||
|
if func_type is not None:
|
||
|
func_type.definition = func_def
|
||
|
func_type.line = lineno
|
||
|
|
||
|
if n.decorator_list:
|
||
|
# Set deco_line to the old pre-3.8 lineno, in order to keep
|
||
|
# existing "# type: ignore" comments working:
|
||
|
deco_line = n.decorator_list[0].lineno
|
||
|
|
||
|
var = Var(func_def.name)
|
||
|
var.is_ready = False
|
||
|
var.set_line(lineno)
|
||
|
|
||
|
func_def.is_decorated = True
|
||
|
func_def.deco_line = deco_line
|
||
|
func_def.set_line(lineno, n.col_offset, end_line, end_column)
|
||
|
|
||
|
deco = Decorator(func_def, self.translate_expr_list(n.decorator_list), var)
|
||
|
first = n.decorator_list[0]
|
||
|
deco.set_line(first.lineno, first.col_offset, end_line, end_column)
|
||
|
retval: FuncDef | Decorator = deco
|
||
|
else:
|
||
|
# FuncDef overrides set_line -- can't use self.set_line
|
||
|
func_def.set_line(lineno, n.col_offset, end_line, end_column)
|
||
|
retval = func_def
|
||
|
if self.options.include_docstrings:
|
||
|
func_def.docstring = ast3.get_docstring(n, clean=False)
|
||
|
self.class_and_function_stack.pop()
|
||
|
return retval
|
||
|
|
||
|
def set_type_optional(self, type: Type | None, initializer: Expression | None) -> None:
|
||
|
if not self.options.implicit_optional:
|
||
|
return
|
||
|
# Indicate that type should be wrapped in an Optional if arg is initialized to None.
|
||
|
optional = isinstance(initializer, NameExpr) and initializer.name == "None"
|
||
|
if isinstance(type, UnboundType):
|
||
|
type.optional = optional
|
||
|
|
||
|
def transform_args(
|
||
|
self, args: ast3.arguments, line: int, no_type_check: bool = False
|
||
|
) -> list[Argument]:
|
||
|
new_args = []
|
||
|
names: list[ast3.arg] = []
|
||
|
posonlyargs = getattr(args, "posonlyargs", cast(List[ast3.arg], []))
|
||
|
args_args = posonlyargs + args.args
|
||
|
args_defaults = args.defaults
|
||
|
num_no_defaults = len(args_args) - len(args_defaults)
|
||
|
# positional arguments without defaults
|
||
|
for i, a in enumerate(args_args[:num_no_defaults]):
|
||
|
pos_only = i < len(posonlyargs)
|
||
|
new_args.append(self.make_argument(a, None, ARG_POS, no_type_check, pos_only))
|
||
|
names.append(a)
|
||
|
|
||
|
# positional arguments with defaults
|
||
|
for i, (a, d) in enumerate(zip(args_args[num_no_defaults:], args_defaults)):
|
||
|
pos_only = num_no_defaults + i < len(posonlyargs)
|
||
|
new_args.append(self.make_argument(a, d, ARG_OPT, no_type_check, pos_only))
|
||
|
names.append(a)
|
||
|
|
||
|
# *arg
|
||
|
if args.vararg is not None:
|
||
|
new_args.append(self.make_argument(args.vararg, None, ARG_STAR, no_type_check))
|
||
|
names.append(args.vararg)
|
||
|
|
||
|
# keyword-only arguments with defaults
|
||
|
for a, kd in zip(args.kwonlyargs, args.kw_defaults):
|
||
|
new_args.append(
|
||
|
self.make_argument(
|
||
|
a, kd, ARG_NAMED if kd is None else ARG_NAMED_OPT, no_type_check
|
||
|
)
|
||
|
)
|
||
|
names.append(a)
|
||
|
|
||
|
# **kwarg
|
||
|
if args.kwarg is not None:
|
||
|
new_args.append(self.make_argument(args.kwarg, None, ARG_STAR2, no_type_check))
|
||
|
names.append(args.kwarg)
|
||
|
|
||
|
check_arg_names([arg.variable.name for arg in new_args], names, self.fail_arg)
|
||
|
|
||
|
return new_args
|
||
|
|
||
|
def make_argument(
|
||
|
self,
|
||
|
arg: ast3.arg,
|
||
|
default: ast3.expr | None,
|
||
|
kind: ArgKind,
|
||
|
no_type_check: bool,
|
||
|
pos_only: bool = False,
|
||
|
) -> Argument:
|
||
|
if no_type_check:
|
||
|
arg_type = None
|
||
|
else:
|
||
|
annotation = arg.annotation
|
||
|
type_comment = arg.type_comment
|
||
|
if annotation is not None and type_comment is not None:
|
||
|
self.fail(message_registry.DUPLICATE_TYPE_SIGNATURES, arg.lineno, arg.col_offset)
|
||
|
arg_type = None
|
||
|
if annotation is not None:
|
||
|
arg_type = TypeConverter(self.errors, line=arg.lineno).visit(annotation)
|
||
|
else:
|
||
|
arg_type = self.translate_type_comment(arg, type_comment)
|
||
|
if argument_elide_name(arg.arg):
|
||
|
pos_only = True
|
||
|
|
||
|
argument = Argument(Var(arg.arg), arg_type, self.visit(default), kind, pos_only)
|
||
|
argument.set_line(
|
||
|
arg.lineno,
|
||
|
arg.col_offset,
|
||
|
getattr(arg, "end_lineno", None),
|
||
|
getattr(arg, "end_col_offset", None),
|
||
|
)
|
||
|
return argument
|
||
|
|
||
|
def fail_arg(self, msg: str, arg: ast3.arg) -> None:
|
||
|
self.fail(ErrorMessage(msg), arg.lineno, arg.col_offset)
|
||
|
|
||
|
# ClassDef(identifier name,
|
||
|
# expr* bases,
|
||
|
# keyword* keywords,
|
||
|
# stmt* body,
|
||
|
# expr* decorator_list)
|
||
|
def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef:
|
||
|
self.class_and_function_stack.append("C")
|
||
|
keywords = [(kw.arg, self.visit(kw.value)) for kw in n.keywords if kw.arg]
|
||
|
|
||
|
if sys.version_info >= (3, 12) and n.type_params:
|
||
|
self.fail(
|
||
|
ErrorMessage("PEP 695 generics are not yet supported", code=codes.VALID_TYPE),
|
||
|
n.type_params[0].lineno,
|
||
|
n.type_params[0].col_offset,
|
||
|
blocker=False,
|
||
|
)
|
||
|
|
||
|
cdef = ClassDef(
|
||
|
n.name,
|
||
|
self.as_required_block(n.body),
|
||
|
None,
|
||
|
self.translate_expr_list(n.bases),
|
||
|
metaclass=dict(keywords).get("metaclass"),
|
||
|
keywords=keywords,
|
||
|
)
|
||
|
cdef.decorators = self.translate_expr_list(n.decorator_list)
|
||
|
# Set lines to match the old mypy 0.700 lines, in order to keep
|
||
|
# existing "# type: ignore" comments working:
|
||
|
cdef.line = n.lineno
|
||
|
cdef.deco_line = n.decorator_list[0].lineno if n.decorator_list else None
|
||
|
|
||
|
if self.options.include_docstrings:
|
||
|
cdef.docstring = ast3.get_docstring(n, clean=False)
|
||
|
cdef.column = n.col_offset
|
||
|
cdef.end_line = getattr(n, "end_lineno", None)
|
||
|
cdef.end_column = getattr(n, "end_col_offset", None)
|
||
|
self.class_and_function_stack.pop()
|
||
|
return cdef
|
||
|
|
||
|
# Return(expr? value)
|
||
|
def visit_Return(self, n: ast3.Return) -> ReturnStmt:
|
||
|
node = ReturnStmt(self.visit(n.value))
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# Delete(expr* targets)
|
||
|
def visit_Delete(self, n: ast3.Delete) -> DelStmt:
|
||
|
if len(n.targets) > 1:
|
||
|
tup = TupleExpr(self.translate_expr_list(n.targets))
|
||
|
tup.set_line(n.lineno)
|
||
|
node = DelStmt(tup)
|
||
|
else:
|
||
|
node = DelStmt(self.visit(n.targets[0]))
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# Assign(expr* targets, expr? value, string? type_comment, expr? annotation)
|
||
|
def visit_Assign(self, n: ast3.Assign) -> AssignmentStmt:
|
||
|
lvalues = self.translate_expr_list(n.targets)
|
||
|
rvalue = self.visit(n.value)
|
||
|
typ = self.translate_type_comment(n, n.type_comment)
|
||
|
s = AssignmentStmt(lvalues, rvalue, type=typ, new_syntax=False)
|
||
|
return self.set_line(s, n)
|
||
|
|
||
|
# AnnAssign(expr target, expr annotation, expr? value, int simple)
|
||
|
def visit_AnnAssign(self, n: ast3.AnnAssign) -> AssignmentStmt:
|
||
|
line = n.lineno
|
||
|
if n.value is None: # always allow 'x: int'
|
||
|
rvalue: Expression = TempNode(AnyType(TypeOfAny.special_form), no_rhs=True)
|
||
|
rvalue.line = line
|
||
|
rvalue.column = n.col_offset
|
||
|
else:
|
||
|
rvalue = self.visit(n.value)
|
||
|
typ = TypeConverter(self.errors, line=line).visit(n.annotation)
|
||
|
assert typ is not None
|
||
|
typ.column = n.annotation.col_offset
|
||
|
s = AssignmentStmt([self.visit(n.target)], rvalue, type=typ, new_syntax=True)
|
||
|
return self.set_line(s, n)
|
||
|
|
||
|
# AugAssign(expr target, operator op, expr value)
|
||
|
def visit_AugAssign(self, n: ast3.AugAssign) -> OperatorAssignmentStmt:
|
||
|
s = OperatorAssignmentStmt(
|
||
|
self.from_operator(n.op), self.visit(n.target), self.visit(n.value)
|
||
|
)
|
||
|
return self.set_line(s, n)
|
||
|
|
||
|
# For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)
|
||
|
def visit_For(self, n: ast3.For) -> ForStmt:
|
||
|
target_type = self.translate_type_comment(n, n.type_comment)
|
||
|
node = ForStmt(
|
||
|
self.visit(n.target),
|
||
|
self.visit(n.iter),
|
||
|
self.as_required_block(n.body),
|
||
|
self.as_block(n.orelse),
|
||
|
target_type,
|
||
|
)
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# AsyncFor(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)
|
||
|
def visit_AsyncFor(self, n: ast3.AsyncFor) -> ForStmt:
|
||
|
target_type = self.translate_type_comment(n, n.type_comment)
|
||
|
node = ForStmt(
|
||
|
self.visit(n.target),
|
||
|
self.visit(n.iter),
|
||
|
self.as_required_block(n.body),
|
||
|
self.as_block(n.orelse),
|
||
|
target_type,
|
||
|
)
|
||
|
node.is_async = True
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# While(expr test, stmt* body, stmt* orelse)
|
||
|
def visit_While(self, n: ast3.While) -> WhileStmt:
|
||
|
node = WhileStmt(
|
||
|
self.visit(n.test), self.as_required_block(n.body), self.as_block(n.orelse)
|
||
|
)
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# If(expr test, stmt* body, stmt* orelse)
|
||
|
def visit_If(self, n: ast3.If) -> IfStmt:
|
||
|
node = IfStmt(
|
||
|
[self.visit(n.test)], [self.as_required_block(n.body)], self.as_block(n.orelse)
|
||
|
)
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# With(withitem* items, stmt* body, string? type_comment)
|
||
|
def visit_With(self, n: ast3.With) -> WithStmt:
|
||
|
target_type = self.translate_type_comment(n, n.type_comment)
|
||
|
node = WithStmt(
|
||
|
[self.visit(i.context_expr) for i in n.items],
|
||
|
[self.visit(i.optional_vars) for i in n.items],
|
||
|
self.as_required_block(n.body),
|
||
|
target_type,
|
||
|
)
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# AsyncWith(withitem* items, stmt* body, string? type_comment)
|
||
|
def visit_AsyncWith(self, n: ast3.AsyncWith) -> WithStmt:
|
||
|
target_type = self.translate_type_comment(n, n.type_comment)
|
||
|
s = WithStmt(
|
||
|
[self.visit(i.context_expr) for i in n.items],
|
||
|
[self.visit(i.optional_vars) for i in n.items],
|
||
|
self.as_required_block(n.body),
|
||
|
target_type,
|
||
|
)
|
||
|
s.is_async = True
|
||
|
return self.set_line(s, n)
|
||
|
|
||
|
# Raise(expr? exc, expr? cause)
|
||
|
def visit_Raise(self, n: ast3.Raise) -> RaiseStmt:
|
||
|
node = RaiseStmt(self.visit(n.exc), self.visit(n.cause))
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# Try(stmt* body, excepthandler* handlers, stmt* orelse, stmt* finalbody)
|
||
|
def visit_Try(self, n: ast3.Try) -> TryStmt:
|
||
|
vs = [
|
||
|
self.set_line(NameExpr(h.name), h) if h.name is not None else None for h in n.handlers
|
||
|
]
|
||
|
types = [self.visit(h.type) for h in n.handlers]
|
||
|
handlers = [self.as_required_block(h.body) for h in n.handlers]
|
||
|
|
||
|
node = TryStmt(
|
||
|
self.as_required_block(n.body),
|
||
|
vs,
|
||
|
types,
|
||
|
handlers,
|
||
|
self.as_block(n.orelse),
|
||
|
self.as_block(n.finalbody),
|
||
|
)
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
def visit_TryStar(self, n: TryStar) -> TryStmt:
|
||
|
vs = [
|
||
|
self.set_line(NameExpr(h.name), h) if h.name is not None else None for h in n.handlers
|
||
|
]
|
||
|
types = [self.visit(h.type) for h in n.handlers]
|
||
|
handlers = [self.as_required_block(h.body) for h in n.handlers]
|
||
|
|
||
|
node = TryStmt(
|
||
|
self.as_required_block(n.body),
|
||
|
vs,
|
||
|
types,
|
||
|
handlers,
|
||
|
self.as_block(n.orelse),
|
||
|
self.as_block(n.finalbody),
|
||
|
)
|
||
|
node.is_star = True
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# Assert(expr test, expr? msg)
|
||
|
def visit_Assert(self, n: ast3.Assert) -> AssertStmt:
|
||
|
node = AssertStmt(self.visit(n.test), self.visit(n.msg))
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# Import(alias* names)
|
||
|
def visit_Import(self, n: ast3.Import) -> Import:
|
||
|
names: list[tuple[str, str | None]] = []
|
||
|
for alias in n.names:
|
||
|
name = self.translate_module_id(alias.name)
|
||
|
asname = alias.asname
|
||
|
if asname is None and name != alias.name:
|
||
|
# if the module name has been translated (and it's not already
|
||
|
# an explicit import-as), make it an implicit import-as the
|
||
|
# original name
|
||
|
asname = alias.name
|
||
|
names.append((name, asname))
|
||
|
i = Import(names)
|
||
|
self.imports.append(i)
|
||
|
return self.set_line(i, n)
|
||
|
|
||
|
# ImportFrom(identifier? module, alias* names, int? level)
|
||
|
def visit_ImportFrom(self, n: ast3.ImportFrom) -> ImportBase:
|
||
|
assert n.level is not None
|
||
|
if len(n.names) == 1 and n.names[0].name == "*":
|
||
|
mod = n.module if n.module is not None else ""
|
||
|
i: ImportBase = ImportAll(mod, n.level)
|
||
|
else:
|
||
|
i = ImportFrom(
|
||
|
self.translate_module_id(n.module) if n.module is not None else "",
|
||
|
n.level,
|
||
|
[(a.name, a.asname) for a in n.names],
|
||
|
)
|
||
|
self.imports.append(i)
|
||
|
return self.set_line(i, n)
|
||
|
|
||
|
# Global(identifier* names)
|
||
|
def visit_Global(self, n: ast3.Global) -> GlobalDecl:
|
||
|
g = GlobalDecl(n.names)
|
||
|
return self.set_line(g, n)
|
||
|
|
||
|
# Nonlocal(identifier* names)
|
||
|
def visit_Nonlocal(self, n: ast3.Nonlocal) -> NonlocalDecl:
|
||
|
d = NonlocalDecl(n.names)
|
||
|
return self.set_line(d, n)
|
||
|
|
||
|
# Expr(expr value)
|
||
|
def visit_Expr(self, n: ast3.Expr) -> ExpressionStmt:
|
||
|
value = self.visit(n.value)
|
||
|
node = ExpressionStmt(value)
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# Pass
|
||
|
def visit_Pass(self, n: ast3.Pass) -> PassStmt:
|
||
|
s = PassStmt()
|
||
|
return self.set_line(s, n)
|
||
|
|
||
|
# Break
|
||
|
def visit_Break(self, n: ast3.Break) -> BreakStmt:
|
||
|
s = BreakStmt()
|
||
|
return self.set_line(s, n)
|
||
|
|
||
|
# Continue
|
||
|
def visit_Continue(self, n: ast3.Continue) -> ContinueStmt:
|
||
|
s = ContinueStmt()
|
||
|
return self.set_line(s, n)
|
||
|
|
||
|
# --- expr ---
|
||
|
|
||
|
def visit_NamedExpr(self, n: NamedExpr) -> AssignmentExpr:
|
||
|
s = AssignmentExpr(self.visit(n.target), self.visit(n.value))
|
||
|
return self.set_line(s, n)
|
||
|
|
||
|
# BoolOp(boolop op, expr* values)
|
||
|
def visit_BoolOp(self, n: ast3.BoolOp) -> OpExpr:
|
||
|
# mypy translates (1 and 2 and 3) as (1 and (2 and 3))
|
||
|
assert len(n.values) >= 2
|
||
|
op_node = n.op
|
||
|
if isinstance(op_node, ast3.And):
|
||
|
op = "and"
|
||
|
elif isinstance(op_node, ast3.Or):
|
||
|
op = "or"
|
||
|
else:
|
||
|
raise RuntimeError("unknown BoolOp " + str(type(n)))
|
||
|
|
||
|
# potentially inefficient!
|
||
|
return self.group(op, self.translate_expr_list(n.values), n)
|
||
|
|
||
|
def group(self, op: str, vals: list[Expression], n: ast3.expr) -> OpExpr:
|
||
|
if len(vals) == 2:
|
||
|
e = OpExpr(op, vals[0], vals[1])
|
||
|
else:
|
||
|
e = OpExpr(op, vals[0], self.group(op, vals[1:], n))
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# BinOp(expr left, operator op, expr right)
|
||
|
def visit_BinOp(self, n: ast3.BinOp) -> OpExpr:
|
||
|
op = self.from_operator(n.op)
|
||
|
|
||
|
if op is None:
|
||
|
raise RuntimeError("cannot translate BinOp " + str(type(n.op)))
|
||
|
|
||
|
e = OpExpr(op, self.visit(n.left), self.visit(n.right))
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# UnaryOp(unaryop op, expr operand)
|
||
|
def visit_UnaryOp(self, n: ast3.UnaryOp) -> UnaryExpr:
|
||
|
op = None
|
||
|
if isinstance(n.op, ast3.Invert):
|
||
|
op = "~"
|
||
|
elif isinstance(n.op, ast3.Not):
|
||
|
op = "not"
|
||
|
elif isinstance(n.op, ast3.UAdd):
|
||
|
op = "+"
|
||
|
elif isinstance(n.op, ast3.USub):
|
||
|
op = "-"
|
||
|
|
||
|
if op is None:
|
||
|
raise RuntimeError("cannot translate UnaryOp " + str(type(n.op)))
|
||
|
|
||
|
e = UnaryExpr(op, self.visit(n.operand))
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# Lambda(arguments args, expr body)
|
||
|
def visit_Lambda(self, n: ast3.Lambda) -> LambdaExpr:
|
||
|
body = ast3.Return(n.body)
|
||
|
body.lineno = n.body.lineno
|
||
|
body.col_offset = n.body.col_offset
|
||
|
|
||
|
self.class_and_function_stack.append("L")
|
||
|
e = LambdaExpr(self.transform_args(n.args, n.lineno), self.as_required_block([body]))
|
||
|
self.class_and_function_stack.pop()
|
||
|
e.set_line(n.lineno, n.col_offset) # Overrides set_line -- can't use self.set_line
|
||
|
return e
|
||
|
|
||
|
# IfExp(expr test, expr body, expr orelse)
|
||
|
def visit_IfExp(self, n: ast3.IfExp) -> ConditionalExpr:
|
||
|
e = ConditionalExpr(self.visit(n.test), self.visit(n.body), self.visit(n.orelse))
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# Dict(expr* keys, expr* values)
|
||
|
def visit_Dict(self, n: ast3.Dict) -> DictExpr:
|
||
|
e = DictExpr(
|
||
|
list(zip(self.translate_opt_expr_list(n.keys), self.translate_expr_list(n.values)))
|
||
|
)
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# Set(expr* elts)
|
||
|
def visit_Set(self, n: ast3.Set) -> SetExpr:
|
||
|
e = SetExpr(self.translate_expr_list(n.elts))
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# ListComp(expr elt, comprehension* generators)
|
||
|
def visit_ListComp(self, n: ast3.ListComp) -> ListComprehension:
|
||
|
e = ListComprehension(self.visit_GeneratorExp(cast(ast3.GeneratorExp, n)))
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# SetComp(expr elt, comprehension* generators)
|
||
|
def visit_SetComp(self, n: ast3.SetComp) -> SetComprehension:
|
||
|
e = SetComprehension(self.visit_GeneratorExp(cast(ast3.GeneratorExp, n)))
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# DictComp(expr key, expr value, comprehension* generators)
|
||
|
def visit_DictComp(self, n: ast3.DictComp) -> DictionaryComprehension:
|
||
|
targets = [self.visit(c.target) for c in n.generators]
|
||
|
iters = [self.visit(c.iter) for c in n.generators]
|
||
|
ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators]
|
||
|
is_async = [bool(c.is_async) for c in n.generators]
|
||
|
e = DictionaryComprehension(
|
||
|
self.visit(n.key), self.visit(n.value), targets, iters, ifs_list, is_async
|
||
|
)
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# GeneratorExp(expr elt, comprehension* generators)
|
||
|
def visit_GeneratorExp(self, n: ast3.GeneratorExp) -> GeneratorExpr:
|
||
|
targets = [self.visit(c.target) for c in n.generators]
|
||
|
iters = [self.visit(c.iter) for c in n.generators]
|
||
|
ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators]
|
||
|
is_async = [bool(c.is_async) for c in n.generators]
|
||
|
e = GeneratorExpr(self.visit(n.elt), targets, iters, ifs_list, is_async)
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# Await(expr value)
|
||
|
def visit_Await(self, n: ast3.Await) -> AwaitExpr:
|
||
|
v = self.visit(n.value)
|
||
|
e = AwaitExpr(v)
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# Yield(expr? value)
|
||
|
def visit_Yield(self, n: ast3.Yield) -> YieldExpr:
|
||
|
e = YieldExpr(self.visit(n.value))
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# YieldFrom(expr value)
|
||
|
def visit_YieldFrom(self, n: ast3.YieldFrom) -> YieldFromExpr:
|
||
|
e = YieldFromExpr(self.visit(n.value))
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# Compare(expr left, cmpop* ops, expr* comparators)
|
||
|
def visit_Compare(self, n: ast3.Compare) -> ComparisonExpr:
|
||
|
operators = [self.from_comp_operator(o) for o in n.ops]
|
||
|
operands = self.translate_expr_list([n.left] + n.comparators)
|
||
|
e = ComparisonExpr(operators, operands)
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# Call(expr func, expr* args, keyword* keywords)
|
||
|
# keyword = (identifier? arg, expr value)
|
||
|
def visit_Call(self, n: Call) -> CallExpr:
|
||
|
args = n.args
|
||
|
keywords = n.keywords
|
||
|
keyword_names = [k.arg for k in keywords]
|
||
|
arg_types = self.translate_expr_list(
|
||
|
[a.value if isinstance(a, Starred) else a for a in args] + [k.value for k in keywords]
|
||
|
)
|
||
|
arg_kinds = [ARG_STAR if type(a) is Starred else ARG_POS for a in args] + [
|
||
|
ARG_STAR2 if arg is None else ARG_NAMED for arg in keyword_names
|
||
|
]
|
||
|
e = CallExpr(
|
||
|
self.visit(n.func),
|
||
|
arg_types,
|
||
|
arg_kinds,
|
||
|
cast("List[Optional[str]]", [None] * len(args)) + keyword_names,
|
||
|
)
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# Constant(object value) -- a constant, in Python 3.8.
|
||
|
def visit_Constant(self, n: Constant) -> Any:
|
||
|
val = n.value
|
||
|
e: Any = None
|
||
|
if val is None:
|
||
|
e = NameExpr("None")
|
||
|
elif isinstance(val, str):
|
||
|
e = StrExpr(val)
|
||
|
elif isinstance(val, bytes):
|
||
|
e = BytesExpr(bytes_to_human_readable_repr(val))
|
||
|
elif isinstance(val, bool): # Must check before int!
|
||
|
e = NameExpr(str(val))
|
||
|
elif isinstance(val, int):
|
||
|
e = IntExpr(val)
|
||
|
elif isinstance(val, float):
|
||
|
e = FloatExpr(val)
|
||
|
elif isinstance(val, complex):
|
||
|
e = ComplexExpr(val)
|
||
|
elif val is Ellipsis:
|
||
|
e = EllipsisExpr()
|
||
|
else:
|
||
|
raise RuntimeError("Constant not implemented for " + str(type(val)))
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# JoinedStr(expr* values)
|
||
|
def visit_JoinedStr(self, n: ast3.JoinedStr) -> Expression:
|
||
|
# Each of n.values is a str or FormattedValue; we just concatenate
|
||
|
# them all using ''.join.
|
||
|
empty_string = StrExpr("")
|
||
|
empty_string.set_line(n.lineno, n.col_offset)
|
||
|
strs_to_join = ListExpr(self.translate_expr_list(n.values))
|
||
|
strs_to_join.set_line(empty_string)
|
||
|
# Don't make unnecessary join call if there is only one str to join
|
||
|
if len(strs_to_join.items) == 1:
|
||
|
return self.set_line(strs_to_join.items[0], n)
|
||
|
elif len(strs_to_join.items) > 1:
|
||
|
last = strs_to_join.items[-1]
|
||
|
if isinstance(last, StrExpr) and last.value == "":
|
||
|
# 3.12 can add an empty literal at the end. Delete it for consistency
|
||
|
# between Python versions.
|
||
|
del strs_to_join.items[-1:]
|
||
|
join_method = MemberExpr(empty_string, "join")
|
||
|
join_method.set_line(empty_string)
|
||
|
result_expression = CallExpr(join_method, [strs_to_join], [ARG_POS], [None])
|
||
|
return self.set_line(result_expression, n)
|
||
|
|
||
|
# FormattedValue(expr value)
|
||
|
def visit_FormattedValue(self, n: ast3.FormattedValue) -> Expression:
|
||
|
# A FormattedValue is a component of a JoinedStr, or it can exist
|
||
|
# on its own. We translate them to individual '{}'.format(value)
|
||
|
# calls. Format specifier and conversion information is passed along
|
||
|
# to allow mypyc to support f-strings with format specifiers and conversions.
|
||
|
val_exp = self.visit(n.value)
|
||
|
val_exp.set_line(n.lineno, n.col_offset)
|
||
|
conv_str = "" if n.conversion < 0 else "!" + chr(n.conversion)
|
||
|
format_string = StrExpr("{" + conv_str + ":{}}")
|
||
|
format_spec_exp = self.visit(n.format_spec) if n.format_spec is not None else StrExpr("")
|
||
|
format_string.set_line(n.lineno, n.col_offset)
|
||
|
format_method = MemberExpr(format_string, "format")
|
||
|
format_method.set_line(format_string)
|
||
|
result_expression = CallExpr(
|
||
|
format_method, [val_exp, format_spec_exp], [ARG_POS, ARG_POS], [None, None]
|
||
|
)
|
||
|
return self.set_line(result_expression, n)
|
||
|
|
||
|
# Attribute(expr value, identifier attr, expr_context ctx)
|
||
|
def visit_Attribute(self, n: Attribute) -> MemberExpr | SuperExpr:
|
||
|
value = n.value
|
||
|
member_expr = MemberExpr(self.visit(value), n.attr)
|
||
|
obj = member_expr.expr
|
||
|
if (
|
||
|
isinstance(obj, CallExpr)
|
||
|
and isinstance(obj.callee, NameExpr)
|
||
|
and obj.callee.name == "super"
|
||
|
):
|
||
|
e: MemberExpr | SuperExpr = SuperExpr(member_expr.name, obj)
|
||
|
else:
|
||
|
e = member_expr
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# Subscript(expr value, slice slice, expr_context ctx)
|
||
|
def visit_Subscript(self, n: ast3.Subscript) -> IndexExpr:
|
||
|
e = IndexExpr(self.visit(n.value), self.visit(n.slice))
|
||
|
self.set_line(e, n)
|
||
|
# alias to please mypyc
|
||
|
is_py38_or_earlier = sys.version_info < (3, 9)
|
||
|
if isinstance(n.slice, ast3.Slice) or (
|
||
|
is_py38_or_earlier and isinstance(n.slice, ast3.ExtSlice)
|
||
|
):
|
||
|
# Before Python 3.9, Slice has no line/column in the raw ast. To avoid incompatibility
|
||
|
# visit_Slice doesn't set_line, even in Python 3.9 on.
|
||
|
# ExtSlice also has no line/column info. In Python 3.9 on, line/column is set for
|
||
|
# e.index when visiting n.slice.
|
||
|
e.index.line = e.line
|
||
|
e.index.column = e.column
|
||
|
return e
|
||
|
|
||
|
# Starred(expr value, expr_context ctx)
|
||
|
def visit_Starred(self, n: Starred) -> StarExpr:
|
||
|
e = StarExpr(self.visit(n.value))
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# Name(identifier id, expr_context ctx)
|
||
|
def visit_Name(self, n: Name) -> NameExpr:
|
||
|
e = NameExpr(n.id)
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# List(expr* elts, expr_context ctx)
|
||
|
def visit_List(self, n: ast3.List) -> ListExpr | TupleExpr:
|
||
|
expr_list: list[Expression] = [self.visit(e) for e in n.elts]
|
||
|
if isinstance(n.ctx, ast3.Store):
|
||
|
# [x, y] = z and (x, y) = z means exactly the same thing
|
||
|
e: ListExpr | TupleExpr = TupleExpr(expr_list)
|
||
|
else:
|
||
|
e = ListExpr(expr_list)
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# Tuple(expr* elts, expr_context ctx)
|
||
|
def visit_Tuple(self, n: ast3.Tuple) -> TupleExpr:
|
||
|
e = TupleExpr(self.translate_expr_list(n.elts))
|
||
|
return self.set_line(e, n)
|
||
|
|
||
|
# --- slice ---
|
||
|
|
||
|
# Slice(expr? lower, expr? upper, expr? step)
|
||
|
def visit_Slice(self, n: ast3.Slice) -> SliceExpr:
|
||
|
return SliceExpr(self.visit(n.lower), self.visit(n.upper), self.visit(n.step))
|
||
|
|
||
|
# ExtSlice(slice* dims)
|
||
|
def visit_ExtSlice(self, n: ast3.ExtSlice) -> TupleExpr:
|
||
|
# cast for mypyc's benefit on Python 3.9
|
||
|
return TupleExpr(self.translate_expr_list(cast(Any, n).dims))
|
||
|
|
||
|
# Index(expr value)
|
||
|
def visit_Index(self, n: Index) -> Node:
|
||
|
# cast for mypyc's benefit on Python 3.9
|
||
|
value = self.visit(cast(Any, n).value)
|
||
|
assert isinstance(value, Node)
|
||
|
return value
|
||
|
|
||
|
# Match(expr subject, match_case* cases) # python 3.10 and later
|
||
|
def visit_Match(self, n: Match) -> MatchStmt:
|
||
|
node = MatchStmt(
|
||
|
self.visit(n.subject),
|
||
|
[self.visit(c.pattern) for c in n.cases],
|
||
|
[self.visit(c.guard) for c in n.cases],
|
||
|
[self.as_required_block(c.body) for c in n.cases],
|
||
|
)
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
def visit_MatchValue(self, n: MatchValue) -> ValuePattern:
|
||
|
node = ValuePattern(self.visit(n.value))
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
def visit_MatchSingleton(self, n: MatchSingleton) -> SingletonPattern:
|
||
|
node = SingletonPattern(n.value)
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
def visit_MatchSequence(self, n: MatchSequence) -> SequencePattern:
|
||
|
patterns = [self.visit(p) for p in n.patterns]
|
||
|
stars = [p for p in patterns if isinstance(p, StarredPattern)]
|
||
|
assert len(stars) < 2
|
||
|
|
||
|
node = SequencePattern(patterns)
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
def visit_MatchStar(self, n: MatchStar) -> StarredPattern:
|
||
|
if n.name is None:
|
||
|
node = StarredPattern(None)
|
||
|
else:
|
||
|
name = self.set_line(NameExpr(n.name), n)
|
||
|
node = StarredPattern(name)
|
||
|
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
def visit_MatchMapping(self, n: MatchMapping) -> MappingPattern:
|
||
|
keys = [self.visit(k) for k in n.keys]
|
||
|
values = [self.visit(v) for v in n.patterns]
|
||
|
|
||
|
if n.rest is None:
|
||
|
rest = None
|
||
|
else:
|
||
|
rest = NameExpr(n.rest)
|
||
|
|
||
|
node = MappingPattern(keys, values, rest)
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
def visit_MatchClass(self, n: MatchClass) -> ClassPattern:
|
||
|
class_ref = self.visit(n.cls)
|
||
|
assert isinstance(class_ref, RefExpr)
|
||
|
positionals = [self.visit(p) for p in n.patterns]
|
||
|
keyword_keys = n.kwd_attrs
|
||
|
keyword_values = [self.visit(p) for p in n.kwd_patterns]
|
||
|
|
||
|
node = ClassPattern(class_ref, positionals, keyword_keys, keyword_values)
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# MatchAs(expr pattern, identifier name)
|
||
|
def visit_MatchAs(self, n: MatchAs) -> AsPattern:
|
||
|
if n.name is None:
|
||
|
name = None
|
||
|
else:
|
||
|
name = NameExpr(n.name)
|
||
|
name = self.set_line(name, n)
|
||
|
node = AsPattern(self.visit(n.pattern), name)
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
# MatchOr(expr* pattern)
|
||
|
def visit_MatchOr(self, n: MatchOr) -> OrPattern:
|
||
|
node = OrPattern([self.visit(pattern) for pattern in n.patterns])
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
def visit_TypeAlias(self, n: ast_TypeAlias) -> AssignmentStmt:
|
||
|
self.fail(
|
||
|
ErrorMessage("PEP 695 type aliases are not yet supported", code=codes.VALID_TYPE),
|
||
|
n.lineno,
|
||
|
n.col_offset,
|
||
|
blocker=False,
|
||
|
)
|
||
|
node = AssignmentStmt([NameExpr(n.name.id)], self.visit(n.value))
|
||
|
return self.set_line(node, n)
|
||
|
|
||
|
|
||
|
class TypeConverter:
|
||
|
def __init__(
|
||
|
self,
|
||
|
errors: Errors | None,
|
||
|
line: int = -1,
|
||
|
override_column: int = -1,
|
||
|
is_evaluated: bool = True,
|
||
|
) -> None:
|
||
|
self.errors = errors
|
||
|
self.line = line
|
||
|
self.override_column = override_column
|
||
|
self.node_stack: list[AST] = []
|
||
|
self.is_evaluated = is_evaluated
|
||
|
self.allow_unpack = False
|
||
|
|
||
|
def convert_column(self, column: int) -> int:
|
||
|
"""Apply column override if defined; otherwise return column.
|
||
|
|
||
|
Column numbers are sometimes incorrect in the AST and the column
|
||
|
override can be used to work around that.
|
||
|
"""
|
||
|
if self.override_column < 0:
|
||
|
return column
|
||
|
else:
|
||
|
return self.override_column
|
||
|
|
||
|
def invalid_type(self, node: AST, note: str | None = None) -> RawExpressionType:
|
||
|
"""Constructs a type representing some expression that normally forms an invalid type.
|
||
|
For example, if we see a type hint that says "3 + 4", we would transform that
|
||
|
expression into a RawExpressionType.
|
||
|
|
||
|
The semantic analysis layer will report an "Invalid type" error when it
|
||
|
encounters this type, along with the given note if one is provided.
|
||
|
|
||
|
See RawExpressionType's docstring for more details on how it's used.
|
||
|
"""
|
||
|
return RawExpressionType(
|
||
|
None, "typing.Any", line=self.line, column=getattr(node, "col_offset", -1), note=note
|
||
|
)
|
||
|
|
||
|
@overload
|
||
|
def visit(self, node: ast3.expr) -> ProperType:
|
||
|
...
|
||
|
|
||
|
@overload
|
||
|
def visit(self, node: AST | None) -> ProperType | None:
|
||
|
...
|
||
|
|
||
|
def visit(self, node: AST | None) -> ProperType | None:
|
||
|
"""Modified visit -- keep track of the stack of nodes"""
|
||
|
if node is None:
|
||
|
return None
|
||
|
self.node_stack.append(node)
|
||
|
try:
|
||
|
method = "visit_" + node.__class__.__name__
|
||
|
visitor = getattr(self, method, None)
|
||
|
if visitor is not None:
|
||
|
typ = visitor(node)
|
||
|
assert isinstance(typ, ProperType)
|
||
|
return typ
|
||
|
else:
|
||
|
return self.invalid_type(node)
|
||
|
finally:
|
||
|
self.node_stack.pop()
|
||
|
|
||
|
def parent(self) -> AST | None:
|
||
|
"""Return the AST node above the one we are processing"""
|
||
|
if len(self.node_stack) < 2:
|
||
|
return None
|
||
|
return self.node_stack[-2]
|
||
|
|
||
|
def fail(self, msg: ErrorMessage, line: int, column: int) -> None:
|
||
|
if self.errors:
|
||
|
self.errors.report(line, column, msg.value, blocker=True, code=msg.code)
|
||
|
|
||
|
def note(self, msg: str, line: int, column: int) -> None:
|
||
|
if self.errors:
|
||
|
self.errors.report(line, column, msg, severity="note", code=codes.SYNTAX)
|
||
|
|
||
|
def translate_expr_list(self, l: Sequence[ast3.expr]) -> list[Type]:
|
||
|
return [self.visit(e) for e in l]
|
||
|
|
||
|
def visit_Call(self, e: Call) -> Type:
|
||
|
# Parse the arg constructor
|
||
|
f = e.func
|
||
|
constructor = stringify_name(f)
|
||
|
|
||
|
if not isinstance(self.parent(), ast3.List):
|
||
|
note = None
|
||
|
if constructor:
|
||
|
note = "Suggestion: use {0}[...] instead of {0}(...)".format(constructor)
|
||
|
return self.invalid_type(e, note=note)
|
||
|
if not constructor:
|
||
|
self.fail(message_registry.ARG_CONSTRUCTOR_NAME_EXPECTED, e.lineno, e.col_offset)
|
||
|
|
||
|
name: str | None = None
|
||
|
default_type = AnyType(TypeOfAny.special_form)
|
||
|
typ: Type = default_type
|
||
|
for i, arg in enumerate(e.args):
|
||
|
if i == 0:
|
||
|
converted = self.visit(arg)
|
||
|
assert converted is not None
|
||
|
typ = converted
|
||
|
elif i == 1:
|
||
|
name = self._extract_argument_name(arg)
|
||
|
else:
|
||
|
self.fail(message_registry.ARG_CONSTRUCTOR_TOO_MANY_ARGS, f.lineno, f.col_offset)
|
||
|
for k in e.keywords:
|
||
|
value = k.value
|
||
|
if k.arg == "name":
|
||
|
if name is not None:
|
||
|
self.fail(
|
||
|
message_registry.MULTIPLE_VALUES_FOR_NAME_KWARG.format(constructor),
|
||
|
f.lineno,
|
||
|
f.col_offset,
|
||
|
)
|
||
|
name = self._extract_argument_name(value)
|
||
|
elif k.arg == "type":
|
||
|
if typ is not default_type:
|
||
|
self.fail(
|
||
|
message_registry.MULTIPLE_VALUES_FOR_TYPE_KWARG.format(constructor),
|
||
|
f.lineno,
|
||
|
f.col_offset,
|
||
|
)
|
||
|
converted = self.visit(value)
|
||
|
assert converted is not None
|
||
|
typ = converted
|
||
|
else:
|
||
|
self.fail(
|
||
|
message_registry.ARG_CONSTRUCTOR_UNEXPECTED_ARG.format(k.arg),
|
||
|
value.lineno,
|
||
|
value.col_offset,
|
||
|
)
|
||
|
return CallableArgument(typ, name, constructor, e.lineno, e.col_offset)
|
||
|
|
||
|
def translate_argument_list(self, l: Sequence[ast3.expr]) -> TypeList:
|
||
|
return TypeList([self.visit(e) for e in l], line=self.line)
|
||
|
|
||
|
def _extract_argument_name(self, n: ast3.expr) -> str | None:
|
||
|
if isinstance(n, Constant) and isinstance(n.value, str):
|
||
|
return n.value.strip()
|
||
|
elif isinstance(n, Constant) and n.value is None:
|
||
|
return None
|
||
|
self.fail(
|
||
|
message_registry.ARG_NAME_EXPECTED_STRING_LITERAL.format(type(n).__name__),
|
||
|
self.line,
|
||
|
0,
|
||
|
)
|
||
|
return None
|
||
|
|
||
|
def visit_Name(self, n: Name) -> Type:
|
||
|
return UnboundType(n.id, line=self.line, column=self.convert_column(n.col_offset))
|
||
|
|
||
|
def visit_BinOp(self, n: ast3.BinOp) -> Type:
|
||
|
if not isinstance(n.op, ast3.BitOr):
|
||
|
return self.invalid_type(n)
|
||
|
|
||
|
left = self.visit(n.left)
|
||
|
right = self.visit(n.right)
|
||
|
return UnionType(
|
||
|
[left, right],
|
||
|
line=self.line,
|
||
|
column=self.convert_column(n.col_offset),
|
||
|
is_evaluated=self.is_evaluated,
|
||
|
uses_pep604_syntax=True,
|
||
|
)
|
||
|
|
||
|
def visit_Constant(self, n: Constant) -> Type:
|
||
|
val = n.value
|
||
|
if val is None:
|
||
|
# None is a type.
|
||
|
return UnboundType("None", line=self.line)
|
||
|
if isinstance(val, str):
|
||
|
# Parse forward reference.
|
||
|
return parse_type_string(val, "builtins.str", self.line, n.col_offset)
|
||
|
if val is Ellipsis:
|
||
|
# '...' is valid in some types.
|
||
|
return EllipsisType(line=self.line)
|
||
|
if isinstance(val, bool):
|
||
|
# Special case for True/False.
|
||
|
return RawExpressionType(val, "builtins.bool", line=self.line)
|
||
|
if isinstance(val, (int, float, complex)):
|
||
|
return self.numeric_type(val, n)
|
||
|
if isinstance(val, bytes):
|
||
|
contents = bytes_to_human_readable_repr(val)
|
||
|
return RawExpressionType(contents, "builtins.bytes", self.line, column=n.col_offset)
|
||
|
# Everything else is invalid.
|
||
|
return self.invalid_type(n)
|
||
|
|
||
|
# UnaryOp(op, operand)
|
||
|
def visit_UnaryOp(self, n: UnaryOp) -> Type:
|
||
|
# We support specifically Literal[-4] and nothing else.
|
||
|
# For example, Literal[+4] or Literal[~6] is not supported.
|
||
|
typ = self.visit(n.operand)
|
||
|
if isinstance(typ, RawExpressionType) and isinstance(n.op, USub):
|
||
|
if isinstance(typ.literal_value, int):
|
||
|
typ.literal_value *= -1
|
||
|
return typ
|
||
|
return self.invalid_type(n)
|
||
|
|
||
|
def numeric_type(self, value: object, n: AST) -> Type:
|
||
|
# The node's field has the type complex, but complex isn't *really*
|
||
|
# a parent of int and float, and this causes isinstance below
|
||
|
# to think that the complex branch is always picked. Avoid
|
||
|
# this by throwing away the type.
|
||
|
if isinstance(value, int):
|
||
|
numeric_value: int | None = value
|
||
|
type_name = "builtins.int"
|
||
|
else:
|
||
|
# Other kinds of numbers (floats, complex) are not valid parameters for
|
||
|
# RawExpressionType so we just pass in 'None' for now. We'll report the
|
||
|
# appropriate error at a later stage.
|
||
|
numeric_value = None
|
||
|
type_name = f"builtins.{type(value).__name__}"
|
||
|
return RawExpressionType(
|
||
|
numeric_value, type_name, line=self.line, column=getattr(n, "col_offset", -1)
|
||
|
)
|
||
|
|
||
|
def visit_Index(self, n: ast3.Index) -> Type:
|
||
|
# cast for mypyc's benefit on Python 3.9
|
||
|
value = self.visit(cast(Any, n).value)
|
||
|
assert isinstance(value, Type)
|
||
|
return value
|
||
|
|
||
|
def visit_Slice(self, n: ast3.Slice) -> Type:
|
||
|
return self.invalid_type(n, note="did you mean to use ',' instead of ':' ?")
|
||
|
|
||
|
# Subscript(expr value, slice slice, expr_context ctx) # Python 3.8 and before
|
||
|
# Subscript(expr value, expr slice, expr_context ctx) # Python 3.9 and later
|
||
|
def visit_Subscript(self, n: ast3.Subscript) -> Type:
|
||
|
if sys.version_info >= (3, 9): # Really 3.9a5 or later
|
||
|
sliceval: Any = n.slice
|
||
|
# Python 3.8 or earlier use a different AST structure for subscripts
|
||
|
elif isinstance(n.slice, ast3.Index):
|
||
|
sliceval: Any = n.slice.value
|
||
|
elif isinstance(n.slice, ast3.Slice):
|
||
|
sliceval = copy.deepcopy(n.slice) # so we don't mutate passed AST
|
||
|
if getattr(sliceval, "col_offset", None) is None:
|
||
|
# Fix column information so that we get Python 3.9+ message order
|
||
|
sliceval.col_offset = sliceval.lower.col_offset
|
||
|
else:
|
||
|
assert isinstance(n.slice, ast3.ExtSlice)
|
||
|
dims = copy.deepcopy(n.slice.dims)
|
||
|
for s in dims:
|
||
|
if getattr(s, "col_offset", None) is None:
|
||
|
if isinstance(s, ast3.Index):
|
||
|
s.col_offset = s.value.col_offset
|
||
|
elif isinstance(s, ast3.Slice):
|
||
|
assert s.lower is not None
|
||
|
s.col_offset = s.lower.col_offset
|
||
|
sliceval = ast3.Tuple(dims, n.ctx)
|
||
|
|
||
|
empty_tuple_index = False
|
||
|
if isinstance(sliceval, ast3.Tuple):
|
||
|
params = self.translate_expr_list(sliceval.elts)
|
||
|
if len(sliceval.elts) == 0:
|
||
|
empty_tuple_index = True
|
||
|
else:
|
||
|
params = [self.visit(sliceval)]
|
||
|
|
||
|
value = self.visit(n.value)
|
||
|
if isinstance(value, UnboundType) and not value.args:
|
||
|
return UnboundType(
|
||
|
value.name,
|
||
|
params,
|
||
|
line=self.line,
|
||
|
column=value.column,
|
||
|
empty_tuple_index=empty_tuple_index,
|
||
|
)
|
||
|
else:
|
||
|
return self.invalid_type(n)
|
||
|
|
||
|
def visit_Tuple(self, n: ast3.Tuple) -> Type:
|
||
|
return TupleType(
|
||
|
self.translate_expr_list(n.elts),
|
||
|
_dummy_fallback,
|
||
|
implicit=True,
|
||
|
line=self.line,
|
||
|
column=self.convert_column(n.col_offset),
|
||
|
)
|
||
|
|
||
|
# Attribute(expr value, identifier attr, expr_context ctx)
|
||
|
def visit_Attribute(self, n: Attribute) -> Type:
|
||
|
before_dot = self.visit(n.value)
|
||
|
|
||
|
if isinstance(before_dot, UnboundType) and not before_dot.args:
|
||
|
return UnboundType(f"{before_dot.name}.{n.attr}", line=self.line)
|
||
|
else:
|
||
|
return self.invalid_type(n)
|
||
|
|
||
|
# Used for Callable[[X *Ys, Z], R]
|
||
|
def visit_Starred(self, n: ast3.Starred) -> Type:
|
||
|
return UnpackType(self.visit(n.value))
|
||
|
|
||
|
# List(expr* elts, expr_context ctx)
|
||
|
def visit_List(self, n: ast3.List) -> Type:
|
||
|
assert isinstance(n.ctx, ast3.Load)
|
||
|
old_allow_unpack = self.allow_unpack
|
||
|
# We specifically only allow starred expressions in a list to avoid
|
||
|
# confusing errors for top-level unpacks (e.g. in base classes).
|
||
|
self.allow_unpack = True
|
||
|
result = self.translate_argument_list(n.elts)
|
||
|
self.allow_unpack = old_allow_unpack
|
||
|
return result
|
||
|
|
||
|
|
||
|
def stringify_name(n: AST) -> str | None:
|
||
|
if isinstance(n, Name):
|
||
|
return n.id
|
||
|
elif isinstance(n, Attribute):
|
||
|
sv = stringify_name(n.value)
|
||
|
if sv is not None:
|
||
|
return f"{sv}.{n.attr}"
|
||
|
return None # Can't do it.
|
||
|
|
||
|
|
||
|
class FindAttributeAssign(TraverserVisitor):
|
||
|
"""Check if an AST contains attribute assignments (e.g. self.x = 0)."""
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
self.lvalue = False
|
||
|
self.found = False
|
||
|
|
||
|
def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
|
||
|
self.lvalue = True
|
||
|
for lv in s.lvalues:
|
||
|
lv.accept(self)
|
||
|
self.lvalue = False
|
||
|
|
||
|
def visit_with_stmt(self, s: WithStmt) -> None:
|
||
|
self.lvalue = True
|
||
|
for lv in s.target:
|
||
|
if lv is not None:
|
||
|
lv.accept(self)
|
||
|
self.lvalue = False
|
||
|
s.body.accept(self)
|
||
|
|
||
|
def visit_for_stmt(self, s: ForStmt) -> None:
|
||
|
self.lvalue = True
|
||
|
s.index.accept(self)
|
||
|
self.lvalue = False
|
||
|
s.body.accept(self)
|
||
|
if s.else_body:
|
||
|
s.else_body.accept(self)
|
||
|
|
||
|
def visit_expression_stmt(self, s: ExpressionStmt) -> None:
|
||
|
# No need to look inside these
|
||
|
pass
|
||
|
|
||
|
def visit_call_expr(self, e: CallExpr) -> None:
|
||
|
# No need to look inside these
|
||
|
pass
|
||
|
|
||
|
def visit_index_expr(self, e: IndexExpr) -> None:
|
||
|
# No need to look inside these
|
||
|
pass
|
||
|
|
||
|
def visit_member_expr(self, e: MemberExpr) -> None:
|
||
|
if self.lvalue:
|
||
|
self.found = True
|
||
|
|
||
|
|
||
|
class FindYield(TraverserVisitor):
|
||
|
"""Check if an AST contains yields or yield froms."""
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
self.found = False
|
||
|
|
||
|
def visit_yield_expr(self, e: YieldExpr) -> None:
|
||
|
self.found = True
|
||
|
|
||
|
def visit_yield_from_expr(self, e: YieldFromExpr) -> None:
|
||
|
self.found = True
|
||
|
|
||
|
|
||
|
def is_possible_trivial_body(s: list[Statement]) -> bool:
|
||
|
"""Could the statements form a "trivial" function body, such as 'pass'?
|
||
|
|
||
|
This mimics mypy.semanal.is_trivial_body, but this runs before
|
||
|
semantic analysis so some checks must be conservative.
|
||
|
"""
|
||
|
l = len(s)
|
||
|
if l == 0:
|
||
|
return False
|
||
|
i = 0
|
||
|
if isinstance(s[0], ExpressionStmt) and isinstance(s[0].expr, StrExpr):
|
||
|
# Skip docstring
|
||
|
i += 1
|
||
|
if i == l:
|
||
|
return True
|
||
|
if l > i + 1:
|
||
|
return False
|
||
|
stmt = s[i]
|
||
|
return isinstance(stmt, (PassStmt, RaiseStmt)) or (
|
||
|
isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, EllipsisExpr)
|
||
|
)
|