Source code for mermaidmro

# coding: utf-8

from __future__ import annotations

__all__ = [
    "get_mermaid_text",
    "get_style_text",
    "get_relations",
    "encode_text",
    "encode_json",
    "download_graph",
    "get_default_name_func",
]

import os
import zlib
import base64
import json
import importlib
import fnmatch
import collections
import urllib.request
from typing import Callable

try:
    import requests
    HAS_REQUESTS = True
except ModuleNotFoundError:
    HAS_REQUESTS = False

# package infos
from mermaidmro.__meta__ import (  # noqa
    __doc__,
    __author__,
    __email__,
    __copyright__,
    __credits__,
    __contact__,
    __license__,
    __status__,
    __version__,
)


# mermaid endpoints
URL_STATIC = "https://mermaid.ink/img/{}?type={}"
URL_STATIC_JSON = "https://mermaid.ink/img/pako:{}?type={}"
URL_EDIT_JSON = "https://mermaid.live/edit#pako:{}"

#: Relation between two classes including a depth value and the mro index with respect to the
#: requested root class (namedtuple).
Relation = collections.namedtuple("Relation", ["cls", "base_cls", "root_cls", "depth", "mro"])

#: Container object with attributes to define css styles for one or multiple classes (namedtuple).
Style = collections.namedtuple("Style", ["name", "cls", "css"])


[docs]def get_relations( root_cls: type, max_depth: int = -1, ) -> list[Relation]: """ Recursively extracts base classes of a *root_cls* down to a maximum depth *max_depth* and returns them in a list of :py:class:`Relation` objects. When *max_depth* is negative, the lookup is fully recursive, possibly down to ``object``. :param root_cls: The root class to use. :param max_depth: Maximum recursion depth. :return: The list of found :py:class:`Relation` objects. """ # stop early if max_depth == 0: return [] # get the mro mro = {cls: i for i, cls in enumerate(root_cls.__mro__)} # iterate recursively with lookup pattern lookup = [(root_cls, 0)] seen = set() relations = [] while lookup: cls, depth = lookup.pop(0) # skip when cls already handled if cls in seen: continue # handle base classes for base_cls in cls.__bases__: # add class relation, starting at depth 1 relations.append(Relation(cls, base_cls, root_cls, depth + 1, mro.get(base_cls, -1))) # ammend lookup when depth below maximum if max_depth < 0 or depth + 1 < max_depth: lookup.append((base_cls, depth + 1)) # mark as seen seen.add(cls) return relations
[docs]def get_default_name_func( skip_modules: list[str] | set[str] | None = None, ) -> Callable[[type], str]: """ Returns a function that takes an arbitrary class and extracts its name representation, usually in the format ``"module_name.class_name"``. *skip_modules* can be a sequence of modules names of patterns matching module names that are not preprended. Please note that ``builtins`` and ``__main__`` are always skipped. :param skip_modules: Optional squence of module names (or patterns) to skip. :return: Function that takes a class and returns the name for visualization. """ # default list of module names to skip skip_modules = set(skip_modules or []) # always skip certain modules skip_modules |= {"builtins", "__main__"} def name_func(cls: type) -> str: if isinstance(cls, str): return cls if cls.__module__ and not any(fnmatch.fnmatch(cls.__module__, m) for m in skip_modules): return f"{cls.__module__}.{cls.__qualname__}" return cls.__qualname__ return name_func
[docs]def get_style_text( styles: list[Style | tuple], indentation: str = " ", name_func: Callable[[type], str] | None = None, skip_modules: list[str] | set[str] | None = None, join_lines: bool = True, ) -> str | list[str]: """ Creates the string representation of style statements for mermaid graphs consisting of style definitions followed by assignments to graph nodes. *styles* should be a sequence of :py:class:`Style` objects (or tuples that can be interpreted as such) containing the name of the style, the name(s) of the class(es) it is applied to, and one or multiple css-like strings, e.g. ``"stroke-width: 3px"``. Example: .. code-block:: python get_style_text([ Style(name="Bold", cls=["ClassA", "ClassB"], css=["stroke-width: 3px"]), Style(name="Colored", cls="ClassA", css=["stroke-width: 3px", "stroke: #83b"]), ]) # classDef Bold stroke-width: 3px # classDef Colored stroke-width: 3px, stroke: #83b # # class ClassA Bold # class ClassB Bold # class ClassA Colored :param styles: Sequence of :py:class:`Style` objects or tuples that can be interpreted as such. :param indentation: The indentation of lines. :param name_func: A function to extract the string representation of a class, defaulting to the return value of :py:func:`get_default_name_func` passing *skip_modules*. :param skip_modules: Sequence of module names (or patterns) to skip when no *name_func* is set. :param join_lines: Whether generated lines should be joined to a string. :return: The style as a text representation or as single lines in a list. """ # default name_func if name_func is None: name_func = get_default_name_func(skip_modules=skip_modules) # ensure styles is a list of style objects styles = [ style if isinstance(style, Style) else Style(*style) for style in styles ] # style definitions lines = [] for style in styles: attr_str = ( style.css if isinstance(style.css, str) else ", ".join(style.css) ) lines.append(f"{indentation}classDef {style.name} {attr_str}") # empty line lines.append("") # class assignments for style in styles: classes = style.cls if isinstance(style.cls, (list, tuple, set)) else [style.cls] for cls in classes: lines.append(f"{indentation}class {name_func(cls)} {style.name}") # join or return as list of lines return "\n".join(lines) if join_lines else lines
[docs]def get_mermaid_text( root_cls: type, max_depth: int = -1, styles: list[Style | tuple] | None = None, show_mro: bool = True, graph_type: str = "TD", arrow_type: str = "-->", indentation: str = " ", skip_func: Callable[[type, Callable], bool] | None = None, name_func: Callable[[type], str] | None = None, skip_modules: list[str] | set[str] | None = None, join_lines: bool = True, ) -> str | list[str]: """ Creates a text representation of the inheritance graph for a *root_cls*, down to a maximum recursion depth *max_depth*. When *styles* is given, the representation contains style statements generated via :py:func:`get_style_text`. When *show_mro* is *True*, mro indices with respect to *root_cls* are shown. The type of the graph and style of arrows can be controlled with *graph_type* and *arrow_type*. Example: .. code-block:: python class A(object): pass class B(object): pass class C(A): pass class D(C, B): pass get_mermaid_text(D) # graph TD # D("D (0)") # C("C (1)") # A("A (2)") # B("B (3)") # object("object (4)") # # C --> D # B --> D # A --> C # object --> A # object --> B get_mermaid_text(D, show_mro=False) # graph TD # C --> D # B --> D # A --> C # object --> A # object --> B :param root_cls: The root class to use. :param max_depth: Maximum recursion depth for the lookup in :py:func:`get_relations`. :param styles: Sequence of :py:class:`Style` objects or tuples that can be interpreted as such. :param show_mro: Whether mro indices should be included. :param graph_type: The mermaid graph type to use, e.g. ``"TD"`` or ``"LR"``. :param arrow_type: The default arrow type to use between classes, e.g. ``"-->"``. :param indentation: The indentation of lines. :param skip_func: A function to decide whether a specific base class should be skipped given the class itself and the *name_func* as arguments. :param name_func: A function to extract the string representation of a class, defaulting to the return value of :py:func:`get_default_name_func` passing *skip_modules*. :param skip_modules: Sequence of module names (or patterns) to skip when no *name_func* is set. :param join_lines: Whether generated lines should be joined to a string. :return: The style as a text representation or as single lines in a list. """ # default name_func if name_func is None: name_func = get_default_name_func(skip_modules=skip_modules) # get relations relations = get_relations(root_cls, max_depth=max_depth) # build lines lines = [f"graph {graph_type}"] # add labels with mro indices if show_mro: mro_label = lambda cls, i: f"{indentation}{name_func(cls)}(\"{name_func(cls)} ({i})\")" mro_pairs = {(root_cls, 0)} | {(rel.base_cls, rel.mro) for rel in relations} lines.extend([ mro_label(base_cls, mro) for base_cls, mro in sorted(mro_pairs, key=lambda tpl: tpl[1]) ]) lines.append("") # add relations for rel in relations: # potentially skip if callable(skip_func) and skip_func(rel.base_cls, name_func): continue # add line lines.append(f"{indentation}{name_func(rel.base_cls)} {arrow_type} {name_func(rel.cls)}") # add styles if styles: # add lines lines.append("") lines.extend(get_style_text( styles, indentation=indentation, name_func=name_func, join_lines=False, )) # join or return as list of lines return "\n".join(lines) if join_lines else lines
[docs]def encode_text( mermaid_text: str, ) -> str: """ Returns a base64 encoded variant of a mermaid graph given in *mermaid_text*, that is usually used in static urls of the mermaidjs service. :param mermaid_text: The graph as a string representation. :return: The base64 encoded representation of the text. """ return base64.urlsafe_b64encode(mermaid_text.encode("utf-8")).decode("utf-8")
[docs]def encode_json( mermaid_text: str, theme: str | None = "default", ) -> str: r""" Returns a base64 encoded and compressed variant of a mermaid graph given in *mermaid_text* and additional configuration options such as *theme*, that is usually used in urls of the mermaidjs live editing service. The structured data that is compressed has the format .. code-block:: json { "code": "graph TD\n....", "mermaid": "{\"theme\": ...}" } as expected by mermaidjs. :param mermaid_text: The graph as a string representation. :param theme: Name of the theme to use. :return: The base64 encoded and compressed representation of the structured data containing the graph and configuration options. """ data = json.dumps({ "code": mermaid_text, "mermaid": json.dumps({"theme": theme} if theme else {}), }) return base64.urlsafe_b64encode(zlib.compress(data.encode("utf-8"), level=9)).decode("utf-8")
[docs]def download_graph( mermaid_text: str, path: str, file_type: str = "jpg", theme: str | None = "default", ) -> str: """ Downloads a mermaid graph represented by *mermaid_text* from the mermaidjs service to a *path* in a specific *file_type*. Missing intermediate directories are created first. :param mermaid_text: The graph as a string representation. :param path: The path where the downloaded file should be saved. :param file_type: The file type to write, usually ``"jpg"`` or ``"png"``. :param theme: Name of the theme to use. :return: The absolute, normalized and expanded path. """ # normalize path path = os.path.normpath(os.path.expandvars(os.path.expanduser(path))) # ensure parent directory exists parent = os.path.dirname(path) if parent and not os.path.exists(parent): os.makedirs(parent) # download and write url = URL_STATIC_JSON.format(encode_json(mermaid_text, theme=theme), file_type) if HAS_REQUESTS: with open(path, "wb") as f: r = requests.get(url, allow_redirects=True) f.write(r.content) else: opener = urllib.request.build_opener() opener.addheaders = [("User-Agent", f"mermaidmro/{__version__}")] urllib.request.install_opener(opener) urllib.request.urlretrieve(url, path) return path
def _import_class( cid: str, ) -> type: # parse the class identifier if ":" not in cid: raise ValueError(f"invalid format, cannot import '{cid}'") module_name, cls_name = cid.split(":", 1) # import the module mod = importlib.import_module(module_name) # get the cls cls = getattr(mod, cls_name, None) if not cls: raise AttributeError(f"not class named '{cls_name}' in module {mod}") return cls def main( cli_args: list[str] | None = None, test: bool = False, ) -> None | list[str] | str: """ Main entry hook of the mermaidmro cli. The arguments to parse can be configured via *cli_args* and default to ``sys.argv[1:]``. When *test* is *True*, no command is executed by created texts and / or commands are returned instead for testing purposes. :param cli_args: Custom cli arguments. :param test: Whether texts and or commands are returned for testing purposes. :return: Texts or commands if *test* is *True* and *None* otherwise. """ import subprocess import tempfile import shlex import argparse # setup arguments parser = argparse.ArgumentParser( description="visualize class inheritance structures with mermaidjs using the mro", prog="mermaidmro", ) parser.add_argument( "cls", help="the root class to visualize in the format 'module.to.import:class'", ) parser.add_argument( "--max-depth", "-m", metavar="VALUE", help="the maximum depth of the graph; default: -1", type=int, default=-1, ) parser.add_argument( "--no-mro", "-n", action="store_true", help="do not show mro indices", ) parser.add_argument( "--graph-type", "-g", help="the graph type; default: 'TD'", default="TD", ) parser.add_argument( "--arrow-type", "-a", help="the arrow type; default: '-->'", default="-->", ) parser.add_argument( "--cmd", "-c", metavar="CMD", help="an executable to open the generated url", ) parser.add_argument( "--edit", "-e", action="store_true", help="whether to open the mermaid live editor instead of a static image when --cmd is set", ) parser.add_argument( "--download", "-d", metavar="PATH", help="path for downloading the graph file instead", ) parser.add_argument( "--visualize", "-v", metavar="CMD", help="executable for visualizing the graph from the (temporarily) downloaded file", ) parser.add_argument( "--file-type", "-f", metavar="TYPE", help="the file type to open or download; has no effect when --edit is set", choices=["png", "jpg"], default="png", ) parser.add_argument( "--args", help="additional arguments to be added to the commands given via --cmd or --visualize", type=shlex.split, ) args = parser.parse_args(cli_args) # overwrite the file type when downloading if args.download: args.file_type = os.path.splitext(args.download)[-1].strip(".") or args.file_type # import the class cls = _import_class(args.cls) # generate the mermaid text mermaid_text = get_mermaid_text( cls, max_depth=args.max_depth, show_mro=not args.no_mro, graph_type=args.graph_type.strip(), arrow_type=args.arrow_type.strip(), ) # trigger actions show_text = True # download and / or visualize if args.download or args.visualize: with tempfile.NamedTemporaryFile(suffix=f".{args.file_type}") as f: vis_path = download_graph( mermaid_text, args.download or f.name, file_type=args.file_type, ) if args.visualize: cmd = [args.visualize, vis_path] + (args.args or []) if test: return cmd subprocess.run(cmd) show_text = False if args.cmd: # open an url mermaid_json = encode_json(mermaid_text) if args.edit: url = URL_EDIT_JSON.format(mermaid_json) else: url = URL_STATIC_JSON.format(mermaid_json, args.file_type) # run the command cmd = [args.cmd, url] + (args.args or []) if test: return cmd subprocess.run(cmd) show_text = False if show_text: # just print the text if test: return mermaid_text print(mermaid_text) # entry hook if __name__ == "__main__": main()