Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Documentation helper. 

4""" 

5import keyword 

6import textwrap 

7import re 

8from jinja2 import Template 

9from jinja2.runtime import Undefined 

10from onnx.defs import OpSchema 

11from ...tools import change_style 

12 

13 

14def type_mapping(name): 

15 """ 

16 Mapping between types name and type integer value. 

17 

18 .. runpython:: 

19 :showcode: 

20 :warningout: DeprecationWarning 

21 

22 from mlprodict.onnxrt.doc.doc_helper import type_mapping 

23 import pprint 

24 pprint.pprint(type_mapping(None)) 

25 print(type_mapping("INT")) 

26 print(type_mapping(2)) 

27 """ 

28 di = dict(FLOAT=1, FLOATS=6, GRAPH=5, GRAPHS=10, INT=2, 

29 INTS=7, STRING=3, STRINGS=8, TENSOR=4, 

30 TENSORS=9, UNDEFINED=0, SPARSE_TENSOR=11) 

31 if name is None: 

32 return di 

33 if isinstance(name, str): 

34 return di[name] 

35 rev = {v: k for k, v in di.items()} 

36 return rev[name] 

37 

38 

39def _get_doc_template(): 

40 

41 return Template(textwrap.dedent(""" 

42 {% for sch in schemas %} 

43 

44 {{format_name_with_domain(sch)}} 

45 {{'=' * len(format_name_with_domain(sch))}} 

46 

47 {{process_documentation(sch.doc)}} 

48 

49 {% if sch.attributes %} 

50 **Attributes** 

51 

52 {% for _, attr in sorted(sch.attributes.items()) %}* *{{attr.name}}*{% 

53 if attr.required %} (required){% endif %}: {{ 

54 process_attribute_doc(attr.description)}} {% 

55 if attr.default_value %} {{ 

56 process_default_value(attr.default_value) 

57 }} ({{type_mapping(attr.type)}}){% endif %} 

58 {% endfor %} 

59 {% endif %} 

60 

61 {% if sch.inputs %} 

62 **Inputs** 

63 

64 {% if sch.min_input != sch.max_input %}Between {{sch.min_input 

65 }} and {{sch.max_input}} inputs. 

66 {% endif %} 

67 {% for ii, inp in enumerate(sch.inputs) %} 

68 * *{{getname(inp, ii)}}*{{format_option(inp)}}{{inp.typeStr}}: {{ 

69 inp.description}}{% endfor %} 

70 {% endif %} 

71 

72 {% if sch.outputs %} 

73 **Outputs** 

74 

75 {% if sch.min_output != sch.max_output %}Between {{sch.min_output 

76 }} and {{sch.max_output}} outputs. 

77 {% endif %} 

78 {% for ii, out in enumerate(sch.outputs) %} 

79 * *{{getname(out, ii)}}*{{format_option(out)}}{{out.typeStr}}: {{ 

80 out.description}}{% endfor %} 

81 {% endif %} 

82 

83 {% if sch.type_constraints %} 

84 **Type Constraints** 

85 

86 {% for ii, type_constraint in enumerate(sch.type_constraints) 

87 %}* {{getconstraint(type_constraint, ii)}}: {{ 

88 type_constraint.description}} 

89 {% endfor %} 

90 {% endif %} 

91 

92 **Version** 

93 

94 *Onnx name:* `{{sch.name}} <{{build_doc_url(sch)}}{{sch.name}}>`_ 

95 

96 {% if sch.support_level == OpSchema.SupportType.EXPERIMENTAL %} 

97 No versioning maintained for experimental ops. 

98 {% else %} 

99 This version of the operator has been {% if 

100 sch.deprecated %}deprecated{% else %}available{% endif %} since 

101 version {{sch.since_version}}{% if 

102 sch.domain %} of domain {{sch.domain}}{% endif %}. 

103 {% if len(sch.versions) > 1 %} 

104 Other versions of this operator: 

105 {% for v in sch.version[:-1] %} {{v}} {% endfor %} 

106 {% endif %} 

107 {% endif %} 

108 

109 **Runtime implementation:** 

110 :class:`{{sch.name}} 

111 <mlprodict.onnxrt.ops_cpu.op_{{change_style(sch.name)}}.{{sch.name}}>` 

112 

113 {% endfor %} 

114 """)) 

115 

116 

117_template_operator = _get_doc_template() 

118 

119 

120class NewOperatorSchema: 

121 """ 

122 Defines a schema for operators added in this package 

123 such as @see cl TreeEnsembleRegressorDouble. 

124 """ 

125 

126 def __init__(self, name): 

127 self.name = name 

128 self.domain = 'mlprodict' 

129 

130 

131def get_rst_doc(op_name): 

132 """ 

133 Returns a documentation in RST format 

134 for all :epkg:`OnnxOperator`. 

135 

136 :param op_name: operator name of None for all 

137 :return: string 

138 

139 The function relies on module :epkg:`jinja2` or replaces it 

140 with a simple rendering if not present. 

141 """ 

142 from ..ops_cpu._op import _schemas 

143 schemas = [_schemas.get(op_name, NewOperatorSchema(op_name))] 

144 

145 def format_name_with_domain(sch): 

146 if sch.domain: 

147 return '{} ({})'.format(sch.name, sch.domain) 

148 return sch.name 

149 

150 def format_option(obj): 

151 opts = [] 

152 if OpSchema.FormalParameterOption.Optional == obj.option: 

153 opts.append('optional') 

154 elif OpSchema.FormalParameterOption.Variadic == obj.option: 

155 opts.append('variadic') 

156 if getattr(obj, 'isHomogeneous', False): 

157 opts.append('heterogeneous') 

158 if opts: 

159 return " (%s)" % ", ".join(opts) 

160 return "" # pragma: no cover 

161 

162 def getconstraint(const, ii): 

163 if const.type_param_str: 

164 name = const.type_param_str 

165 else: 

166 name = str(ii) # pragma: no cover 

167 if const.allowed_type_strs: 

168 name += " " + ", ".join(const.allowed_type_strs) 

169 return name 

170 

171 def getname(obj, i): 

172 name = obj.name 

173 if len(name) == 0: 

174 return str(i) # pragma: no cover 

175 return name 

176 

177 def process_documentation(doc): 

178 if doc is None: 

179 doc = '' # pragma: no cover 

180 if isinstance(doc, Undefined): 

181 doc = '' # pragma: no cover 

182 if not isinstance(doc, str): 

183 raise TypeError( # pragma: no cover 

184 "Unexpected type {} for {}".format(type(doc), doc)) 

185 doc = textwrap.dedent(doc) 

186 main_docs_url = "https://github.com/onnx/onnx/blob/master/" 

187 rep = { 

188 '[the doc](IR.md)': '`ONNX <{0}docs/IR.md>`_', 

189 '[the doc](Broadcasting.md)': 

190 '`Broadcasting in ONNX <{0}docs/Broadcasting.md>`_', 

191 '<dl>': '', 

192 '</dl>': '', 

193 '<dt>': '* ', 

194 '<dd>': ' ', 

195 '</dt>': '', 

196 '</dd>': '', 

197 '<tt>': '``', 

198 '</tt>': '``', 

199 '<br>': '\n', 

200 '```': '``', 

201 } 

202 for k, v in rep.items(): 

203 doc = doc.replace(k, v.format(main_docs_url)) 

204 move = 0 

205 lines = [] 

206 for line in doc.split('\n'): 

207 if line.startswith("```"): 

208 if move > 0: 

209 move -= 4 

210 lines.append("\n") 

211 else: 

212 lines.append("::\n") 

213 move += 4 

214 elif move > 0: 

215 lines.append(" " * move + line) 

216 else: 

217 lines.append(line) 

218 return "\n".join(lines) 

219 

220 def process_attribute_doc(doc): 

221 return doc.replace("<br>", " ") 

222 

223 def build_doc_url(sch): 

224 doc_url = "https://github.com/onnx/onnx/blob/master/docs/Operators" 

225 if "ml" in sch.domain: 

226 doc_url += "-ml" 

227 doc_url += ".md" 

228 doc_url += "#" 

229 if sch.domain not in (None, '', 'ai.onnx'): 

230 doc_url += sch.domain + "." 

231 return doc_url 

232 

233 def process_default_value(value): 

234 if value is None: 

235 return '' # pragma: no cover 

236 res = [] 

237 for c in str(value): 

238 if ((c >= 'A' and c <= 'Z') or (c >= 'a' and c <= 'z') or 

239 (c >= '0' and c <= '9')): 

240 res.append(c) 

241 continue 

242 if c in '[]-+(),.?': 

243 res.append(c) 

244 continue 

245 if len(res) == 0: 

246 return "*default value cannot be automatically retrieved*" 

247 return "Default value is ``" + ''.join(res) + "``" 

248 

249 fnwd = format_name_with_domain 

250 tmpl = _template_operator 

251 docs = tmpl.render(schemas=schemas, OpSchema=OpSchema, 

252 len=len, getattr=getattr, sorted=sorted, 

253 format_option=format_option, 

254 getconstraint=getconstraint, 

255 getname=getname, enumerate=enumerate, 

256 format_name_with_domain=fnwd, 

257 process_documentation=process_documentation, 

258 build_doc_url=build_doc_url, str=str, 

259 type_mapping=type_mapping, 

260 process_attribute_doc=process_attribute_doc, 

261 process_default_value=process_default_value, 

262 change_style=change_style) 

263 return docs.replace(" Default value is ````", "") 

264 

265 

266def debug_onnx_object(obj, depth=3): 

267 """ 

268 ``__dict__`` is not in most of :epkg:`onnx` objects. 

269 This function uses function *dir* to explore this object. 

270 """ 

271 def iterable(o): 

272 try: 

273 iter(o) 

274 return True 

275 except TypeError: 

276 return False 

277 

278 if depth <= 0: 

279 return None 

280 

281 rows = [str(type(obj))] 

282 if not isinstance(obj, (int, str, float, bool)): 

283 

284 for k in sorted(dir(obj)): 

285 try: 

286 val = getattr(obj, k) 

287 sval = str(val).replace("\n", " ") 

288 except (AttributeError, ValueError) as e: # pragma: no cover 

289 sval = "ERRROR-" + str(e) 

290 val = None 

291 

292 if 'method-wrapper' in sval or "built-in method" in sval: 

293 continue 

294 

295 rows.append("- {}: {}".format(k, sval)) 

296 if k.startswith('__') and k.endswith('__'): 

297 continue 

298 if val is None: 

299 continue 

300 

301 if isinstance(val, dict): 

302 try: 

303 sorted_list = list(sorted(val.items())) 

304 except TypeError: # pragma: no cover 

305 sorted_list = list(val.items()) 

306 for kk, vv in sorted_list: 

307 rows.append(" - [%s]: %s" % (str(kk), str(vv))) 

308 res = debug_onnx_object(vv, depth - 1) 

309 if res is None: 

310 continue 

311 for line in res.split("\n"): 

312 rows.append(" " + line) 

313 elif iterable(val): 

314 if all(map(lambda o: isinstance(o, (str, bytes)) and len(o) == 1, val)): 

315 continue 

316 for i, vv in enumerate(val): 

317 rows.append(" - [%d]: %s" % (i, str(vv))) 

318 res = debug_onnx_object(vv, depth - 1) 

319 if res is None: 

320 continue 

321 for line in res.split("\n"): 

322 rows.append(" " + line) 

323 elif not callable(val): 

324 res = debug_onnx_object(val, depth - 1) 

325 if res is None: 

326 continue 

327 for line in res.split("\n"): 

328 rows.append(" " + line) 

329 

330 return "\n".join(rows) 

331 

332 

333def visual_rst_template(): 

334 """ 

335 Returns a :epkg:`jinja2` template to display DOT graph for each 

336 converter from :epkg:`sklearn-onnx`. 

337 

338 .. runpython:: 

339 :showcode: 

340 :warningout: DeprecationWarning 

341 

342 from mlprodict.onnxrt.doc.doc_helper import visual_rst_template 

343 print(visual_rst_template()) 

344 """ 

345 return textwrap.dedent(""" 

346 

347 .. _l-{{link}}: 

348 

349 {{ title }} 

350 {{ '=' * len(title) }} 

351 

352 Fitted on a problem type *{{ kind }}* 

353 (see :func:`find_suitable_problem 

354 <mlprodict.onnxrt.validate.validate_problems.find_suitable_problem>`), 

355 method {{ method }} matches output {{ output_index }}. 

356 {{ optim_param }} 

357 

358 :: 

359 

360 {{ indent(model, " ") }} 

361 

362 {{ table }} 

363 

364 .. gdot:: 

365 

366 {{ indent(dot, " ") }} 

367 """)