Source code for onnx_array_api.profiling

import cProfile
import json
import math
import os
import site
from collections import OrderedDict, deque
from io import StringIO
from pstats import SortKey, Stats
from typing import Any, Callable, Dict, List, Optional


class ProfileNode:
    """
    Graph structure to represent a profiling.

    :param filename: filename
    :param line: line number
    :param func_name: function name
    :param nc1: number of calls 1
    :param nc2: number of calls 2
    :param tin: time spent in the function
    :param tout: time spent in the function and in the sub functions
    """

    def __init__(
        self,
        filename: str,
        line: int,
        func_name: str,
        nc1: int,
        nc2: int,
        tin: float,
        tall: float,
    ):
        if "method 'disable' of '_lsprof.Profiler'" in func_name:
            raise RuntimeError(f"Function not allowed in the profiling: {func_name!r}.")
        self.filename = filename
        self.line = line
        self.func_name = func_name
        self.nc1 = nc1
        self.nc2 = nc2
        self.tin = tin
        self.tall = tall
        self.called_by = []
        self.calls_to = []
        self.calls_to_elements = []

    def add_called_by(self, pnode: "ProfileNode"):
        "This function is called by these lines."
        self.called_by.append(pnode)

    def add_calls_to(self, pnode: "ProfileNode", time_elements):
        "This function calls this node."
        self.calls_to.append(pnode)
        self.calls_to_elements.append(time_elements)

    @staticmethod
    def _key(filename: str, line: int, fct: Callable) -> str:
        key = "%s:%d:%s" % (filename, line, fct)
        return key

    @property
    def key(self):
        "Returns `file:line`."
        return ProfileNode._key(self.filename, self.line, self.func_name)

    def get_root(self):
        "Returns the root of the graph."
        done = set()

        def _get_root(node, stor=None):
            if stor is not None:
                stor.append(node)
            if len(node.called_by) == 0:
                return node
            if len(node.called_by) == 1:
                return _get_root(node.called_by[0], stor=stor)
            res = None
            for ct in node.called_by:
                k = id(node), id(ct)
                if k in done:
                    continue
                res = ct
                break
            if res is None:
                # All paths have been explored and no entry point was found.
                # Choosing the most consuming function.
                return None
            done.add((id(node), id(res)))
            return _get_root(res, stor=stor)

        root = _get_root(self)
        if root is None:
            candidates = []
            _get_root(self, stor=candidates)
            tall = [(n.tall, n) for n in candidates]
            tall.sort()
            root = tall[-1][-1]
        return root

    def __repr__(self) -> str:
        "usual"
        return "%s(%r, %r, %r, %r, %r, %r, %r)  # %d-%d" % (
            self.__class__.__name__,
            self.filename,
            self.line,
            self.func_name,
            self.nc1,
            self.nc2,
            self.tin,
            self.tall,
            len(self.called_by),
            len(self.calls_to),
        )

    def __iter__(self):
        "Returns all nodes in the graph."
        done = set()
        stack = deque()
        stack.append(self)
        while len(stack) > 0:
            node = stack.popleft()
            if node.key in done:
                continue
            yield node
            done.add(node.key)
            stack.extend(node.calls_to)

    _modules_ = {
        "~",
        "subprocess.py",
        "posixpath.py",
        "os.py",
        "<frozen importlib._bootstrap>",
        "inspect.py",
        "version.py",
        "typing.py",
        "warnings.py",
        "errors.py",
        "numbers.py",
        "ast.py",
        "threading.py",
        "_collections_abc.py",
        "datetime.py",
        "abc.py",
        "argparse.py",
        "__future__.py",
        "functools.py",
        "six.py",
        "sre_parse.py",
        "contextlib.py",
        " _globals.py",
        "_ios.py",
        "types.py",
    }

    @staticmethod
    def filter_node_(node, info=None) -> bool:
        """
        Filters out node to be displayed by default.

        :param node: node
        :param info: if the node is called by a function,
            this dictionary can be used to overwrite the attributes
            held by the node
        :return: boolean (True to keep, False to forget)
        """
        if node.filename in ProfileNode._modules_:
            if info is None:
                if node.nc1 <= 10 and node.nc2 <= 10 and node.tall <= 1e-4:
                    return False
            else:
                if info["nc1"] <= 10 and info["nc2"] <= 10 and info["tall"] <= 1e-4:
                    return False

        return True

    def as_dict(self, filter_node=None, sort_key=SortKey.LINE):
        """
        Renders the results of a profiling interpreted with
        function @fn profile2graph. It can then be loaded with
        a dataframe.

        :param filter_node: display only the nodes for which
            this function returns True, if None, the default function
            removes built-in function with small impact
        :param sort_key: sort sub nodes by...
        :return: rows
        """

        def sort_key_line(dr):
            if isinstance(dr, tuple):
                return (dr[0].filename, dr[0].line)
            return (dr.filename, dr.line)

        def sort_key_tin(dr):
            if isinstance(dr, tuple):
                return -dr[1][2]
            return -dr.tin

        def sort_key_tall(dr):
            if isinstance(dr, tuple):
                return -dr[1][3]
            return -dr.tall

        if sort_key == SortKey.LINE:
            sortk = sort_key_line
        elif sort_key == SortKey.CUMULATIVE:
            sortk = sort_key_tall
        elif sort_key == SortKey.TIME:
            sortk = sort_key_tin
        else:
            raise NotImplementedError(
                f"Unable to sort subcalls with this key {sort_key!r}."
            )

        def depth_first(node, roots_keys, indent=0):
            text = {
                "fct": node.func_name,
                "where": node.key,
                "nc1": node.nc1,
                "nc2": node.nc2,
                "tin": node.tin,
                "tall": node.tall,
                "indent": indent,
                "ncalls": len(node.calls_to),
                "debug": "A",
            }
            yield text
            for n, nel in sorted(zip(node.calls_to, node.calls_to_elements), key=sortk):
                if n.key in roots_keys:
                    text = {
                        "fct": n.func_name,
                        "where": n.key,
                        "nc1": nel[0],
                        "nc2": nel[1],
                        "tin": nel[2],
                        "tall": nel[3],
                        "indent": indent + 1,
                        "ncalls": len(n.calls_to),
                        "more": "+",
                        "debug": "B",
                    }
                    if filter_node is not None and not filter_node(n, info=text):
                        continue
                    yield text
                else:
                    if filter_node is not None and not filter_node(n):
                        continue
                    for t in depth_first(n, roots_keys, indent + 1):
                        yield t

        if filter_node is None:
            filter_node = ProfileNode.filter_node_
        nodes = list(self)
        roots = [node for node in nodes if len(node.called_by) != 1]
        roots_key = {r.key: r for r in roots}
        rows = []
        for root in sorted(roots, key=sortk):
            if filter_node is not None and not filter_node(root):
                continue
            rows.extend(depth_first(root, roots_key))
        return rows

    def to_text(self, filter_node=None, sort_key=SortKey.LINE, fct_width=60) -> str:
        """
        Prints the profiling to text.

        :param filter_node: display only the nodes for which
            this function returns True, if None, the default function
            removes built-in function with small impact
        :param sort_key: sort sub nodes by...
        :return: rows
        """

        def align_text(text, size):
            if size <= 0:
                return text
            if len(text) <= size:
                return text + " " * (size - len(text))
            h = size // 2 - 1
            return text[:h] + "..." + text[-h + 1 :]

        dicts = self.as_dict(filter_node=filter_node, sort_key=sort_key)
        max_nc = max(max(_["nc1"] for _ in dicts), max(_["nc2"] for _ in dicts))
        dg = int(math.log(max_nc) / math.log(10) + 1.5)
        line_format = (
            "{indent}{fct} -- {nc1: %dd} {nc2: %dd} -- {tin:1.5f} {tall:1.5f}"
            " -- {name} ({fct2})" % (dg, dg)
        )
        text = []
        for row in dicts:
            line = line_format.format(
                indent=" " * (row["indent"] * 4),
                fct=align_text(row["fct"], fct_width - row["indent"] * 4),
                nc1=row["nc1"],
                nc2=row["nc2"],
                tin=row["tin"],
                tall=row["tall"],
                name=row["where"],
                fct2=row["fct"],
            )
            if row.get("more", "") == "+":
                line += " +++"
            text.append(line)
        return "\n".join(text)

    def to_json(
        self, filter_node=None, sort_key=SortKey.LINE, as_str=True, **kwargs
    ) -> str:
        """
        Renders the results of a profiling interpreted with
        function @fn profile2graph as :epkg:`JSON`.

        :param filter_node: display only the nodes for which
            this function returns True, if None, the default function
            removes built-in function with small impact
        :param sort_key: sort sub nodes by...
        :param as_str: converts the json into a string
        :param kwargs: see :func:`json.dumps`
        :return: rows
        """

        def sort_key_line(dr):
            if isinstance(dr, tuple):
                return (dr[0].filename, dr[0].line)
            return (dr.filename, dr.line)

        def sort_key_tin(dr):
            if isinstance(dr, tuple):
                return -dr[1][2]
            return -dr.tin

        def sort_key_tall(dr):
            if isinstance(dr, tuple):
                return -dr[1][3]
            return -dr.tall

        if sort_key == SortKey.LINE:
            sortk = sort_key_line
        elif sort_key == SortKey.CUMULATIVE:
            sortk = sort_key_tall
        elif sort_key == SortKey.TIME:
            sortk = sort_key_tin
        else:
            raise NotImplementedError(
                f"Unable to sort subcalls with this key {sort_key!r}."
            )

        def walk(node, roots_keys, indent=0):
            item = {
                "details": {
                    "fct": node.func_name,
                    "where": node.key,
                    "nc1": node.nc1,
                    "nc2": node.nc2,
                    "tin": node.tin,
                    "tall": node.tall,
                    "indent": indent,
                    "ncalls": len(node.calls_to),
                }
            }

            child = OrderedDict()
            for n, nel in sorted(zip(node.calls_to, node.calls_to_elements), key=sortk):
                key = (nel[0], f"{nel[3]:1.5f}:{n.func_name}")
                if n.key in roots_keys:
                    details = {
                        "fct": n.func_name,
                        "where": n.key,
                        "nc1": nel[0],
                        "nc2": nel[1],
                        "tin": nel[2],
                        "tall": nel[3],
                        "indent": indent,
                        "ncalls": len(node.calls_to),
                    }
                    if filter_node is not None and not filter_node(n, info=details):
                        continue
                    child[key] = {"details": details}
                else:
                    if filter_node is not None and not filter_node(n):
                        continue
                    child[key] = walk(n, roots_key, indent + 1)

            if len(child) > 0:
                mx = max(_[0] for _ in child)
                dg = int(math.log(mx) / math.log(10) + 1.5)
                form = f"%-{dg}d-%s"
                child = OrderedDict((form % k, v) for k, v in child.items())
                item["calls"] = child
            return item

        if filter_node is None:
            filter_node = ProfileNode.filter_node_
        nodes = list(self)
        roots = [node for node in nodes if len(node.called_by) != 1]
        roots_key = {r.key: r for r in roots}
        rows = OrderedDict()
        for root in sorted(roots, key=sortk):
            if filter_node is not None and not filter_node(root):
                continue
            key = (root.nc1, f"{root.tall:1.5f}:::{root.func_name}")
            rows[key] = walk(root, roots_key)
        mx = max(_[0] for _ in rows)
        dg = int(math.log(mx) / math.log(10) + 1.5)
        form = f"%-{dg}d-%s"
        rows = OrderedDict((form % k, v) for k, v in rows.items())
        if as_str:
            return json.dumps({"profile": rows}, **kwargs)
        return {"profile": rows}


def _process_pstats(
    ps: Stats,
    clean_text: Optional[Callable] = None,
    verbose: bool = False,
    fLOG: Optional[Callable] = None,
) -> List[Dict[str, Any]]:
    """
    Converts class `Stats <https://docs.python.org/3/library/
    profile.html#pstats.Stats>`_ into something
    readable for a dataframe.

    :param ps: instance of type :func:`pstats.Stats`
    :param clean_text: function to clean function names
    :param verbose: change verbosity
    :param fLOG: logging function
    :return: list of rows
    """
    if clean_text is None:
        clean_text = lambda x: x

    def add_rows(rows, d):
        tt1, tt2 = 0, 0
        for k, v in d.items():
            stin = 0
            stall = 0
            if verbose and fLOG is not None:
                fLOG(
                    "[pstats] %s=%r"
                    % ((clean_text(k[0].replace("\\", "/")),) + k[1:], v)
                )
            if len(v) < 5:
                continue
            row = {
                "file": "%s:%d" % (clean_text(k[0].replace("\\", "/")), k[1]),
                "fct": k[2],
                "ncalls1": v[0],
                "ncalls2": v[1],
                "tin": v[2],
                "tall": v[3],
            }
            stin += v[2]
            stall += v[3]
            if len(v) == 5:
                t1, t2 = add_rows(rows, v[-1])
                stin += t1
                stall += t2
            row["cum_tin"] = stin
            row["cum_tall"] = stall
            rows.append(row)
            tt1 += stin
            tt2 += stall
        return tt1, tt2

    rows = []
    add_rows(rows, ps.stats)
    return rows


[docs]def profile2df( ps: Stats, as_df: bool = True, clean_text: bool = None, verbose: bool = False, fLOG=None, ): """ Converts profiling statistics into a Dataframe. :param ps: an instance of `pstats <https://docs.python.org/3/library/profile.html#pstats.Stats>`_ :param as_df: returns the results as a dataframe (True) or a list of dictionaries (False) :param clean_text: function to clean function names :param verbose: verbosity :param fLOG: logging function :return: a DataFrame :: import pstats from pyquickhelper.pycode.profiling import profile2df ps = pstats.Stats('bench_ortmodule_nn_gpu6.prof') df = profile2df(pd) print(df) """ rows = _process_pstats(ps, clean_text, verbose=verbose, fLOG=fLOG) if not as_df: return rows import pandas df = pandas.DataFrame(rows) df = df[["fct", "file", "ncalls1", "ncalls2", "tin", "cum_tin", "tall", "cum_tall"]] df = ( df.groupby(["fct", "file"], as_index=False) .sum() .sort_values("cum_tall", ascending=False) .reset_index(drop=True) ) return df.copy()
[docs]def profile( fct: Callable, sort: str = "cumulative", rootrem: Optional[str] = None, as_df: bool = False, return_results=False, **kwargs, ) -> str: """ Profiles the execution of a function. :param fct: function to profile :param sort: see `sort_stats <https://docs.python.org/3/library/ profile.html#pstats.Stats.sort_stats>`_ :param rootrem: root to remove in filenames :param as_df: return the results as a dataframe and not text :param return_results: if True, return results as well (in the first position) :param kwargs: additional parameters used to create the profiler :return: raw results, statistics text dump (or dataframe is *as_df* is True) .. plot:: import matplotlib.pyplot as plt from pyquickhelper.pycode.profiling import profile from pyquickhelper.texthelper import compare_module_version def fctm(): return compare_module_version('0.20.4', '0.22.dev0') pr, df = profile(lambda: [fctm() for i in range(0, 1000)], as_df=True) ax = df[['namefct', 'cum_tall']].head(n=15).set_index( 'namefct').plot(kind='bar', figsize=(8, 3), rot=30) ax.set_title("example of a graph") for la in ax.get_xticklabels(): la.set_horizontalalignment('right'); plt.show() """ pr = cProfile.Profile(**kwargs) pr.enable() fct_res = fct() pr.disable() s = StringIO() ps = Stats(pr, stream=s).sort_stats(sort) ps.print_stats() res = s.getvalue() try: pack = site.getsitepackages() except AttributeError: import numpy pack = os.path.normpath( os.path.abspath(os.path.join(os.path.dirname(numpy.__file__), "..")) ) pack = [pack] pack_ = os.path.normpath(os.path.join(pack[-1], "..")) def clean_text(res): res = res.replace(pack[-1], "site-packages") res = res.replace(pack_, "lib") if rootrem is not None: if isinstance(rootrem, str): res = res.replace(rootrem, "") else: for sub in rootrem: if isinstance(sub, str): res = res.replace(sub, "") elif isinstance(sub, tuple) and len(sub) == 2: res = res.replace(sub[0], sub[1]) else: raise TypeError( "rootrem must contains strings or tuple not {0}" ".".format(rootrem) ) return res if as_df: def better_name(row): if len(row["fct"]) > 15: return f"{row['file'].split(':')[-1]}-{row['fct']}" name = row["file"].replace("\\", "/") return f"{name.split('/')[-1]}-{row['fct']}" rows = _process_pstats(ps, clean_text) import pandas df = pandas.DataFrame(rows) df = df[ [ "fct", "file", "ncalls1", "ncalls2", "tin", "cum_tin", "tall", "cum_tall", ] ] df["namefct"] = df.apply(lambda row: better_name(row), axis=1) df = ( df.groupby(["namefct", "file"], as_index=False) .sum() .sort_values("cum_tall", ascending=False) .reset_index(drop=True) ) if return_results: return fct_res, ps, df return ps, df res = clean_text(res) if return_results: return fct_res, ps, res return ps, res
[docs]def profile2graph( ps: Stats, clean_text: Optional[Callable] = None, verbose: bool = False, fLOG: Optional[Callable] = None, ) -> ProfileNode: """ Converts profiling statistics into a graphs. :param ps: an instance of `pstats <https://docs.python.org/3/library/profile.html#pstats.Stats>`_ :param clean_text: function to clean function names :param verbose: verbosity :param fLOG: logging function :return: an instance of class @see cl ProfileNode :epkg:`pyinstrument` has a nice display to show time spent and call stack at the same time. This function tries to replicate that display based on the results produced by module :mod:`cProfile`. Here is an example. .. runpython:: :showcode: import time from onnx_array_api.profiling import profile, profile2graph def fct0(t): time.sleep(t) def fct1(t): time.sleep(t) def fct2(): fct1(0.1) fct1(0.01) def fct3(): fct0(0.2) fct1(0.5) def fct4(): fct2() fct3() ps = profile(fct4)[0] root, nodes = profile2graph(ps, clean_text=lambda x: x.split('/')[-1]) text = root.to_text() print(text) """ if clean_text is None: clean_text = lambda x: x nodes = {} for k, v in ps.stats.items(): if verbose and fLOG is not None: fLOG(f"[pstats] {k}={v!r}") if len(v) < 5: continue if k[0] == "~" and len(v) == 0: # raw function never called by another continue if "method 'disable' of '_lsprof.Profiler'" in k[2]: continue node = ProfileNode( filename=clean_text(k[0].replace("\\", "/")), line=k[1], func_name=k[2], nc1=v[0], nc2=v[1], tin=v[2], tall=v[3], ) if node.key in nodes: raise RuntimeError(f"Key {node.key!r} is already present, node={node!r}.") nodes[node.key] = node for k, v in ps.stats.items(): if "method 'disable' of '_lsprof.Profiler'" in k[2]: continue filename = clean_text(k[0].replace("\\", "/")) ks = ProfileNode._key(filename, k[1], k[2]) node = nodes[ks] sublist = v[4] for f, vv in sublist.items(): if "method 'disable' of '_lsprof.Profiler'" in f[2]: continue name = clean_text(f[0].replace("\\", "/")) key = ProfileNode._key(name, f[1], f[2]) if key not in nodes: raise RuntimeError( "Unable to find key %r into\n%s" % (key, "\n".join(sorted(nodes))) ) if k[0] == "~" and len(v) == 0: continue child = nodes[key] node.add_called_by(child) child.add_calls_to(node, vv) for k, v in nodes.items(): root = v.get_root() break return root, nodes