236 lines
8.7 KiB
Python
236 lines
8.7 KiB
Python
|
# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
|
||
|
# For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
|
||
|
# Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
|
||
|
|
||
|
"""Handle diagram generation options for class diagram or default diagrams."""
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import argparse
|
||
|
from collections.abc import Generator
|
||
|
from typing import Any
|
||
|
|
||
|
import astroid
|
||
|
from astroid import nodes
|
||
|
from astroid.modutils import is_stdlib_module
|
||
|
|
||
|
from pylint.pyreverse.diagrams import ClassDiagram, PackageDiagram
|
||
|
from pylint.pyreverse.inspector import Linker, Project
|
||
|
from pylint.pyreverse.utils import LocalsVisitor
|
||
|
|
||
|
# diagram generators ##########################################################
|
||
|
|
||
|
|
||
|
class DiaDefGenerator:
|
||
|
"""Handle diagram generation options."""
|
||
|
|
||
|
def __init__(self, linker: Linker, handler: DiadefsHandler) -> None:
|
||
|
"""Common Diagram Handler initialization."""
|
||
|
self.config = handler.config
|
||
|
self.module_names: bool = False
|
||
|
self._set_default_options()
|
||
|
self.linker = linker
|
||
|
self.classdiagram: ClassDiagram # defined by subclasses
|
||
|
|
||
|
def get_title(self, node: nodes.ClassDef) -> str:
|
||
|
"""Get title for objects."""
|
||
|
title = node.name
|
||
|
if self.module_names:
|
||
|
title = f"{node.root().name}.{title}"
|
||
|
return title # type: ignore[no-any-return]
|
||
|
|
||
|
def _set_option(self, option: bool | None) -> bool:
|
||
|
"""Activate some options if not explicitly deactivated."""
|
||
|
# if we have a class diagram, we want more information by default;
|
||
|
# so if the option is None, we return True
|
||
|
if option is None:
|
||
|
return bool(self.config.classes)
|
||
|
return option
|
||
|
|
||
|
def _set_default_options(self) -> None:
|
||
|
"""Set different default options with _default dictionary."""
|
||
|
self.module_names = self._set_option(self.config.module_names)
|
||
|
all_ancestors = self._set_option(self.config.all_ancestors)
|
||
|
all_associated = self._set_option(self.config.all_associated)
|
||
|
anc_level, association_level = (0, 0)
|
||
|
if all_ancestors:
|
||
|
anc_level = -1
|
||
|
if all_associated:
|
||
|
association_level = -1
|
||
|
if self.config.show_ancestors is not None:
|
||
|
anc_level = self.config.show_ancestors
|
||
|
if self.config.show_associated is not None:
|
||
|
association_level = self.config.show_associated
|
||
|
self.anc_level, self.association_level = anc_level, association_level
|
||
|
|
||
|
def _get_levels(self) -> tuple[int, int]:
|
||
|
"""Help function for search levels."""
|
||
|
return self.anc_level, self.association_level
|
||
|
|
||
|
def show_node(self, node: nodes.ClassDef) -> bool:
|
||
|
"""Determine if node should be shown based on config."""
|
||
|
if node.root().name == "builtins":
|
||
|
return self.config.show_builtin # type: ignore[no-any-return]
|
||
|
|
||
|
if is_stdlib_module(node.root().name):
|
||
|
return self.config.show_stdlib # type: ignore[no-any-return]
|
||
|
|
||
|
return True
|
||
|
|
||
|
def add_class(self, node: nodes.ClassDef) -> None:
|
||
|
"""Visit one class and add it to diagram."""
|
||
|
self.linker.visit(node)
|
||
|
self.classdiagram.add_object(self.get_title(node), node)
|
||
|
|
||
|
def get_ancestors(
|
||
|
self, node: nodes.ClassDef, level: int
|
||
|
) -> Generator[nodes.ClassDef, None, None]:
|
||
|
"""Return ancestor nodes of a class node."""
|
||
|
if level == 0:
|
||
|
return
|
||
|
for ancestor in node.ancestors(recurs=False):
|
||
|
if not self.show_node(ancestor):
|
||
|
continue
|
||
|
yield ancestor
|
||
|
|
||
|
def get_associated(
|
||
|
self, klass_node: nodes.ClassDef, level: int
|
||
|
) -> Generator[nodes.ClassDef, None, None]:
|
||
|
"""Return associated nodes of a class node."""
|
||
|
if level == 0:
|
||
|
return
|
||
|
for association_nodes in list(klass_node.instance_attrs_type.values()) + list(
|
||
|
klass_node.locals_type.values()
|
||
|
):
|
||
|
for node in association_nodes:
|
||
|
if isinstance(node, astroid.Instance):
|
||
|
node = node._proxied
|
||
|
if not (isinstance(node, nodes.ClassDef) and self.show_node(node)):
|
||
|
continue
|
||
|
yield node
|
||
|
|
||
|
def extract_classes(
|
||
|
self, klass_node: nodes.ClassDef, anc_level: int, association_level: int
|
||
|
) -> None:
|
||
|
"""Extract recursively classes related to klass_node."""
|
||
|
if self.classdiagram.has_node(klass_node) or not self.show_node(klass_node):
|
||
|
return
|
||
|
self.add_class(klass_node)
|
||
|
|
||
|
for ancestor in self.get_ancestors(klass_node, anc_level):
|
||
|
self.extract_classes(ancestor, anc_level - 1, association_level)
|
||
|
|
||
|
for node in self.get_associated(klass_node, association_level):
|
||
|
self.extract_classes(node, anc_level, association_level - 1)
|
||
|
|
||
|
|
||
|
class DefaultDiadefGenerator(LocalsVisitor, DiaDefGenerator):
|
||
|
"""Generate minimum diagram definition for the project :
|
||
|
|
||
|
* a package diagram including project's modules
|
||
|
* a class diagram including project's classes
|
||
|
"""
|
||
|
|
||
|
def __init__(self, linker: Linker, handler: DiadefsHandler) -> None:
|
||
|
DiaDefGenerator.__init__(self, linker, handler)
|
||
|
LocalsVisitor.__init__(self)
|
||
|
|
||
|
def visit_project(self, node: Project) -> None:
|
||
|
"""Visit a pyreverse.utils.Project node.
|
||
|
|
||
|
create a diagram definition for packages
|
||
|
"""
|
||
|
mode = self.config.mode
|
||
|
if len(node.modules) > 1:
|
||
|
self.pkgdiagram: PackageDiagram | None = PackageDiagram(
|
||
|
f"packages {node.name}", mode
|
||
|
)
|
||
|
else:
|
||
|
self.pkgdiagram = None
|
||
|
self.classdiagram = ClassDiagram(f"classes {node.name}", mode)
|
||
|
|
||
|
def leave_project(self, _: Project) -> Any:
|
||
|
"""Leave the pyreverse.utils.Project node.
|
||
|
|
||
|
return the generated diagram definition
|
||
|
"""
|
||
|
if self.pkgdiagram:
|
||
|
return self.pkgdiagram, self.classdiagram
|
||
|
return (self.classdiagram,)
|
||
|
|
||
|
def visit_module(self, node: nodes.Module) -> None:
|
||
|
"""Visit an astroid.Module node.
|
||
|
|
||
|
add this class to the package diagram definition
|
||
|
"""
|
||
|
if self.pkgdiagram:
|
||
|
self.linker.visit(node)
|
||
|
self.pkgdiagram.add_object(node.name, node)
|
||
|
|
||
|
def visit_classdef(self, node: nodes.ClassDef) -> None:
|
||
|
"""Visit an astroid.Class node.
|
||
|
|
||
|
add this class to the class diagram definition
|
||
|
"""
|
||
|
anc_level, association_level = self._get_levels()
|
||
|
self.extract_classes(node, anc_level, association_level)
|
||
|
|
||
|
def visit_importfrom(self, node: nodes.ImportFrom) -> None:
|
||
|
"""Visit astroid.ImportFrom and catch modules for package diagram."""
|
||
|
if self.pkgdiagram:
|
||
|
self.pkgdiagram.add_from_depend(node, node.modname)
|
||
|
|
||
|
|
||
|
class ClassDiadefGenerator(DiaDefGenerator):
|
||
|
"""Generate a class diagram definition including all classes related to a
|
||
|
given class.
|
||
|
"""
|
||
|
|
||
|
def class_diagram(self, project: Project, klass: nodes.ClassDef) -> ClassDiagram:
|
||
|
"""Return a class diagram definition for the class and related classes."""
|
||
|
self.classdiagram = ClassDiagram(klass, self.config.mode)
|
||
|
if len(project.modules) > 1:
|
||
|
module, klass = klass.rsplit(".", 1)
|
||
|
module = project.get_module(module)
|
||
|
else:
|
||
|
module = project.modules[0]
|
||
|
klass = klass.split(".")[-1]
|
||
|
klass = next(module.ilookup(klass))
|
||
|
|
||
|
anc_level, association_level = self._get_levels()
|
||
|
self.extract_classes(klass, anc_level, association_level)
|
||
|
return self.classdiagram
|
||
|
|
||
|
|
||
|
# diagram handler #############################################################
|
||
|
|
||
|
|
||
|
class DiadefsHandler:
|
||
|
"""Get diagram definitions from user (i.e. xml files) or generate them."""
|
||
|
|
||
|
def __init__(self, config: argparse.Namespace) -> None:
|
||
|
self.config = config
|
||
|
|
||
|
def get_diadefs(self, project: Project, linker: Linker) -> list[ClassDiagram]:
|
||
|
"""Get the diagram's configuration data.
|
||
|
|
||
|
:param project:The pyreverse project
|
||
|
:type project: pyreverse.utils.Project
|
||
|
:param linker: The linker
|
||
|
:type linker: pyreverse.inspector.Linker(IdGeneratorMixIn, LocalsVisitor)
|
||
|
|
||
|
:returns: The list of diagram definitions
|
||
|
:rtype: list(:class:`pylint.pyreverse.diagrams.ClassDiagram`)
|
||
|
"""
|
||
|
|
||
|
# read and interpret diagram definitions (Diadefs)
|
||
|
diagrams = []
|
||
|
generator = ClassDiadefGenerator(linker, self)
|
||
|
for klass in self.config.classes:
|
||
|
diagrams.append(generator.class_diagram(project, klass))
|
||
|
if not diagrams:
|
||
|
diagrams = DefaultDiadefGenerator(linker, self).visit(project)
|
||
|
for diagram in diagrams:
|
||
|
diagram.extract_relationships()
|
||
|
return diagrams
|