Source code for mlprodict.testing.script_testing
"""
Utilies to test script from :epkg:`scikit-learn` documentation.
:githublink:`%|py|5`
"""
import os
from io import StringIO
from contextlib import redirect_stdout, redirect_stderr
import pprint
import numpy
from sklearn.base import BaseEstimator
from .verify_code import verify_code
[docs]class MissingVariableError(RuntimeError):
"""
Raised when a variable is missing.
:githublink:`%|py|17`
"""
pass
[docs]def _clean_script(content):
"""
Comments out all lines containing ``.show()``.
:githublink:`%|py|24`
"""
new_lines = []
for line in content.split('\n'):
if '.show()' in line or 'sys.exit' in line:
new_lines.append("# " + line)
else:
new_lines.append(line)
return "\n".join(new_lines)
[docs]def _enumerate_fit_info(fits):
"""
Extracts the name of the fitted models and the data
used to train it.
:githublink:`%|py|38`
"""
for fit in fits:
chs = fit['children']
if len(chs) < 2:
# unable to extract the needed information
continue # pragma: no cover
model = chs[0]['str']
if model.endswith('.fit'):
model = model[:-4]
args = [ch['str'] for ch in chs[1:]]
yield model, args
[docs]def _try_onnx(loc, model_name, args_name, **options):
"""
Tries onnx conversion.
:param loc: available variables
:param model_name: model name among these variables
:param args_name: arguments name among these variables
:param options: additional options for the conversion
:return: onnx model
:githublink:`%|py|60`
"""
from ..onnx_conv import to_onnx
if model_name not in loc:
raise MissingVariableError("Unable to find model '{}' in {}".format(
model_name, ", ".join(sorted(loc))))
if args_name[0] not in loc:
raise MissingVariableError("Unable to find data '{}' in {}".format(
args_name[0], ", ".join(sorted(loc))))
model = loc[model_name]
X = loc[args_name[0]]
dtype = options.get('dtype', numpy.float32)
Xt = X.astype(dtype)
onx = to_onnx(model, Xt, **options)
args = dict(onx=onx, model=model, X=Xt)
return onx, args
[docs]def verify_script(file_or_name, try_onnx=True, existing_loc=None,
**options):
"""
Checks that models fitted in an example from :epkg:`scikit-learn`
documentation can be converted into :epkg:`ONNX`.
:param file_or_name: file or string
:param try_onnx: try the onnx conversion
:param existing_loc: existing local variables
:param options: conversion options
:return: list of converted models
:githublink:`%|py|88`
"""
if '\n' not in file_or_name and os.path.exists(file_or_name):
filename = file_or_name
with open(file_or_name, 'r', encoding='utf-8') as f:
content = f.read()
else: # pragma: no cover
content = file_or_name
filename = "<string>"
# comments out .show()
content = _clean_script(content)
# look for fit or predict expressions
_, node = verify_code(content, exc=False)
fits = node._fits
models_args = list(_enumerate_fit_info(fits))
# execution
obj = compile(content, filename, 'exec')
glo = globals().copy()
loc = {}
if existing_loc is not None:
loc.update(existing_loc)
glo.update(existing_loc)
out = StringIO()
err = StringIO()
with redirect_stdout(out):
with redirect_stderr(err):
exec(obj, glo, loc) # pylint: disable=W0122
# filter out values
cls = (BaseEstimator, numpy.ndarray)
loc_fil = {k: v for k, v in loc.items() if isinstance(v, cls)}
glo_fil = {k: v for k, v in glo.items() if k not in {'__builtins__'}}
onx_info = []
# onnx
if try_onnx:
if len(models_args) == 0:
raise MissingVariableError(
"No detected trained model in '{}'\n{}\n--LOCALS--\n{}".format(
filename, content, pprint.pformat(loc_fil)))
for model_args in models_args:
try:
onx, args = _try_onnx(loc_fil, *model_args, **options)
except MissingVariableError as e:
raise MissingVariableError("Unable to find variable in '{}'\n{}".format(
filename, pprint.pformat(fits))) from e
loc_fil[model_args[0] + "_onnx"] = onx
onx_info.append(args)
# final results
return dict(locals=loc_fil, globals=glo_fil,
stdout=out.getvalue(),
stderr=err.getvalue(),
onx_info=onx_info)