Tipragot
628be439b8
Cela permet de ne pas avoir de problèmes de compatibilité car python est dans le git.
2033 lines
75 KiB
Python
2033 lines
75 KiB
Python
#!/usr/bin/env python3
|
|
"""Generator of dynamically typed draft stubs for arbitrary modules.
|
|
|
|
The logic of this script can be split in three steps:
|
|
* parsing options and finding sources:
|
|
- use runtime imports be default (to find also C modules)
|
|
- or use mypy's mechanisms, if importing is prohibited
|
|
* (optionally) semantically analysing the sources using mypy (as a single set)
|
|
* emitting the stubs text:
|
|
- for Python modules: from ASTs using StubGenerator
|
|
- for C modules using runtime introspection and (optionally) Sphinx docs
|
|
|
|
During first and third steps some problematic files can be skipped, but any
|
|
blocking error during second step will cause the whole program to stop.
|
|
|
|
Basic usage:
|
|
|
|
$ stubgen foo.py bar.py some_directory
|
|
=> Generate out/foo.pyi, out/bar.pyi, and stubs for some_directory (recursively).
|
|
|
|
$ stubgen -m urllib.parse
|
|
=> Generate out/urllib/parse.pyi.
|
|
|
|
$ stubgen -p urllib
|
|
=> Generate stubs for whole urlib package (recursively).
|
|
|
|
For C modules, you can get more precise function signatures by parsing .rst (Sphinx)
|
|
documentation for extra information. For this, use the --doc-dir option:
|
|
|
|
$ stubgen --doc-dir <DIR>/Python-3.4.2/Doc/library -m curses
|
|
|
|
Note: The generated stubs should be verified manually.
|
|
|
|
TODO:
|
|
- maybe use .rst docs also for Python modules
|
|
- maybe export more imported names if there is no __all__ (this affects ssl.SSLError, for example)
|
|
- a quick and dirty heuristic would be to turn this on if a module has something like
|
|
'from x import y as _y'
|
|
- we don't seem to always detect properties ('closed' in 'io', for example)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import glob
|
|
import keyword
|
|
import os
|
|
import os.path
|
|
import sys
|
|
import traceback
|
|
from collections import defaultdict
|
|
from typing import Final, Iterable, Mapping
|
|
|
|
import mypy.build
|
|
import mypy.mixedtraverser
|
|
import mypy.parse
|
|
import mypy.traverser
|
|
import mypy.util
|
|
from mypy.build import build
|
|
from mypy.errors import CompileError, Errors
|
|
from mypy.find_sources import InvalidSourceList, create_source_list
|
|
from mypy.modulefinder import (
|
|
BuildSource,
|
|
FindModuleCache,
|
|
ModuleNotFoundReason,
|
|
SearchPaths,
|
|
default_lib_path,
|
|
)
|
|
from mypy.moduleinspect import ModuleInspect
|
|
from mypy.nodes import (
|
|
ARG_NAMED,
|
|
ARG_POS,
|
|
ARG_STAR,
|
|
ARG_STAR2,
|
|
IS_ABSTRACT,
|
|
NOT_ABSTRACT,
|
|
AssignmentStmt,
|
|
Block,
|
|
BytesExpr,
|
|
CallExpr,
|
|
ClassDef,
|
|
ComparisonExpr,
|
|
ComplexExpr,
|
|
Decorator,
|
|
DictExpr,
|
|
EllipsisExpr,
|
|
Expression,
|
|
FloatExpr,
|
|
FuncBase,
|
|
FuncDef,
|
|
IfStmt,
|
|
Import,
|
|
ImportAll,
|
|
ImportFrom,
|
|
IndexExpr,
|
|
IntExpr,
|
|
ListExpr,
|
|
MemberExpr,
|
|
MypyFile,
|
|
NameExpr,
|
|
OpExpr,
|
|
OverloadedFuncDef,
|
|
Statement,
|
|
StrExpr,
|
|
TempNode,
|
|
TupleExpr,
|
|
TypeInfo,
|
|
UnaryExpr,
|
|
Var,
|
|
)
|
|
from mypy.options import Options as MypyOptions
|
|
from mypy.stubdoc import Sig, find_unique_signatures, parse_all_signatures
|
|
from mypy.stubgenc import (
|
|
DocstringSignatureGenerator,
|
|
ExternalSignatureGenerator,
|
|
FallbackSignatureGenerator,
|
|
SignatureGenerator,
|
|
generate_stub_for_c_module,
|
|
)
|
|
from mypy.stubutil import (
|
|
CantImport,
|
|
common_dir_prefix,
|
|
fail_missing,
|
|
find_module_path_and_all_py3,
|
|
generate_guarded,
|
|
remove_misplaced_type_comments,
|
|
report_missing,
|
|
walk_packages,
|
|
)
|
|
from mypy.traverser import (
|
|
all_yield_expressions,
|
|
has_return_statement,
|
|
has_yield_expression,
|
|
has_yield_from_expression,
|
|
)
|
|
from mypy.types import (
|
|
OVERLOAD_NAMES,
|
|
TPDICT_NAMES,
|
|
TYPED_NAMEDTUPLE_NAMES,
|
|
AnyType,
|
|
CallableType,
|
|
Instance,
|
|
NoneType,
|
|
TupleType,
|
|
Type,
|
|
TypeList,
|
|
TypeStrVisitor,
|
|
UnboundType,
|
|
UnionType,
|
|
get_proper_type,
|
|
)
|
|
from mypy.visitor import NodeVisitor
|
|
|
|
TYPING_MODULE_NAMES: Final = ("typing", "typing_extensions")
|
|
|
|
# Common ways of naming package containing vendored modules.
|
|
VENDOR_PACKAGES: Final = ["packages", "vendor", "vendored", "_vendor", "_vendored_packages"]
|
|
|
|
# Avoid some file names that are unnecessary or likely to cause trouble (\n for end of path).
|
|
BLACKLIST: Final = [
|
|
"/six.py\n", # Likely vendored six; too dynamic for us to handle
|
|
"/vendored/", # Vendored packages
|
|
"/vendor/", # Vendored packages
|
|
"/_vendor/",
|
|
"/_vendored_packages/",
|
|
]
|
|
|
|
# Special-cased names that are implicitly exported from the stub (from m import y as y).
|
|
EXTRA_EXPORTED: Final = {
|
|
"pyasn1_modules.rfc2437.univ",
|
|
"pyasn1_modules.rfc2459.char",
|
|
"pyasn1_modules.rfc2459.univ",
|
|
}
|
|
|
|
# These names should be omitted from generated stubs.
|
|
IGNORED_DUNDERS: Final = {
|
|
"__all__",
|
|
"__author__",
|
|
"__version__",
|
|
"__about__",
|
|
"__copyright__",
|
|
"__email__",
|
|
"__license__",
|
|
"__summary__",
|
|
"__title__",
|
|
"__uri__",
|
|
"__str__",
|
|
"__repr__",
|
|
"__getstate__",
|
|
"__setstate__",
|
|
"__slots__",
|
|
}
|
|
|
|
# These methods are expected to always return a non-trivial value.
|
|
METHODS_WITH_RETURN_VALUE: Final = {
|
|
"__ne__",
|
|
"__eq__",
|
|
"__lt__",
|
|
"__le__",
|
|
"__gt__",
|
|
"__ge__",
|
|
"__hash__",
|
|
"__iter__",
|
|
}
|
|
|
|
# These magic methods always return the same type.
|
|
KNOWN_MAGIC_METHODS_RETURN_TYPES: Final = {
|
|
"__len__": "int",
|
|
"__length_hint__": "int",
|
|
"__init__": "None",
|
|
"__del__": "None",
|
|
"__bool__": "bool",
|
|
"__bytes__": "bytes",
|
|
"__format__": "str",
|
|
"__contains__": "bool",
|
|
"__complex__": "complex",
|
|
"__int__": "int",
|
|
"__float__": "float",
|
|
"__index__": "int",
|
|
}
|
|
|
|
|
|
class Options:
|
|
"""Represents stubgen options.
|
|
|
|
This class is mutable to simplify testing.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
pyversion: tuple[int, int],
|
|
no_import: bool,
|
|
doc_dir: str,
|
|
search_path: list[str],
|
|
interpreter: str,
|
|
parse_only: bool,
|
|
ignore_errors: bool,
|
|
include_private: bool,
|
|
output_dir: str,
|
|
modules: list[str],
|
|
packages: list[str],
|
|
files: list[str],
|
|
verbose: bool,
|
|
quiet: bool,
|
|
export_less: bool,
|
|
include_docstrings: bool,
|
|
) -> None:
|
|
# See parse_options for descriptions of the flags.
|
|
self.pyversion = pyversion
|
|
self.no_import = no_import
|
|
self.doc_dir = doc_dir
|
|
self.search_path = search_path
|
|
self.interpreter = interpreter
|
|
self.decointerpreter = interpreter
|
|
self.parse_only = parse_only
|
|
self.ignore_errors = ignore_errors
|
|
self.include_private = include_private
|
|
self.output_dir = output_dir
|
|
self.modules = modules
|
|
self.packages = packages
|
|
self.files = files
|
|
self.verbose = verbose
|
|
self.quiet = quiet
|
|
self.export_less = export_less
|
|
self.include_docstrings = include_docstrings
|
|
|
|
|
|
class StubSource:
|
|
"""A single source for stub: can be a Python or C module.
|
|
|
|
A simple extension of BuildSource that also carries the AST and
|
|
the value of __all__ detected at runtime.
|
|
"""
|
|
|
|
def __init__(
|
|
self, module: str, path: str | None = None, runtime_all: list[str] | None = None
|
|
) -> None:
|
|
self.source = BuildSource(path, module, None)
|
|
self.runtime_all = runtime_all
|
|
self.ast: MypyFile | None = None
|
|
|
|
@property
|
|
def module(self) -> str:
|
|
return self.source.module
|
|
|
|
@property
|
|
def path(self) -> str | None:
|
|
return self.source.path
|
|
|
|
|
|
# What was generated previously in the stub file. We keep track of these to generate
|
|
# nicely formatted output (add empty line between non-empty classes, for example).
|
|
EMPTY: Final = "EMPTY"
|
|
FUNC: Final = "FUNC"
|
|
CLASS: Final = "CLASS"
|
|
EMPTY_CLASS: Final = "EMPTY_CLASS"
|
|
VAR: Final = "VAR"
|
|
NOT_IN_ALL: Final = "NOT_IN_ALL"
|
|
|
|
# Indicates that we failed to generate a reasonable output
|
|
# for a given node. These should be manually replaced by a user.
|
|
|
|
ERROR_MARKER: Final = "<ERROR>"
|
|
|
|
|
|
class AnnotationPrinter(TypeStrVisitor):
|
|
"""Visitor used to print existing annotations in a file.
|
|
|
|
The main difference from TypeStrVisitor is a better treatment of
|
|
unbound types.
|
|
|
|
Notes:
|
|
* This visitor doesn't add imports necessary for annotations, this is done separately
|
|
by ImportTracker.
|
|
* It can print all kinds of types, but the generated strings may not be valid (notably
|
|
callable types) since it prints the same string that reveal_type() does.
|
|
* For Instance types it prints the fully qualified names.
|
|
"""
|
|
|
|
# TODO: Generate valid string representation for callable types.
|
|
# TODO: Use short names for Instances.
|
|
def __init__(self, stubgen: StubGenerator) -> None:
|
|
super().__init__(options=mypy.options.Options())
|
|
self.stubgen = stubgen
|
|
|
|
def visit_any(self, t: AnyType) -> str:
|
|
s = super().visit_any(t)
|
|
self.stubgen.import_tracker.require_name(s)
|
|
return s
|
|
|
|
def visit_unbound_type(self, t: UnboundType) -> str:
|
|
s = t.name
|
|
self.stubgen.import_tracker.require_name(s)
|
|
if t.args:
|
|
s += f"[{self.args_str(t.args)}]"
|
|
return s
|
|
|
|
def visit_none_type(self, t: NoneType) -> str:
|
|
return "None"
|
|
|
|
def visit_type_list(self, t: TypeList) -> str:
|
|
return f"[{self.list_str(t.items)}]"
|
|
|
|
def visit_union_type(self, t: UnionType) -> str:
|
|
return " | ".join([item.accept(self) for item in t.items])
|
|
|
|
def args_str(self, args: Iterable[Type]) -> str:
|
|
"""Convert an array of arguments to strings and join the results with commas.
|
|
|
|
The main difference from list_str is the preservation of quotes for string
|
|
arguments
|
|
"""
|
|
types = ["builtins.bytes", "builtins.str"]
|
|
res = []
|
|
for arg in args:
|
|
arg_str = arg.accept(self)
|
|
if isinstance(arg, UnboundType) and arg.original_str_fallback in types:
|
|
res.append(f"'{arg_str}'")
|
|
else:
|
|
res.append(arg_str)
|
|
return ", ".join(res)
|
|
|
|
|
|
class AliasPrinter(NodeVisitor[str]):
|
|
"""Visitor used to collect type aliases _and_ type variable definitions.
|
|
|
|
Visit r.h.s of the definition to get the string representation of type alias.
|
|
"""
|
|
|
|
def __init__(self, stubgen: StubGenerator) -> None:
|
|
self.stubgen = stubgen
|
|
super().__init__()
|
|
|
|
def visit_call_expr(self, node: CallExpr) -> str:
|
|
# Call expressions are not usually types, but we also treat `X = TypeVar(...)` as a
|
|
# type alias that has to be preserved (even if TypeVar is not the same as an alias)
|
|
callee = node.callee.accept(self)
|
|
args = []
|
|
for name, arg, kind in zip(node.arg_names, node.args, node.arg_kinds):
|
|
if kind == ARG_POS:
|
|
args.append(arg.accept(self))
|
|
elif kind == ARG_STAR:
|
|
args.append("*" + arg.accept(self))
|
|
elif kind == ARG_STAR2:
|
|
args.append("**" + arg.accept(self))
|
|
elif kind == ARG_NAMED:
|
|
args.append(f"{name}={arg.accept(self)}")
|
|
else:
|
|
raise ValueError(f"Unknown argument kind {kind} in call")
|
|
return f"{callee}({', '.join(args)})"
|
|
|
|
def visit_name_expr(self, node: NameExpr) -> str:
|
|
self.stubgen.import_tracker.require_name(node.name)
|
|
return node.name
|
|
|
|
def visit_member_expr(self, o: MemberExpr) -> str:
|
|
node: Expression = o
|
|
trailer = ""
|
|
while isinstance(node, MemberExpr):
|
|
trailer = "." + node.name + trailer
|
|
node = node.expr
|
|
if not isinstance(node, NameExpr):
|
|
return ERROR_MARKER
|
|
self.stubgen.import_tracker.require_name(node.name)
|
|
return node.name + trailer
|
|
|
|
def visit_str_expr(self, node: StrExpr) -> str:
|
|
return repr(node.value)
|
|
|
|
def visit_index_expr(self, node: IndexExpr) -> str:
|
|
base = node.base.accept(self)
|
|
index = node.index.accept(self)
|
|
if len(index) > 2 and index.startswith("(") and index.endswith(")"):
|
|
index = index[1:-1]
|
|
return f"{base}[{index}]"
|
|
|
|
def visit_tuple_expr(self, node: TupleExpr) -> str:
|
|
return f"({', '.join(n.accept(self) for n in node.items)})"
|
|
|
|
def visit_list_expr(self, node: ListExpr) -> str:
|
|
return f"[{', '.join(n.accept(self) for n in node.items)}]"
|
|
|
|
def visit_dict_expr(self, o: DictExpr) -> str:
|
|
dict_items = []
|
|
for key, value in o.items:
|
|
# This is currently only used for TypedDict where all keys are strings.
|
|
assert isinstance(key, StrExpr)
|
|
dict_items.append(f"{key.accept(self)}: {value.accept(self)}")
|
|
return f"{{{', '.join(dict_items)}}}"
|
|
|
|
def visit_ellipsis(self, node: EllipsisExpr) -> str:
|
|
return "..."
|
|
|
|
def visit_op_expr(self, o: OpExpr) -> str:
|
|
return f"{o.left.accept(self)} {o.op} {o.right.accept(self)}"
|
|
|
|
|
|
class ImportTracker:
|
|
"""Record necessary imports during stub generation."""
|
|
|
|
def __init__(self) -> None:
|
|
# module_for['foo'] has the module name where 'foo' was imported from, or None if
|
|
# 'foo' is a module imported directly; examples
|
|
# 'from pkg.m import f as foo' ==> module_for['foo'] == 'pkg.m'
|
|
# 'from m import f' ==> module_for['f'] == 'm'
|
|
# 'import m' ==> module_for['m'] == None
|
|
# 'import pkg.m' ==> module_for['pkg.m'] == None
|
|
# ==> module_for['pkg'] == None
|
|
self.module_for: dict[str, str | None] = {}
|
|
|
|
# direct_imports['foo'] is the module path used when the name 'foo' was added to the
|
|
# namespace.
|
|
# import foo.bar.baz ==> direct_imports['foo'] == 'foo.bar.baz'
|
|
# ==> direct_imports['foo.bar'] == 'foo.bar.baz'
|
|
# ==> direct_imports['foo.bar.baz'] == 'foo.bar.baz'
|
|
self.direct_imports: dict[str, str] = {}
|
|
|
|
# reverse_alias['foo'] is the name that 'foo' had originally when imported with an
|
|
# alias; examples
|
|
# 'import numpy as np' ==> reverse_alias['np'] == 'numpy'
|
|
# 'import foo.bar as bar' ==> reverse_alias['bar'] == 'foo.bar'
|
|
# 'from decimal import Decimal as D' ==> reverse_alias['D'] == 'Decimal'
|
|
self.reverse_alias: dict[str, str] = {}
|
|
|
|
# required_names is the set of names that are actually used in a type annotation
|
|
self.required_names: set[str] = set()
|
|
|
|
# Names that should be reexported if they come from another module
|
|
self.reexports: set[str] = set()
|
|
|
|
def add_import_from(self, module: str, names: list[tuple[str, str | None]]) -> None:
|
|
for name, alias in names:
|
|
if alias:
|
|
# 'from {module} import {name} as {alias}'
|
|
self.module_for[alias] = module
|
|
self.reverse_alias[alias] = name
|
|
else:
|
|
# 'from {module} import {name}'
|
|
self.module_for[name] = module
|
|
self.reverse_alias.pop(name, None)
|
|
self.direct_imports.pop(alias or name, None)
|
|
|
|
def add_import(self, module: str, alias: str | None = None) -> None:
|
|
if alias:
|
|
# 'import {module} as {alias}'
|
|
self.module_for[alias] = None
|
|
self.reverse_alias[alias] = module
|
|
else:
|
|
# 'import {module}'
|
|
name = module
|
|
# add module and its parent packages
|
|
while name:
|
|
self.module_for[name] = None
|
|
self.direct_imports[name] = module
|
|
self.reverse_alias.pop(name, None)
|
|
name = name.rpartition(".")[0]
|
|
|
|
def require_name(self, name: str) -> None:
|
|
self.required_names.add(name.split(".")[0])
|
|
|
|
def reexport(self, name: str) -> None:
|
|
"""Mark a given non qualified name as needed in __all__.
|
|
|
|
This means that in case it comes from a module, it should be
|
|
imported with an alias even is the alias is the same as the name.
|
|
"""
|
|
self.require_name(name)
|
|
self.reexports.add(name)
|
|
|
|
def import_lines(self) -> list[str]:
|
|
"""The list of required import lines (as strings with python code)."""
|
|
result = []
|
|
|
|
# To summarize multiple names imported from a same module, we collect those
|
|
# in the `module_map` dictionary, mapping a module path to the list of names that should
|
|
# be imported from it. the names can also be alias in the form 'original as alias'
|
|
module_map: Mapping[str, list[str]] = defaultdict(list)
|
|
|
|
for name in sorted(self.required_names):
|
|
# If we haven't seen this name in an import statement, ignore it
|
|
if name not in self.module_for:
|
|
continue
|
|
|
|
m = self.module_for[name]
|
|
if m is not None:
|
|
# This name was found in a from ... import ...
|
|
# Collect the name in the module_map
|
|
if name in self.reverse_alias:
|
|
name = f"{self.reverse_alias[name]} as {name}"
|
|
elif name in self.reexports:
|
|
name = f"{name} as {name}"
|
|
module_map[m].append(name)
|
|
else:
|
|
# This name was found in an import ...
|
|
# We can already generate the import line
|
|
if name in self.reverse_alias:
|
|
source = self.reverse_alias[name]
|
|
result.append(f"import {source} as {name}\n")
|
|
elif name in self.reexports:
|
|
assert "." not in name # Because reexports only has nonqualified names
|
|
result.append(f"import {name} as {name}\n")
|
|
else:
|
|
result.append(f"import {self.direct_imports[name]}\n")
|
|
|
|
# Now generate all the from ... import ... lines collected in module_map
|
|
for module, names in sorted(module_map.items()):
|
|
result.append(f"from {module} import {', '.join(sorted(names))}\n")
|
|
return result
|
|
|
|
|
|
def find_defined_names(file: MypyFile) -> set[str]:
|
|
finder = DefinitionFinder()
|
|
file.accept(finder)
|
|
return finder.names
|
|
|
|
|
|
class DefinitionFinder(mypy.traverser.TraverserVisitor):
|
|
"""Find names of things defined at the top level of a module."""
|
|
|
|
# TODO: Assignment statements etc.
|
|
|
|
def __init__(self) -> None:
|
|
# Short names of things defined at the top level.
|
|
self.names: set[str] = set()
|
|
|
|
def visit_class_def(self, o: ClassDef) -> None:
|
|
# Don't recurse into classes, as we only keep track of top-level definitions.
|
|
self.names.add(o.name)
|
|
|
|
def visit_func_def(self, o: FuncDef) -> None:
|
|
# Don't recurse, as we only keep track of top-level definitions.
|
|
self.names.add(o.name)
|
|
|
|
|
|
def find_referenced_names(file: MypyFile) -> set[str]:
|
|
finder = ReferenceFinder()
|
|
file.accept(finder)
|
|
return finder.refs
|
|
|
|
|
|
class ReferenceFinder(mypy.mixedtraverser.MixedTraverserVisitor):
|
|
"""Find all name references (both local and global)."""
|
|
|
|
# TODO: Filter out local variable and class attribute references
|
|
|
|
def __init__(self) -> None:
|
|
# Short names of things defined at the top level.
|
|
self.refs: set[str] = set()
|
|
|
|
def visit_block(self, block: Block) -> None:
|
|
if not block.is_unreachable:
|
|
super().visit_block(block)
|
|
|
|
def visit_name_expr(self, e: NameExpr) -> None:
|
|
self.refs.add(e.name)
|
|
|
|
def visit_instance(self, t: Instance) -> None:
|
|
self.add_ref(t.type.fullname)
|
|
super().visit_instance(t)
|
|
|
|
def visit_unbound_type(self, t: UnboundType) -> None:
|
|
if t.name:
|
|
self.add_ref(t.name)
|
|
|
|
def visit_tuple_type(self, t: TupleType) -> None:
|
|
# Ignore fallback
|
|
for item in t.items:
|
|
item.accept(self)
|
|
|
|
def visit_callable_type(self, t: CallableType) -> None:
|
|
# Ignore fallback
|
|
for arg in t.arg_types:
|
|
arg.accept(self)
|
|
t.ret_type.accept(self)
|
|
|
|
def add_ref(self, fullname: str) -> None:
|
|
self.refs.add(fullname.split(".")[-1])
|
|
|
|
|
|
class StubGenerator(mypy.traverser.TraverserVisitor):
|
|
"""Generate stub text from a mypy AST."""
|
|
|
|
def __init__(
|
|
self,
|
|
_all_: list[str] | None,
|
|
include_private: bool = False,
|
|
analyzed: bool = False,
|
|
export_less: bool = False,
|
|
include_docstrings: bool = False,
|
|
) -> None:
|
|
# Best known value of __all__.
|
|
self._all_ = _all_
|
|
self._output: list[str] = []
|
|
self._decorators: list[str] = []
|
|
self._import_lines: list[str] = []
|
|
# Current indent level (indent is hardcoded to 4 spaces).
|
|
self._indent = ""
|
|
# Stack of defined variables (per scope).
|
|
self._vars: list[list[str]] = [[]]
|
|
# What was generated previously in the stub file.
|
|
self._state = EMPTY
|
|
self._toplevel_names: list[str] = []
|
|
self._include_private = include_private
|
|
self._include_docstrings = include_docstrings
|
|
self._current_class: ClassDef | None = None
|
|
self.import_tracker = ImportTracker()
|
|
# Was the tree semantically analysed before?
|
|
self.analyzed = analyzed
|
|
# Disable implicit exports of package-internal imports?
|
|
self.export_less = export_less
|
|
# Add imports that could be implicitly generated
|
|
self.import_tracker.add_import_from("typing", [("NamedTuple", None)])
|
|
# Names in __all__ are required
|
|
for name in _all_ or ():
|
|
if name not in IGNORED_DUNDERS:
|
|
self.import_tracker.reexport(name)
|
|
self.defined_names: set[str] = set()
|
|
# Short names of methods defined in the body of the current class
|
|
self.method_names: set[str] = set()
|
|
|
|
def visit_mypy_file(self, o: MypyFile) -> None:
|
|
self.module = o.fullname # Current module being processed
|
|
self.path = o.path
|
|
self.defined_names = find_defined_names(o)
|
|
self.referenced_names = find_referenced_names(o)
|
|
known_imports = {
|
|
"_typeshed": ["Incomplete"],
|
|
"typing": ["Any", "TypeVar", "NamedTuple"],
|
|
"collections.abc": ["Generator"],
|
|
"typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"],
|
|
}
|
|
for pkg, imports in known_imports.items():
|
|
for t in imports:
|
|
if t not in self.defined_names:
|
|
alias = None
|
|
else:
|
|
alias = "_" + t
|
|
self.import_tracker.add_import_from(pkg, [(t, alias)])
|
|
super().visit_mypy_file(o)
|
|
undefined_names = [name for name in self._all_ or [] if name not in self._toplevel_names]
|
|
if undefined_names:
|
|
if self._state != EMPTY:
|
|
self.add("\n")
|
|
self.add("# Names in __all__ with no definition:\n")
|
|
for name in sorted(undefined_names):
|
|
self.add(f"# {name}\n")
|
|
|
|
def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
|
|
"""@property with setters and getters, @overload chain and some others."""
|
|
overload_chain = False
|
|
for item in o.items:
|
|
if not isinstance(item, Decorator):
|
|
continue
|
|
if self.is_private_name(item.func.name, item.func.fullname):
|
|
continue
|
|
|
|
self.process_decorator(item)
|
|
if not overload_chain:
|
|
self.visit_func_def(item.func)
|
|
if item.func.is_overload:
|
|
overload_chain = True
|
|
elif item.func.is_overload:
|
|
self.visit_func_def(item.func)
|
|
else:
|
|
# skip the overload implementation and clear the decorator we just processed
|
|
self.clear_decorators()
|
|
|
|
def visit_func_def(self, o: FuncDef) -> None:
|
|
if (
|
|
self.is_private_name(o.name, o.fullname)
|
|
or self.is_not_in_all(o.name)
|
|
or (self.is_recorded_name(o.name) and not o.is_overload)
|
|
):
|
|
self.clear_decorators()
|
|
return
|
|
if not self._indent and self._state not in (EMPTY, FUNC) and not o.is_awaitable_coroutine:
|
|
self.add("\n")
|
|
if not self.is_top_level():
|
|
self_inits = find_self_initializers(o)
|
|
for init, value in self_inits:
|
|
if init in self.method_names:
|
|
# Can't have both an attribute and a method/property with the same name.
|
|
continue
|
|
init_code = self.get_init(init, value)
|
|
if init_code:
|
|
self.add(init_code)
|
|
# dump decorators, just before "def ..."
|
|
for s in self._decorators:
|
|
self.add(s)
|
|
self.clear_decorators()
|
|
self.add(f"{self._indent}{'async ' if o.is_coroutine else ''}def {o.name}(")
|
|
self.record_name(o.name)
|
|
args: list[str] = []
|
|
for i, arg_ in enumerate(o.arguments):
|
|
var = arg_.variable
|
|
kind = arg_.kind
|
|
name = var.name
|
|
annotated_type = (
|
|
o.unanalyzed_type.arg_types[i]
|
|
if isinstance(o.unanalyzed_type, CallableType)
|
|
else None
|
|
)
|
|
# I think the name check is incorrect: there are libraries which
|
|
# name their 0th argument other than self/cls
|
|
is_self_arg = i == 0 and name == "self"
|
|
is_cls_arg = i == 0 and name == "cls"
|
|
annotation = ""
|
|
if annotated_type and not is_self_arg and not is_cls_arg:
|
|
# Luckily, an argument explicitly annotated with "Any" has
|
|
# type "UnboundType" and will not match.
|
|
if not isinstance(get_proper_type(annotated_type), AnyType):
|
|
annotation = f": {self.print_annotation(annotated_type)}"
|
|
|
|
if kind.is_named() and not any(arg.startswith("*") for arg in args):
|
|
args.append("*")
|
|
|
|
if arg_.initializer:
|
|
if not annotation:
|
|
typename = self.get_str_type_of_node(arg_.initializer, True, False)
|
|
if typename == "":
|
|
annotation = "=..."
|
|
else:
|
|
annotation = f": {typename} = ..."
|
|
else:
|
|
annotation += " = ..."
|
|
arg = name + annotation
|
|
elif kind == ARG_STAR:
|
|
arg = f"*{name}{annotation}"
|
|
elif kind == ARG_STAR2:
|
|
arg = f"**{name}{annotation}"
|
|
else:
|
|
arg = name + annotation
|
|
args.append(arg)
|
|
retname = None
|
|
if o.name != "__init__" and isinstance(o.unanalyzed_type, CallableType):
|
|
if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType):
|
|
# Luckily, a return type explicitly annotated with "Any" has
|
|
# type "UnboundType" and will enter the else branch.
|
|
retname = None # implicit Any
|
|
else:
|
|
retname = self.print_annotation(o.unanalyzed_type.ret_type)
|
|
elif o.abstract_status == IS_ABSTRACT or o.name in METHODS_WITH_RETURN_VALUE:
|
|
# Always assume abstract methods return Any unless explicitly annotated. Also
|
|
# some dunder methods should not have a None return type.
|
|
retname = None # implicit Any
|
|
elif o.name in KNOWN_MAGIC_METHODS_RETURN_TYPES:
|
|
retname = KNOWN_MAGIC_METHODS_RETURN_TYPES[o.name]
|
|
elif has_yield_expression(o) or has_yield_from_expression(o):
|
|
generator_name = self.add_typing_import("Generator")
|
|
yield_name = "None"
|
|
send_name = "None"
|
|
return_name = "None"
|
|
if has_yield_from_expression(o):
|
|
yield_name = send_name = self.add_typing_import("Incomplete")
|
|
else:
|
|
for expr, in_assignment in all_yield_expressions(o):
|
|
if expr.expr is not None and not self.is_none_expr(expr.expr):
|
|
yield_name = self.add_typing_import("Incomplete")
|
|
if in_assignment:
|
|
send_name = self.add_typing_import("Incomplete")
|
|
if has_return_statement(o):
|
|
return_name = self.add_typing_import("Incomplete")
|
|
retname = f"{generator_name}[{yield_name}, {send_name}, {return_name}]"
|
|
elif not has_return_statement(o) and o.abstract_status == NOT_ABSTRACT:
|
|
retname = "None"
|
|
retfield = ""
|
|
if retname is not None:
|
|
retfield = " -> " + retname
|
|
|
|
self.add(", ".join(args))
|
|
self.add(f"){retfield}:")
|
|
if self._include_docstrings and o.docstring:
|
|
docstring = mypy.util.quote_docstring(o.docstring)
|
|
self.add(f"\n{self._indent} {docstring}\n")
|
|
else:
|
|
self.add(" ...\n")
|
|
|
|
self._state = FUNC
|
|
|
|
def is_none_expr(self, expr: Expression) -> bool:
|
|
return isinstance(expr, NameExpr) and expr.name == "None"
|
|
|
|
def visit_decorator(self, o: Decorator) -> None:
|
|
if self.is_private_name(o.func.name, o.func.fullname):
|
|
return
|
|
self.process_decorator(o)
|
|
self.visit_func_def(o.func)
|
|
|
|
def process_decorator(self, o: Decorator) -> None:
|
|
"""Process a series of decorators.
|
|
|
|
Only preserve certain special decorators such as @abstractmethod.
|
|
"""
|
|
for decorator in o.original_decorators:
|
|
if not isinstance(decorator, (NameExpr, MemberExpr)):
|
|
continue
|
|
qualname = get_qualified_name(decorator)
|
|
fullname = self.get_fullname(decorator)
|
|
if fullname in (
|
|
"builtins.property",
|
|
"builtins.staticmethod",
|
|
"builtins.classmethod",
|
|
"functools.cached_property",
|
|
):
|
|
self.add_decorator(qualname, require_name=True)
|
|
elif fullname in (
|
|
"asyncio.coroutine",
|
|
"asyncio.coroutines.coroutine",
|
|
"types.coroutine",
|
|
):
|
|
o.func.is_awaitable_coroutine = True
|
|
self.add_decorator(qualname, require_name=True)
|
|
elif fullname == "abc.abstractmethod":
|
|
self.add_decorator(qualname, require_name=True)
|
|
o.func.abstract_status = IS_ABSTRACT
|
|
elif fullname in (
|
|
"abc.abstractproperty",
|
|
"abc.abstractstaticmethod",
|
|
"abc.abstractclassmethod",
|
|
):
|
|
abc_module = qualname.rpartition(".")[0]
|
|
if not abc_module:
|
|
self.import_tracker.add_import("abc")
|
|
builtin_decorator_replacement = fullname[len("abc.abstract") :]
|
|
self.add_decorator(builtin_decorator_replacement, require_name=False)
|
|
self.add_decorator(f"{abc_module or 'abc'}.abstractmethod", require_name=True)
|
|
o.func.abstract_status = IS_ABSTRACT
|
|
elif fullname in OVERLOAD_NAMES:
|
|
self.add_decorator(qualname, require_name=True)
|
|
o.func.is_overload = True
|
|
elif qualname.endswith(".setter"):
|
|
self.add_decorator(qualname, require_name=False)
|
|
|
|
def get_fullname(self, expr: Expression) -> str:
|
|
"""Return the full name resolving imports and import aliases."""
|
|
if (
|
|
self.analyzed
|
|
and isinstance(expr, (NameExpr, MemberExpr))
|
|
and expr.fullname
|
|
and not (isinstance(expr.node, Var) and expr.node.is_suppressed_import)
|
|
):
|
|
return expr.fullname
|
|
name = get_qualified_name(expr)
|
|
if "." not in name:
|
|
real_module = self.import_tracker.module_for.get(name)
|
|
real_short = self.import_tracker.reverse_alias.get(name, name)
|
|
if real_module is None and real_short not in self.defined_names:
|
|
real_module = "builtins" # not imported and not defined, must be a builtin
|
|
else:
|
|
name_module, real_short = name.split(".", 1)
|
|
real_module = self.import_tracker.reverse_alias.get(name_module, name_module)
|
|
resolved_name = real_short if real_module is None else f"{real_module}.{real_short}"
|
|
return resolved_name
|
|
|
|
def visit_class_def(self, o: ClassDef) -> None:
|
|
self._current_class = o
|
|
self.method_names = find_method_names(o.defs.body)
|
|
sep: int | None = None
|
|
if not self._indent and self._state != EMPTY:
|
|
sep = len(self._output)
|
|
self.add("\n")
|
|
self.add(f"{self._indent}class {o.name}")
|
|
self.record_name(o.name)
|
|
base_types = self.get_base_types(o)
|
|
if base_types:
|
|
for base in base_types:
|
|
self.import_tracker.require_name(base)
|
|
if isinstance(o.metaclass, (NameExpr, MemberExpr)):
|
|
meta = o.metaclass.accept(AliasPrinter(self))
|
|
base_types.append("metaclass=" + meta)
|
|
elif self.analyzed and o.info.is_abstract and not o.info.is_protocol:
|
|
base_types.append("metaclass=abc.ABCMeta")
|
|
self.import_tracker.add_import("abc")
|
|
self.import_tracker.require_name("abc")
|
|
if base_types:
|
|
self.add(f"({', '.join(base_types)})")
|
|
self.add(":\n")
|
|
self._indent += " "
|
|
if self._include_docstrings and o.docstring:
|
|
docstring = mypy.util.quote_docstring(o.docstring)
|
|
self.add(f"{self._indent}{docstring}\n")
|
|
n = len(self._output)
|
|
self._vars.append([])
|
|
super().visit_class_def(o)
|
|
self._indent = self._indent[:-4]
|
|
self._vars.pop()
|
|
self._vars[-1].append(o.name)
|
|
if len(self._output) == n:
|
|
if self._state == EMPTY_CLASS and sep is not None:
|
|
self._output[sep] = ""
|
|
if not (self._include_docstrings and o.docstring):
|
|
self._output[-1] = self._output[-1][:-1] + " ...\n"
|
|
self._state = EMPTY_CLASS
|
|
else:
|
|
self._state = CLASS
|
|
self.method_names = set()
|
|
self._current_class = None
|
|
|
|
def get_base_types(self, cdef: ClassDef) -> list[str]:
|
|
"""Get list of base classes for a class."""
|
|
base_types: list[str] = []
|
|
p = AliasPrinter(self)
|
|
for base in cdef.base_type_exprs + cdef.removed_base_type_exprs:
|
|
if isinstance(base, (NameExpr, MemberExpr)):
|
|
if self.get_fullname(base) != "builtins.object":
|
|
base_types.append(get_qualified_name(base))
|
|
elif isinstance(base, IndexExpr):
|
|
base_types.append(base.accept(p))
|
|
elif isinstance(base, CallExpr):
|
|
# namedtuple(typename, fields), NamedTuple(typename, fields) calls can
|
|
# be used as a base class. The first argument is a string literal that
|
|
# is usually the same as the class name.
|
|
#
|
|
# Note:
|
|
# A call-based named tuple as a base class cannot be safely converted to
|
|
# a class-based NamedTuple definition because class attributes defined
|
|
# in the body of the class inheriting from the named tuple call are not
|
|
# namedtuple fields at runtime.
|
|
if self.is_namedtuple(base):
|
|
nt_fields = self._get_namedtuple_fields(base)
|
|
assert isinstance(base.args[0], StrExpr)
|
|
typename = base.args[0].value
|
|
if nt_fields is None:
|
|
# Invalid namedtuple() call, cannot determine fields
|
|
base_types.append(self.add_typing_import("Incomplete"))
|
|
continue
|
|
fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields)
|
|
namedtuple_name = self.add_typing_import("NamedTuple")
|
|
base_types.append(f"{namedtuple_name}({typename!r}, [{fields_str}])")
|
|
elif self.is_typed_namedtuple(base):
|
|
base_types.append(base.accept(p))
|
|
else:
|
|
# At this point, we don't know what the base class is, so we
|
|
# just use Incomplete as the base class.
|
|
base_types.append(self.add_typing_import("Incomplete"))
|
|
for name, value in cdef.keywords.items():
|
|
if name == "metaclass":
|
|
continue # handled separately
|
|
base_types.append(f"{name}={value.accept(p)}")
|
|
return base_types
|
|
|
|
def visit_block(self, o: Block) -> None:
|
|
# Unreachable statements may be partially uninitialized and that may
|
|
# cause trouble.
|
|
if not o.is_unreachable:
|
|
super().visit_block(o)
|
|
|
|
def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
|
|
foundl = []
|
|
|
|
for lvalue in o.lvalues:
|
|
if isinstance(lvalue, NameExpr) and isinstance(o.rvalue, CallExpr):
|
|
if self.is_namedtuple(o.rvalue) or self.is_typed_namedtuple(o.rvalue):
|
|
self.process_namedtuple(lvalue, o.rvalue)
|
|
foundl.append(False) # state is updated in process_namedtuple
|
|
continue
|
|
if self.is_typeddict(o.rvalue):
|
|
self.process_typeddict(lvalue, o.rvalue)
|
|
foundl.append(False) # state is updated in process_typeddict
|
|
continue
|
|
if (
|
|
isinstance(lvalue, NameExpr)
|
|
and not self.is_private_name(lvalue.name)
|
|
# it is never an alias with explicit annotation
|
|
and not o.unanalyzed_type
|
|
and self.is_alias_expression(o.rvalue)
|
|
):
|
|
self.process_typealias(lvalue, o.rvalue)
|
|
continue
|
|
if isinstance(lvalue, (TupleExpr, ListExpr)):
|
|
items = lvalue.items
|
|
if isinstance(o.unanalyzed_type, TupleType): # type: ignore[misc]
|
|
annotations: Iterable[Type | None] = o.unanalyzed_type.items
|
|
else:
|
|
annotations = [None] * len(items)
|
|
else:
|
|
items = [lvalue]
|
|
annotations = [o.unanalyzed_type]
|
|
sep = False
|
|
found = False
|
|
for item, annotation in zip(items, annotations):
|
|
if isinstance(item, NameExpr):
|
|
init = self.get_init(item.name, o.rvalue, annotation)
|
|
if init:
|
|
found = True
|
|
if not sep and not self._indent and self._state not in (EMPTY, VAR):
|
|
init = "\n" + init
|
|
sep = True
|
|
self.add(init)
|
|
self.record_name(item.name)
|
|
foundl.append(found)
|
|
|
|
if all(foundl):
|
|
self._state = VAR
|
|
|
|
def is_namedtuple(self, expr: CallExpr) -> bool:
|
|
return self.get_fullname(expr.callee) == "collections.namedtuple"
|
|
|
|
def is_typed_namedtuple(self, expr: CallExpr) -> bool:
|
|
return self.get_fullname(expr.callee) in TYPED_NAMEDTUPLE_NAMES
|
|
|
|
def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None:
|
|
if self.is_namedtuple(call):
|
|
fields_arg = call.args[1]
|
|
if isinstance(fields_arg, StrExpr):
|
|
field_names = fields_arg.value.replace(",", " ").split()
|
|
elif isinstance(fields_arg, (ListExpr, TupleExpr)):
|
|
field_names = []
|
|
for field in fields_arg.items:
|
|
if not isinstance(field, StrExpr):
|
|
return None
|
|
field_names.append(field.value)
|
|
else:
|
|
return None # Invalid namedtuple fields type
|
|
if not field_names:
|
|
return []
|
|
incomplete = self.add_typing_import("Incomplete")
|
|
return [(field_name, incomplete) for field_name in field_names]
|
|
elif self.is_typed_namedtuple(call):
|
|
fields_arg = call.args[1]
|
|
if not isinstance(fields_arg, (ListExpr, TupleExpr)):
|
|
return None
|
|
fields: list[tuple[str, str]] = []
|
|
p = AliasPrinter(self)
|
|
for field in fields_arg.items:
|
|
if not (isinstance(field, TupleExpr) and len(field.items) == 2):
|
|
return None
|
|
field_name, field_type = field.items
|
|
if not isinstance(field_name, StrExpr):
|
|
return None
|
|
fields.append((field_name.value, field_type.accept(p)))
|
|
return fields
|
|
else:
|
|
return None # Not a named tuple call
|
|
|
|
def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
|
|
if self._state == CLASS:
|
|
self.add("\n")
|
|
|
|
if not isinstance(rvalue.args[0], StrExpr):
|
|
self.annotate_as_incomplete(lvalue)
|
|
return
|
|
|
|
fields = self._get_namedtuple_fields(rvalue)
|
|
if fields is None:
|
|
self.annotate_as_incomplete(lvalue)
|
|
return
|
|
bases = self.add_typing_import("NamedTuple")
|
|
# TODO: Add support for generic NamedTuples. Requires `Generic` as base class.
|
|
class_def = f"{self._indent}class {lvalue.name}({bases}):"
|
|
if len(fields) == 0:
|
|
self.add(f"{class_def} ...\n")
|
|
self._state = EMPTY_CLASS
|
|
else:
|
|
if self._state not in (EMPTY, CLASS):
|
|
self.add("\n")
|
|
self.add(f"{class_def}\n")
|
|
for f_name, f_type in fields:
|
|
self.add(f"{self._indent} {f_name}: {f_type}\n")
|
|
self._state = CLASS
|
|
|
|
def is_typeddict(self, expr: CallExpr) -> bool:
|
|
return self.get_fullname(expr.callee) in TPDICT_NAMES
|
|
|
|
def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
|
|
if self._state == CLASS:
|
|
self.add("\n")
|
|
|
|
if not isinstance(rvalue.args[0], StrExpr):
|
|
self.annotate_as_incomplete(lvalue)
|
|
return
|
|
|
|
items: list[tuple[str, Expression]] = []
|
|
total: Expression | None = None
|
|
if len(rvalue.args) > 1 and rvalue.arg_kinds[1] == ARG_POS:
|
|
if not isinstance(rvalue.args[1], DictExpr):
|
|
self.annotate_as_incomplete(lvalue)
|
|
return
|
|
for attr_name, attr_type in rvalue.args[1].items:
|
|
if not isinstance(attr_name, StrExpr):
|
|
self.annotate_as_incomplete(lvalue)
|
|
return
|
|
items.append((attr_name.value, attr_type))
|
|
if len(rvalue.args) > 2:
|
|
if rvalue.arg_kinds[2] != ARG_NAMED or rvalue.arg_names[2] != "total":
|
|
self.annotate_as_incomplete(lvalue)
|
|
return
|
|
total = rvalue.args[2]
|
|
else:
|
|
for arg_name, arg in zip(rvalue.arg_names[1:], rvalue.args[1:]):
|
|
if not isinstance(arg_name, str):
|
|
self.annotate_as_incomplete(lvalue)
|
|
return
|
|
if arg_name == "total":
|
|
total = arg
|
|
else:
|
|
items.append((arg_name, arg))
|
|
bases = self.add_typing_import("TypedDict")
|
|
p = AliasPrinter(self)
|
|
if any(not key.isidentifier() or keyword.iskeyword(key) for key, _ in items):
|
|
# Keep the call syntax if there are non-identifier or reserved keyword keys.
|
|
self.add(f"{self._indent}{lvalue.name} = {rvalue.accept(p)}\n")
|
|
self._state = VAR
|
|
else:
|
|
# TODO: Add support for generic TypedDicts. Requires `Generic` as base class.
|
|
if total is not None:
|
|
bases += f", total={total.accept(p)}"
|
|
class_def = f"{self._indent}class {lvalue.name}({bases}):"
|
|
if len(items) == 0:
|
|
self.add(f"{class_def} ...\n")
|
|
self._state = EMPTY_CLASS
|
|
else:
|
|
if self._state not in (EMPTY, CLASS):
|
|
self.add("\n")
|
|
self.add(f"{class_def}\n")
|
|
for key, key_type in items:
|
|
self.add(f"{self._indent} {key}: {key_type.accept(p)}\n")
|
|
self._state = CLASS
|
|
|
|
def annotate_as_incomplete(self, lvalue: NameExpr) -> None:
|
|
self.add(f"{self._indent}{lvalue.name}: {self.add_typing_import('Incomplete')}\n")
|
|
self._state = VAR
|
|
|
|
def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool:
|
|
"""Return True for things that look like target for an alias.
|
|
|
|
Used to know if assignments look like type aliases, function alias,
|
|
or module alias.
|
|
"""
|
|
# Assignment of TypeVar(...) and other typevar-likes are passed through
|
|
if isinstance(expr, CallExpr) and self.get_fullname(expr.callee) in (
|
|
"typing.TypeVar",
|
|
"typing_extensions.TypeVar",
|
|
"typing.ParamSpec",
|
|
"typing_extensions.ParamSpec",
|
|
"typing.TypeVarTuple",
|
|
"typing_extensions.TypeVarTuple",
|
|
):
|
|
return True
|
|
elif isinstance(expr, EllipsisExpr):
|
|
return not top_level
|
|
elif isinstance(expr, NameExpr):
|
|
if expr.name in ("True", "False"):
|
|
return False
|
|
elif expr.name == "None":
|
|
return not top_level
|
|
else:
|
|
return not self.is_private_name(expr.name)
|
|
elif isinstance(expr, MemberExpr) and self.analyzed:
|
|
# Also add function and module aliases.
|
|
return (
|
|
top_level
|
|
and isinstance(expr.node, (FuncDef, Decorator, MypyFile))
|
|
or isinstance(expr.node, TypeInfo)
|
|
) and not self.is_private_member(expr.node.fullname)
|
|
elif (
|
|
isinstance(expr, IndexExpr)
|
|
and isinstance(expr.base, NameExpr)
|
|
and not self.is_private_name(expr.base.name)
|
|
):
|
|
if isinstance(expr.index, TupleExpr):
|
|
indices = expr.index.items
|
|
else:
|
|
indices = [expr.index]
|
|
if expr.base.name == "Callable" and len(indices) == 2:
|
|
args, ret = indices
|
|
if isinstance(args, EllipsisExpr):
|
|
indices = [ret]
|
|
elif isinstance(args, ListExpr):
|
|
indices = args.items + [ret]
|
|
else:
|
|
return False
|
|
return all(self.is_alias_expression(i, top_level=False) for i in indices)
|
|
else:
|
|
return False
|
|
|
|
def process_typealias(self, lvalue: NameExpr, rvalue: Expression) -> None:
|
|
p = AliasPrinter(self)
|
|
self.add(f"{self._indent}{lvalue.name} = {rvalue.accept(p)}\n")
|
|
self.record_name(lvalue.name)
|
|
self._vars[-1].append(lvalue.name)
|
|
|
|
def visit_if_stmt(self, o: IfStmt) -> None:
|
|
# Ignore if __name__ == '__main__'.
|
|
expr = o.expr[0]
|
|
if (
|
|
isinstance(expr, ComparisonExpr)
|
|
and isinstance(expr.operands[0], NameExpr)
|
|
and isinstance(expr.operands[1], StrExpr)
|
|
and expr.operands[0].name == "__name__"
|
|
and "__main__" in expr.operands[1].value
|
|
):
|
|
return
|
|
super().visit_if_stmt(o)
|
|
|
|
def visit_import_all(self, o: ImportAll) -> None:
|
|
self.add_import_line(f"from {'.' * o.relative}{o.id} import *\n")
|
|
|
|
def visit_import_from(self, o: ImportFrom) -> None:
|
|
exported_names: set[str] = set()
|
|
import_names = []
|
|
module, relative = translate_module_name(o.id, o.relative)
|
|
if self.module:
|
|
full_module, ok = mypy.util.correct_relative_import(
|
|
self.module, relative, module, self.path.endswith(".__init__.py")
|
|
)
|
|
if not ok:
|
|
full_module = module
|
|
else:
|
|
full_module = module
|
|
if module == "__future__":
|
|
return # Not preserved
|
|
for name, as_name in o.names:
|
|
if name == "six":
|
|
# Vendored six -- translate into plain 'import six'.
|
|
self.visit_import(Import([("six", None)]))
|
|
continue
|
|
exported = False
|
|
if as_name is None and self.module and (self.module + "." + name) in EXTRA_EXPORTED:
|
|
# Special case certain names that should be exported, against our general rules.
|
|
exported = True
|
|
is_private = self.is_private_name(name, full_module + "." + name)
|
|
if (
|
|
as_name is None
|
|
and name not in self.referenced_names
|
|
and (not self._all_ or name in IGNORED_DUNDERS)
|
|
and not is_private
|
|
and module not in ("abc", "asyncio") + TYPING_MODULE_NAMES
|
|
):
|
|
# An imported name that is never referenced in the module is assumed to be
|
|
# exported, unless there is an explicit __all__. Note that we need to special
|
|
# case 'abc' since some references are deleted during semantic analysis.
|
|
exported = True
|
|
top_level = full_module.split(".")[0]
|
|
if (
|
|
as_name is None
|
|
and not self.export_less
|
|
and (not self._all_ or name in IGNORED_DUNDERS)
|
|
and self.module
|
|
and not is_private
|
|
and top_level in (self.module.split(".")[0], "_" + self.module.split(".")[0])
|
|
):
|
|
# Export imports from the same package, since we can't reliably tell whether they
|
|
# are part of the public API.
|
|
exported = True
|
|
if exported:
|
|
self.import_tracker.reexport(name)
|
|
as_name = name
|
|
import_names.append((name, as_name))
|
|
self.import_tracker.add_import_from("." * relative + module, import_names)
|
|
self._vars[-1].extend(alias or name for name, alias in import_names)
|
|
for name, alias in import_names:
|
|
self.record_name(alias or name)
|
|
|
|
if self._all_:
|
|
# Include "import from"s that import names defined in __all__.
|
|
names = [
|
|
name
|
|
for name, alias in o.names
|
|
if name in self._all_ and alias is None and name not in IGNORED_DUNDERS
|
|
]
|
|
exported_names.update(names)
|
|
|
|
def visit_import(self, o: Import) -> None:
|
|
for id, as_id in o.ids:
|
|
self.import_tracker.add_import(id, as_id)
|
|
if as_id is None:
|
|
target_name = id.split(".")[0]
|
|
else:
|
|
target_name = as_id
|
|
self._vars[-1].append(target_name)
|
|
self.record_name(target_name)
|
|
|
|
def get_init(
|
|
self, lvalue: str, rvalue: Expression, annotation: Type | None = None
|
|
) -> str | None:
|
|
"""Return initializer for a variable.
|
|
|
|
Return None if we've generated one already or if the variable is internal.
|
|
"""
|
|
if lvalue in self._vars[-1]:
|
|
# We've generated an initializer already for this variable.
|
|
return None
|
|
# TODO: Only do this at module top level.
|
|
if self.is_private_name(lvalue) or self.is_not_in_all(lvalue):
|
|
return None
|
|
self._vars[-1].append(lvalue)
|
|
if annotation is not None:
|
|
typename = self.print_annotation(annotation)
|
|
if (
|
|
isinstance(annotation, UnboundType)
|
|
and not annotation.args
|
|
and annotation.name == "Final"
|
|
and self.import_tracker.module_for.get("Final") in TYPING_MODULE_NAMES
|
|
):
|
|
# Final without type argument is invalid in stubs.
|
|
final_arg = self.get_str_type_of_node(rvalue)
|
|
typename += f"[{final_arg}]"
|
|
else:
|
|
typename = self.get_str_type_of_node(rvalue)
|
|
initializer = self.get_assign_initializer(rvalue)
|
|
return f"{self._indent}{lvalue}: {typename}{initializer}\n"
|
|
|
|
def get_assign_initializer(self, rvalue: Expression) -> str:
|
|
"""Does this rvalue need some special initializer value?"""
|
|
if self._current_class and self._current_class.info:
|
|
# Current rules
|
|
# 1. Return `...` if we are dealing with `NamedTuple` and it has an existing default value
|
|
if self._current_class.info.is_named_tuple and not isinstance(rvalue, TempNode):
|
|
return " = ..."
|
|
# TODO: support other possible cases, where initializer is important
|
|
|
|
# By default, no initializer is required:
|
|
return ""
|
|
|
|
def add(self, string: str) -> None:
|
|
"""Add text to generated stub."""
|
|
self._output.append(string)
|
|
|
|
def add_decorator(self, name: str, require_name: bool = False) -> None:
|
|
if require_name:
|
|
self.import_tracker.require_name(name)
|
|
if not self._indent and self._state not in (EMPTY, FUNC):
|
|
self._decorators.append("\n")
|
|
self._decorators.append(f"{self._indent}@{name}\n")
|
|
|
|
def clear_decorators(self) -> None:
|
|
self._decorators.clear()
|
|
|
|
def typing_name(self, name: str) -> str:
|
|
if name in self.defined_names:
|
|
# Avoid name clash between name from typing and a name defined in stub.
|
|
return "_" + name
|
|
else:
|
|
return name
|
|
|
|
def add_typing_import(self, name: str) -> str:
|
|
"""Add a name to be imported for typing, unless it's imported already.
|
|
|
|
The import will be internal to the stub.
|
|
"""
|
|
name = self.typing_name(name)
|
|
self.import_tracker.require_name(name)
|
|
return name
|
|
|
|
def add_import_line(self, line: str) -> None:
|
|
"""Add a line of text to the import section, unless it's already there."""
|
|
if line not in self._import_lines:
|
|
self._import_lines.append(line)
|
|
|
|
def output(self) -> str:
|
|
"""Return the text for the stub."""
|
|
imports = ""
|
|
if self._import_lines:
|
|
imports += "".join(self._import_lines)
|
|
imports += "".join(self.import_tracker.import_lines())
|
|
if imports and self._output:
|
|
imports += "\n"
|
|
return imports + "".join(self._output)
|
|
|
|
def is_not_in_all(self, name: str) -> bool:
|
|
if self.is_private_name(name):
|
|
return False
|
|
if self._all_:
|
|
return self.is_top_level() and name not in self._all_
|
|
return False
|
|
|
|
def is_private_name(self, name: str, fullname: str | None = None) -> bool:
|
|
if self._include_private:
|
|
return False
|
|
if fullname in EXTRA_EXPORTED:
|
|
return False
|
|
return name.startswith("_") and (not name.endswith("__") or name in IGNORED_DUNDERS)
|
|
|
|
def is_private_member(self, fullname: str) -> bool:
|
|
parts = fullname.split(".")
|
|
return any(self.is_private_name(part) for part in parts)
|
|
|
|
def get_str_type_of_node(
|
|
self, rvalue: Expression, can_infer_optional: bool = False, can_be_any: bool = True
|
|
) -> str:
|
|
rvalue = self.maybe_unwrap_unary_expr(rvalue)
|
|
|
|
if isinstance(rvalue, IntExpr):
|
|
return "int"
|
|
if isinstance(rvalue, StrExpr):
|
|
return "str"
|
|
if isinstance(rvalue, BytesExpr):
|
|
return "bytes"
|
|
if isinstance(rvalue, FloatExpr):
|
|
return "float"
|
|
if isinstance(rvalue, ComplexExpr): # 1j
|
|
return "complex"
|
|
if isinstance(rvalue, OpExpr) and rvalue.op in ("-", "+"): # -1j + 1
|
|
if isinstance(self.maybe_unwrap_unary_expr(rvalue.left), ComplexExpr) or isinstance(
|
|
self.maybe_unwrap_unary_expr(rvalue.right), ComplexExpr
|
|
):
|
|
return "complex"
|
|
if isinstance(rvalue, NameExpr) and rvalue.name in ("True", "False"):
|
|
return "bool"
|
|
if can_infer_optional and isinstance(rvalue, NameExpr) and rvalue.name == "None":
|
|
return f"{self.add_typing_import('Incomplete')} | None"
|
|
if can_be_any:
|
|
return self.add_typing_import("Incomplete")
|
|
else:
|
|
return ""
|
|
|
|
def maybe_unwrap_unary_expr(self, expr: Expression) -> Expression:
|
|
"""Unwrap (possibly nested) unary expressions.
|
|
|
|
But, some unary expressions can change the type of expression.
|
|
While we want to preserve it. For example, `~True` is `int`.
|
|
So, we only allow a subset of unary expressions to be unwrapped.
|
|
"""
|
|
if not isinstance(expr, UnaryExpr):
|
|
return expr
|
|
|
|
# First, try to unwrap `[+-]+ (int|float|complex)` expr:
|
|
math_ops = ("+", "-")
|
|
if expr.op in math_ops:
|
|
while isinstance(expr, UnaryExpr):
|
|
if expr.op not in math_ops or not isinstance(
|
|
expr.expr, (IntExpr, FloatExpr, ComplexExpr, UnaryExpr)
|
|
):
|
|
break
|
|
expr = expr.expr
|
|
return expr
|
|
|
|
# Next, try `not bool` expr:
|
|
if expr.op == "not":
|
|
while isinstance(expr, UnaryExpr):
|
|
if expr.op != "not" or not isinstance(expr.expr, (NameExpr, UnaryExpr)):
|
|
break
|
|
if isinstance(expr.expr, NameExpr) and expr.expr.name not in ("True", "False"):
|
|
break
|
|
expr = expr.expr
|
|
return expr
|
|
|
|
# This is some other unary expr, we cannot do anything with it (yet?).
|
|
return expr
|
|
|
|
def print_annotation(self, t: Type) -> str:
|
|
printer = AnnotationPrinter(self)
|
|
return t.accept(printer)
|
|
|
|
def is_top_level(self) -> bool:
|
|
"""Are we processing the top level of a file?"""
|
|
return self._indent == ""
|
|
|
|
def record_name(self, name: str) -> None:
|
|
"""Mark a name as defined.
|
|
|
|
This only does anything if at the top level of a module.
|
|
"""
|
|
if self.is_top_level():
|
|
self._toplevel_names.append(name)
|
|
|
|
def is_recorded_name(self, name: str) -> bool:
|
|
"""Has this name been recorded previously?"""
|
|
return self.is_top_level() and name in self._toplevel_names
|
|
|
|
|
|
def find_method_names(defs: list[Statement]) -> set[str]:
|
|
# TODO: Traverse into nested definitions
|
|
result = set()
|
|
for defn in defs:
|
|
if isinstance(defn, FuncDef):
|
|
result.add(defn.name)
|
|
elif isinstance(defn, Decorator):
|
|
result.add(defn.func.name)
|
|
elif isinstance(defn, OverloadedFuncDef):
|
|
for item in defn.items:
|
|
result.update(find_method_names([item]))
|
|
return result
|
|
|
|
|
|
class SelfTraverser(mypy.traverser.TraverserVisitor):
|
|
def __init__(self) -> None:
|
|
self.results: list[tuple[str, Expression]] = []
|
|
|
|
def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
|
|
lvalue = o.lvalues[0]
|
|
if (
|
|
isinstance(lvalue, MemberExpr)
|
|
and isinstance(lvalue.expr, NameExpr)
|
|
and lvalue.expr.name == "self"
|
|
):
|
|
self.results.append((lvalue.name, o.rvalue))
|
|
|
|
|
|
def find_self_initializers(fdef: FuncBase) -> list[tuple[str, Expression]]:
|
|
"""Find attribute initializers in a method.
|
|
|
|
Return a list of pairs (attribute name, r.h.s. expression).
|
|
"""
|
|
traverser = SelfTraverser()
|
|
fdef.accept(traverser)
|
|
return traverser.results
|
|
|
|
|
|
def get_qualified_name(o: Expression) -> str:
|
|
if isinstance(o, NameExpr):
|
|
return o.name
|
|
elif isinstance(o, MemberExpr):
|
|
return f"{get_qualified_name(o.expr)}.{o.name}"
|
|
else:
|
|
return ERROR_MARKER
|
|
|
|
|
|
def remove_blacklisted_modules(modules: list[StubSource]) -> list[StubSource]:
|
|
return [
|
|
module for module in modules if module.path is None or not is_blacklisted_path(module.path)
|
|
]
|
|
|
|
|
|
def is_blacklisted_path(path: str) -> bool:
|
|
return any(substr in (normalize_path_separators(path) + "\n") for substr in BLACKLIST)
|
|
|
|
|
|
def normalize_path_separators(path: str) -> str:
|
|
if sys.platform == "win32":
|
|
return path.replace("\\", "/")
|
|
return path
|
|
|
|
|
|
def collect_build_targets(
|
|
options: Options, mypy_opts: MypyOptions
|
|
) -> tuple[list[StubSource], list[StubSource]]:
|
|
"""Collect files for which we need to generate stubs.
|
|
|
|
Return list of Python modules and C modules.
|
|
"""
|
|
if options.packages or options.modules:
|
|
if options.no_import:
|
|
py_modules = find_module_paths_using_search(
|
|
options.modules, options.packages, options.search_path, options.pyversion
|
|
)
|
|
c_modules: list[StubSource] = []
|
|
else:
|
|
# Using imports is the default, since we can also find C modules.
|
|
py_modules, c_modules = find_module_paths_using_imports(
|
|
options.modules, options.packages, options.verbose, options.quiet
|
|
)
|
|
else:
|
|
# Use mypy native source collection for files and directories.
|
|
try:
|
|
source_list = create_source_list(options.files, mypy_opts)
|
|
except InvalidSourceList as e:
|
|
raise SystemExit(str(e)) from e
|
|
py_modules = [StubSource(m.module, m.path) for m in source_list]
|
|
c_modules = []
|
|
|
|
py_modules = remove_blacklisted_modules(py_modules)
|
|
|
|
return py_modules, c_modules
|
|
|
|
|
|
def find_module_paths_using_imports(
|
|
modules: list[str], packages: list[str], verbose: bool, quiet: bool
|
|
) -> tuple[list[StubSource], list[StubSource]]:
|
|
"""Find path and runtime value of __all__ (if possible) for modules and packages.
|
|
|
|
This function uses runtime Python imports to get the information.
|
|
"""
|
|
with ModuleInspect() as inspect:
|
|
py_modules: list[StubSource] = []
|
|
c_modules: list[StubSource] = []
|
|
found = list(walk_packages(inspect, packages, verbose))
|
|
modules = modules + found
|
|
modules = [
|
|
mod for mod in modules if not is_non_library_module(mod)
|
|
] # We don't want to run any tests or scripts
|
|
for mod in modules:
|
|
try:
|
|
result = find_module_path_and_all_py3(inspect, mod, verbose)
|
|
except CantImport as e:
|
|
tb = traceback.format_exc()
|
|
if verbose:
|
|
sys.stdout.write(tb)
|
|
if not quiet:
|
|
report_missing(mod, e.message, tb)
|
|
continue
|
|
if not result:
|
|
c_modules.append(StubSource(mod))
|
|
else:
|
|
path, runtime_all = result
|
|
py_modules.append(StubSource(mod, path, runtime_all))
|
|
return py_modules, c_modules
|
|
|
|
|
|
def is_non_library_module(module: str) -> bool:
|
|
"""Does module look like a test module or a script?"""
|
|
if module.endswith(
|
|
(
|
|
".tests",
|
|
".test",
|
|
".testing",
|
|
"_tests",
|
|
"_test_suite",
|
|
"test_util",
|
|
"test_utils",
|
|
"test_base",
|
|
".__main__",
|
|
".conftest", # Used by pytest
|
|
".setup", # Typically an install script
|
|
)
|
|
):
|
|
return True
|
|
if module.split(".")[-1].startswith("test_"):
|
|
return True
|
|
if (
|
|
".tests." in module
|
|
or ".test." in module
|
|
or ".testing." in module
|
|
or ".SelfTest." in module
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
def translate_module_name(module: str, relative: int) -> tuple[str, int]:
|
|
for pkg in VENDOR_PACKAGES:
|
|
for alt in "six.moves", "six":
|
|
substr = f"{pkg}.{alt}"
|
|
if module.endswith("." + substr) or (module == substr and relative):
|
|
return alt, 0
|
|
if "." + substr + "." in module:
|
|
return alt + "." + module.partition("." + substr + ".")[2], 0
|
|
return module, relative
|
|
|
|
|
|
def find_module_paths_using_search(
|
|
modules: list[str], packages: list[str], search_path: list[str], pyversion: tuple[int, int]
|
|
) -> list[StubSource]:
|
|
"""Find sources for modules and packages requested.
|
|
|
|
This function just looks for source files at the file system level.
|
|
This is used if user passes --no-import, and will not find C modules.
|
|
Exit if some of the modules or packages can't be found.
|
|
"""
|
|
result: list[StubSource] = []
|
|
typeshed_path = default_lib_path(mypy.build.default_data_dir(), pyversion, None)
|
|
search_paths = SearchPaths((".",) + tuple(search_path), (), (), tuple(typeshed_path))
|
|
cache = FindModuleCache(search_paths, fscache=None, options=None)
|
|
for module in modules:
|
|
m_result = cache.find_module(module)
|
|
if isinstance(m_result, ModuleNotFoundReason):
|
|
fail_missing(module, m_result)
|
|
module_path = None
|
|
else:
|
|
module_path = m_result
|
|
result.append(StubSource(module, module_path))
|
|
for package in packages:
|
|
p_result = cache.find_modules_recursive(package)
|
|
if p_result:
|
|
fail_missing(package, ModuleNotFoundReason.NOT_FOUND)
|
|
sources = [StubSource(m.module, m.path) for m in p_result]
|
|
result.extend(sources)
|
|
|
|
result = [m for m in result if not is_non_library_module(m.module)]
|
|
|
|
return result
|
|
|
|
|
|
def mypy_options(stubgen_options: Options) -> MypyOptions:
|
|
"""Generate mypy options using the flag passed by user."""
|
|
options = MypyOptions()
|
|
options.follow_imports = "skip"
|
|
options.incremental = False
|
|
options.ignore_errors = True
|
|
options.semantic_analysis_only = True
|
|
options.python_version = stubgen_options.pyversion
|
|
options.show_traceback = True
|
|
options.transform_source = remove_misplaced_type_comments
|
|
options.preserve_asts = True
|
|
options.include_docstrings = stubgen_options.include_docstrings
|
|
|
|
# Override cache_dir if provided in the environment
|
|
environ_cache_dir = os.getenv("MYPY_CACHE_DIR", "")
|
|
if environ_cache_dir.strip():
|
|
options.cache_dir = environ_cache_dir
|
|
options.cache_dir = os.path.expanduser(options.cache_dir)
|
|
|
|
return options
|
|
|
|
|
|
def parse_source_file(mod: StubSource, mypy_options: MypyOptions) -> None:
|
|
"""Parse a source file.
|
|
|
|
On success, store AST in the corresponding attribute of the stub source.
|
|
If there are syntax errors, print them and exit.
|
|
"""
|
|
assert mod.path is not None, "Not found module was not skipped"
|
|
with open(mod.path, "rb") as f:
|
|
data = f.read()
|
|
source = mypy.util.decode_python_encoding(data)
|
|
errors = Errors(mypy_options)
|
|
mod.ast = mypy.parse.parse(
|
|
source, fnam=mod.path, module=mod.module, errors=errors, options=mypy_options
|
|
)
|
|
mod.ast._fullname = mod.module
|
|
if errors.is_blockers():
|
|
# Syntax error!
|
|
for m in errors.new_messages():
|
|
sys.stderr.write(f"{m}\n")
|
|
sys.exit(1)
|
|
|
|
|
|
def generate_asts_for_modules(
|
|
py_modules: list[StubSource], parse_only: bool, mypy_options: MypyOptions, verbose: bool
|
|
) -> None:
|
|
"""Use mypy to parse (and optionally analyze) source files."""
|
|
if not py_modules:
|
|
return # Nothing to do here, but there may be C modules
|
|
if verbose:
|
|
print(f"Processing {len(py_modules)} files...")
|
|
if parse_only:
|
|
for mod in py_modules:
|
|
parse_source_file(mod, mypy_options)
|
|
return
|
|
# Perform full semantic analysis of the source set.
|
|
try:
|
|
res = build([module.source for module in py_modules], mypy_options)
|
|
except CompileError as e:
|
|
raise SystemExit(f"Critical error during semantic analysis: {e}") from e
|
|
|
|
for mod in py_modules:
|
|
mod.ast = res.graph[mod.module].tree
|
|
# Use statically inferred __all__ if there is no runtime one.
|
|
if mod.runtime_all is None:
|
|
mod.runtime_all = res.manager.semantic_analyzer.export_map[mod.module]
|
|
|
|
|
|
def generate_stub_from_ast(
|
|
mod: StubSource,
|
|
target: str,
|
|
parse_only: bool = False,
|
|
include_private: bool = False,
|
|
export_less: bool = False,
|
|
include_docstrings: bool = False,
|
|
) -> None:
|
|
"""Use analysed (or just parsed) AST to generate type stub for single file.
|
|
|
|
If directory for target doesn't exist it will created. Existing stub
|
|
will be overwritten.
|
|
"""
|
|
gen = StubGenerator(
|
|
mod.runtime_all,
|
|
include_private=include_private,
|
|
analyzed=not parse_only,
|
|
export_less=export_less,
|
|
include_docstrings=include_docstrings,
|
|
)
|
|
assert mod.ast is not None, "This function must be used only with analyzed modules"
|
|
mod.ast.accept(gen)
|
|
|
|
# Write output to file.
|
|
subdir = os.path.dirname(target)
|
|
if subdir and not os.path.isdir(subdir):
|
|
os.makedirs(subdir)
|
|
with open(target, "w") as file:
|
|
file.write("".join(gen.output()))
|
|
|
|
|
|
def get_sig_generators(options: Options) -> list[SignatureGenerator]:
|
|
sig_generators: list[SignatureGenerator] = [
|
|
DocstringSignatureGenerator(),
|
|
FallbackSignatureGenerator(),
|
|
]
|
|
if options.doc_dir:
|
|
# Collect info from docs (if given). Always check these first.
|
|
sigs, class_sigs = collect_docs_signatures(options.doc_dir)
|
|
sig_generators.insert(0, ExternalSignatureGenerator(sigs, class_sigs))
|
|
return sig_generators
|
|
|
|
|
|
def collect_docs_signatures(doc_dir: str) -> tuple[dict[str, str], dict[str, str]]:
|
|
"""Gather all function and class signatures in the docs.
|
|
|
|
Return a tuple (function signatures, class signatures).
|
|
Currently only used for C modules.
|
|
"""
|
|
all_sigs: list[Sig] = []
|
|
all_class_sigs: list[Sig] = []
|
|
for path in glob.glob(f"{doc_dir}/*.rst"):
|
|
with open(path) as f:
|
|
loc_sigs, loc_class_sigs = parse_all_signatures(f.readlines())
|
|
all_sigs += loc_sigs
|
|
all_class_sigs += loc_class_sigs
|
|
sigs = dict(find_unique_signatures(all_sigs))
|
|
class_sigs = dict(find_unique_signatures(all_class_sigs))
|
|
return sigs, class_sigs
|
|
|
|
|
|
def generate_stubs(options: Options) -> None:
|
|
"""Main entry point for the program."""
|
|
mypy_opts = mypy_options(options)
|
|
py_modules, c_modules = collect_build_targets(options, mypy_opts)
|
|
sig_generators = get_sig_generators(options)
|
|
# Use parsed sources to generate stubs for Python modules.
|
|
generate_asts_for_modules(py_modules, options.parse_only, mypy_opts, options.verbose)
|
|
files = []
|
|
for mod in py_modules:
|
|
assert mod.path is not None, "Not found module was not skipped"
|
|
target = mod.module.replace(".", "/")
|
|
if os.path.basename(mod.path) == "__init__.py":
|
|
target += "/__init__.pyi"
|
|
else:
|
|
target += ".pyi"
|
|
target = os.path.join(options.output_dir, target)
|
|
files.append(target)
|
|
with generate_guarded(mod.module, target, options.ignore_errors, options.verbose):
|
|
generate_stub_from_ast(
|
|
mod,
|
|
target,
|
|
options.parse_only,
|
|
options.include_private,
|
|
options.export_less,
|
|
include_docstrings=options.include_docstrings,
|
|
)
|
|
|
|
# Separately analyse C modules using different logic.
|
|
all_modules = sorted(m.module for m in (py_modules + c_modules))
|
|
for mod in c_modules:
|
|
if any(py_mod.module.startswith(mod.module + ".") for py_mod in py_modules + c_modules):
|
|
target = mod.module.replace(".", "/") + "/__init__.pyi"
|
|
else:
|
|
target = mod.module.replace(".", "/") + ".pyi"
|
|
target = os.path.join(options.output_dir, target)
|
|
files.append(target)
|
|
with generate_guarded(mod.module, target, options.ignore_errors, options.verbose):
|
|
generate_stub_for_c_module(
|
|
mod.module,
|
|
target,
|
|
known_modules=all_modules,
|
|
sig_generators=sig_generators,
|
|
include_docstrings=options.include_docstrings,
|
|
)
|
|
num_modules = len(py_modules) + len(c_modules)
|
|
if not options.quiet and num_modules > 0:
|
|
print("Processed %d modules" % num_modules)
|
|
if len(files) == 1:
|
|
print(f"Generated {files[0]}")
|
|
else:
|
|
print(f"Generated files under {common_dir_prefix(files)}" + os.sep)
|
|
|
|
|
|
HEADER = """%(prog)s [-h] [more options, see -h]
|
|
[-m MODULE] [-p PACKAGE] [files ...]"""
|
|
|
|
DESCRIPTION = """
|
|
Generate draft stubs for modules.
|
|
|
|
Stubs are generated in directory ./out, to avoid overriding files with
|
|
manual changes. This directory is assumed to exist.
|
|
"""
|
|
|
|
|
|
def parse_options(args: list[str]) -> Options:
|
|
parser = argparse.ArgumentParser(prog="stubgen", usage=HEADER, description=DESCRIPTION)
|
|
|
|
parser.add_argument(
|
|
"--ignore-errors",
|
|
action="store_true",
|
|
help="ignore errors when trying to generate stubs for modules",
|
|
)
|
|
parser.add_argument(
|
|
"--no-import",
|
|
action="store_true",
|
|
help="don't import the modules, just parse and analyze them "
|
|
"(doesn't work with C extension modules and might not "
|
|
"respect __all__)",
|
|
)
|
|
parser.add_argument(
|
|
"--parse-only",
|
|
action="store_true",
|
|
help="don't perform semantic analysis of sources, just parse them "
|
|
"(only applies to Python modules, might affect quality of stubs)",
|
|
)
|
|
parser.add_argument(
|
|
"--include-private",
|
|
action="store_true",
|
|
help="generate stubs for objects and members considered private "
|
|
"(single leading underscore and no trailing underscores)",
|
|
)
|
|
parser.add_argument(
|
|
"--export-less",
|
|
action="store_true",
|
|
help="don't implicitly export all names imported from other modules in the same package",
|
|
)
|
|
parser.add_argument(
|
|
"--include-docstrings",
|
|
action="store_true",
|
|
help="include existing docstrings with the stubs",
|
|
)
|
|
parser.add_argument("-v", "--verbose", action="store_true", help="show more verbose messages")
|
|
parser.add_argument("-q", "--quiet", action="store_true", help="show fewer messages")
|
|
parser.add_argument(
|
|
"--doc-dir",
|
|
metavar="PATH",
|
|
default="",
|
|
help="use .rst documentation in PATH (this may result in "
|
|
"better stubs in some cases; consider setting this to "
|
|
"DIR/Python-X.Y.Z/Doc/library)",
|
|
)
|
|
parser.add_argument(
|
|
"--search-path",
|
|
metavar="PATH",
|
|
default="",
|
|
help="specify module search directories, separated by ':' "
|
|
"(currently only used if --no-import is given)",
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output",
|
|
metavar="PATH",
|
|
dest="output_dir",
|
|
default="out",
|
|
help="change the output directory [default: %(default)s]",
|
|
)
|
|
parser.add_argument(
|
|
"-m",
|
|
"--module",
|
|
action="append",
|
|
metavar="MODULE",
|
|
dest="modules",
|
|
default=[],
|
|
help="generate stub for module; can repeat for more modules",
|
|
)
|
|
parser.add_argument(
|
|
"-p",
|
|
"--package",
|
|
action="append",
|
|
metavar="PACKAGE",
|
|
dest="packages",
|
|
default=[],
|
|
help="generate stubs for package recursively; can be repeated",
|
|
)
|
|
parser.add_argument(
|
|
metavar="files",
|
|
nargs="*",
|
|
dest="files",
|
|
help="generate stubs for given files or directories",
|
|
)
|
|
|
|
ns = parser.parse_args(args)
|
|
|
|
pyversion = sys.version_info[:2]
|
|
ns.interpreter = sys.executable
|
|
|
|
if ns.modules + ns.packages and ns.files:
|
|
parser.error("May only specify one of: modules/packages or files.")
|
|
if ns.quiet and ns.verbose:
|
|
parser.error("Cannot specify both quiet and verbose messages")
|
|
|
|
# Create the output folder if it doesn't already exist.
|
|
if not os.path.exists(ns.output_dir):
|
|
os.makedirs(ns.output_dir)
|
|
|
|
return Options(
|
|
pyversion=pyversion,
|
|
no_import=ns.no_import,
|
|
doc_dir=ns.doc_dir,
|
|
search_path=ns.search_path.split(":"),
|
|
interpreter=ns.interpreter,
|
|
ignore_errors=ns.ignore_errors,
|
|
parse_only=ns.parse_only,
|
|
include_private=ns.include_private,
|
|
output_dir=ns.output_dir,
|
|
modules=ns.modules,
|
|
packages=ns.packages,
|
|
files=ns.files,
|
|
verbose=ns.verbose,
|
|
quiet=ns.quiet,
|
|
export_less=ns.export_less,
|
|
include_docstrings=ns.include_docstrings,
|
|
)
|
|
|
|
|
|
def main(args: list[str] | None = None) -> None:
|
|
mypy.util.check_python_version("stubgen")
|
|
# Make sure that the current directory is in sys.path so that
|
|
# stubgen can be run on packages in the current directory.
|
|
if not ("" in sys.path or "." in sys.path):
|
|
sys.path.insert(0, "")
|
|
|
|
options = parse_options(sys.argv[1:] if args is None else args)
|
|
generate_stubs(options)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|