Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Documentation helper.
4"""
5import keyword
6import textwrap
7import re
8from jinja2 import Template
9from jinja2.runtime import Undefined
10from onnx.defs import OpSchema
11from ...tools import change_style
14def type_mapping(name):
15 """
16 Mapping between types name and type integer value.
18 .. runpython::
19 :showcode:
20 :warningout: DeprecationWarning
22 from mlprodict.onnxrt.doc.doc_helper import type_mapping
23 import pprint
24 pprint.pprint(type_mapping(None))
25 print(type_mapping("INT"))
26 print(type_mapping(2))
27 """
28 di = dict(FLOAT=1, FLOATS=6, GRAPH=5, GRAPHS=10, INT=2,
29 INTS=7, STRING=3, STRINGS=8, TENSOR=4,
30 TENSORS=9, UNDEFINED=0, SPARSE_TENSOR=11)
31 if name is None:
32 return di
33 if isinstance(name, str):
34 return di[name]
35 rev = {v: k for k, v in di.items()}
36 return rev[name]
39def _get_doc_template():
41 return Template(textwrap.dedent("""
42 {% for sch in schemas %}
44 {{format_name_with_domain(sch)}}
45 {{'=' * len(format_name_with_domain(sch))}}
47 {{process_documentation(sch.doc)}}
49 {% if sch.attributes %}
50 **Attributes**
52 {% for _, attr in sorted(sch.attributes.items()) %}* *{{attr.name}}*{%
53 if attr.required %} (required){% endif %}: {{
54 process_attribute_doc(attr.description)}} {%
55 if attr.default_value %} {{
56 process_default_value(attr.default_value)
57 }} ({{type_mapping(attr.type)}}){% endif %}
58 {% endfor %}
59 {% endif %}
61 {% if sch.inputs %}
62 **Inputs**
64 {% if sch.min_input != sch.max_input %}Between {{sch.min_input
65 }} and {{sch.max_input}} inputs.
66 {% endif %}
67 {% for ii, inp in enumerate(sch.inputs) %}
68 * *{{getname(inp, ii)}}*{{format_option(inp)}}{{inp.typeStr}}: {{
69 inp.description}}{% endfor %}
70 {% endif %}
72 {% if sch.outputs %}
73 **Outputs**
75 {% if sch.min_output != sch.max_output %}Between {{sch.min_output
76 }} and {{sch.max_output}} outputs.
77 {% endif %}
78 {% for ii, out in enumerate(sch.outputs) %}
79 * *{{getname(out, ii)}}*{{format_option(out)}}{{out.typeStr}}: {{
80 out.description}}{% endfor %}
81 {% endif %}
83 {% if sch.type_constraints %}
84 **Type Constraints**
86 {% for ii, type_constraint in enumerate(sch.type_constraints)
87 %}* {{getconstraint(type_constraint, ii)}}: {{
88 type_constraint.description}}
89 {% endfor %}
90 {% endif %}
92 **Version**
94 *Onnx name:* `{{sch.name}} <{{build_doc_url(sch)}}{{sch.name}}>`_
96 {% if sch.support_level == OpSchema.SupportType.EXPERIMENTAL %}
97 No versioning maintained for experimental ops.
98 {% else %}
99 This version of the operator has been {% if
100 sch.deprecated %}deprecated{% else %}available{% endif %} since
101 version {{sch.since_version}}{% if
102 sch.domain %} of domain {{sch.domain}}{% endif %}.
103 {% if len(sch.versions) > 1 %}
104 Other versions of this operator:
105 {% for v in sch.version[:-1] %} {{v}} {% endfor %}
106 {% endif %}
107 {% endif %}
109 **Runtime implementation:**
110 :class:`{{sch.name}}
111 <mlprodict.onnxrt.ops_cpu.op_{{change_style(sch.name)}}.{{sch.name}}>`
113 {% endfor %}
114 """))
117_template_operator = _get_doc_template()
120class NewOperatorSchema:
121 """
122 Defines a schema for operators added in this package
123 such as @see cl TreeEnsembleRegressorDouble.
124 """
126 def __init__(self, name):
127 self.name = name
128 self.domain = 'mlprodict'
131def get_rst_doc(op_name):
132 """
133 Returns a documentation in RST format
134 for all :epkg:`OnnxOperator`.
136 :param op_name: operator name of None for all
137 :return: string
139 The function relies on module :epkg:`jinja2` or replaces it
140 with a simple rendering if not present.
141 """
142 from ..ops_cpu._op import _schemas
143 schemas = [_schemas.get(op_name, NewOperatorSchema(op_name))]
145 def format_name_with_domain(sch):
146 if sch.domain:
147 return '{} ({})'.format(sch.name, sch.domain)
148 return sch.name
150 def format_option(obj):
151 opts = []
152 if OpSchema.FormalParameterOption.Optional == obj.option:
153 opts.append('optional')
154 elif OpSchema.FormalParameterOption.Variadic == obj.option:
155 opts.append('variadic')
156 if getattr(obj, 'isHomogeneous', False):
157 opts.append('heterogeneous')
158 if opts:
159 return " (%s)" % ", ".join(opts)
160 return "" # pragma: no cover
162 def getconstraint(const, ii):
163 if const.type_param_str:
164 name = const.type_param_str
165 else:
166 name = str(ii) # pragma: no cover
167 if const.allowed_type_strs:
168 name += " " + ", ".join(const.allowed_type_strs)
169 return name
171 def getname(obj, i):
172 name = obj.name
173 if len(name) == 0:
174 return str(i) # pragma: no cover
175 return name
177 def process_documentation(doc):
178 if doc is None:
179 doc = '' # pragma: no cover
180 if isinstance(doc, Undefined):
181 doc = '' # pragma: no cover
182 if not isinstance(doc, str):
183 raise TypeError( # pragma: no cover
184 "Unexpected type {} for {}".format(type(doc), doc))
185 doc = textwrap.dedent(doc)
186 main_docs_url = "https://github.com/onnx/onnx/blob/master/"
187 rep = {
188 '[the doc](IR.md)': '`ONNX <{0}docs/IR.md>`_',
189 '[the doc](Broadcasting.md)':
190 '`Broadcasting in ONNX <{0}docs/Broadcasting.md>`_',
191 '<dl>': '',
192 '</dl>': '',
193 '<dt>': '* ',
194 '<dd>': ' ',
195 '</dt>': '',
196 '</dd>': '',
197 '<tt>': '``',
198 '</tt>': '``',
199 '<br>': '\n',
200 '```': '``',
201 }
202 for k, v in rep.items():
203 doc = doc.replace(k, v.format(main_docs_url))
204 move = 0
205 lines = []
206 for line in doc.split('\n'):
207 if line.startswith("```"):
208 if move > 0:
209 move -= 4
210 lines.append("\n")
211 else:
212 lines.append("::\n")
213 move += 4
214 elif move > 0:
215 lines.append(" " * move + line)
216 else:
217 lines.append(line)
218 return "\n".join(lines)
220 def process_attribute_doc(doc):
221 return doc.replace("<br>", " ")
223 def build_doc_url(sch):
224 doc_url = "https://github.com/onnx/onnx/blob/master/docs/Operators"
225 if "ml" in sch.domain:
226 doc_url += "-ml"
227 doc_url += ".md"
228 doc_url += "#"
229 if sch.domain not in (None, '', 'ai.onnx'):
230 doc_url += sch.domain + "."
231 return doc_url
233 def process_default_value(value):
234 if value is None:
235 return '' # pragma: no cover
236 res = []
237 for c in str(value):
238 if ((c >= 'A' and c <= 'Z') or (c >= 'a' and c <= 'z') or
239 (c >= '0' and c <= '9')):
240 res.append(c)
241 continue
242 if c in '[]-+(),.?':
243 res.append(c)
244 continue
245 if len(res) == 0:
246 return "*default value cannot be automatically retrieved*"
247 return "Default value is ``" + ''.join(res) + "``"
249 fnwd = format_name_with_domain
250 tmpl = _template_operator
251 docs = tmpl.render(schemas=schemas, OpSchema=OpSchema,
252 len=len, getattr=getattr, sorted=sorted,
253 format_option=format_option,
254 getconstraint=getconstraint,
255 getname=getname, enumerate=enumerate,
256 format_name_with_domain=fnwd,
257 process_documentation=process_documentation,
258 build_doc_url=build_doc_url, str=str,
259 type_mapping=type_mapping,
260 process_attribute_doc=process_attribute_doc,
261 process_default_value=process_default_value,
262 change_style=change_style)
263 return docs.replace(" Default value is ````", "")
266def debug_onnx_object(obj, depth=3):
267 """
268 ``__dict__`` is not in most of :epkg:`onnx` objects.
269 This function uses function *dir* to explore this object.
270 """
271 def iterable(o):
272 try:
273 iter(o)
274 return True
275 except TypeError:
276 return False
278 if depth <= 0:
279 return None
281 rows = [str(type(obj))]
282 if not isinstance(obj, (int, str, float, bool)):
284 for k in sorted(dir(obj)):
285 try:
286 val = getattr(obj, k)
287 sval = str(val).replace("\n", " ")
288 except (AttributeError, ValueError) as e: # pragma: no cover
289 sval = "ERRROR-" + str(e)
290 val = None
292 if 'method-wrapper' in sval or "built-in method" in sval:
293 continue
295 rows.append("- {}: {}".format(k, sval))
296 if k.startswith('__') and k.endswith('__'):
297 continue
298 if val is None:
299 continue
301 if isinstance(val, dict):
302 try:
303 sorted_list = list(sorted(val.items()))
304 except TypeError: # pragma: no cover
305 sorted_list = list(val.items())
306 for kk, vv in sorted_list:
307 rows.append(" - [%s]: %s" % (str(kk), str(vv)))
308 res = debug_onnx_object(vv, depth - 1)
309 if res is None:
310 continue
311 for line in res.split("\n"):
312 rows.append(" " + line)
313 elif iterable(val):
314 if all(map(lambda o: isinstance(o, (str, bytes)) and len(o) == 1, val)):
315 continue
316 for i, vv in enumerate(val):
317 rows.append(" - [%d]: %s" % (i, str(vv)))
318 res = debug_onnx_object(vv, depth - 1)
319 if res is None:
320 continue
321 for line in res.split("\n"):
322 rows.append(" " + line)
323 elif not callable(val):
324 res = debug_onnx_object(val, depth - 1)
325 if res is None:
326 continue
327 for line in res.split("\n"):
328 rows.append(" " + line)
330 return "\n".join(rows)
333def visual_rst_template():
334 """
335 Returns a :epkg:`jinja2` template to display DOT graph for each
336 converter from :epkg:`sklearn-onnx`.
338 .. runpython::
339 :showcode:
340 :warningout: DeprecationWarning
342 from mlprodict.onnxrt.doc.doc_helper import visual_rst_template
343 print(visual_rst_template())
344 """
345 return textwrap.dedent("""
347 .. _l-{{link}}:
349 {{ title }}
350 {{ '=' * len(title) }}
352 Fitted on a problem type *{{ kind }}*
353 (see :func:`find_suitable_problem
354 <mlprodict.onnxrt.validate.validate_problems.find_suitable_problem>`),
355 method {{ method }} matches output {{ output_index }}.
356 {{ optim_param }}
358 ::
360 {{ indent(model, " ") }}
362 {{ table }}
364 .. gdot::
366 {{ indent(dot, " ") }}
367 """)