Source code for mlprodict.onnxrt.ops_cpu.op_solve
# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.
:githublink:`%|py|7`
"""
from scipy.linalg import solve
from ._op import OpRunBinaryNum
from ._new_ops import OperatorSchema
[docs]class Solve(OpRunBinaryNum):
atts = {'lower': False,
'transposed': False}
[docs] def __init__(self, onnx_node, desc=None, **options):
OpRunBinaryNum.__init__(self, onnx_node, desc=desc,
expected_attributes=Solve.atts,
**options)
[docs] def _find_custom_operator_schema(self, op_name):
if op_name == "Solve":
return SolveSchema()
raise RuntimeError( # pragma: no cover
"Unable to find a schema for operator '{}'.".format(op_name))
[docs] def _run(self, a, b): # pylint: disable=W0221
if self.inplaces.get(1, False):
return (solve(a, b, overwrite_b=True, lower=self.lower,
transposed=self.transposed), )
return (solve(a, b, lower=self.lower, transposed=self.transposed), )
[docs] def _infer_shapes(self, a, b): # pylint: disable=W0221
"""
Returns the shapes.
:githublink:`%|py|37`
"""
return (b, )
[docs] def to_python(self, inputs):
return ('from scipy.linalg import solve',
"return solve({}, {}, lower={}, transposed={})".format(
inputs[0], inputs[1], self.lower, self.transposed))
[docs]class SolveSchema(OperatorSchema):
"""
Defines a schema for operators added in this package
such as :class:`TreeEnsembleClassifierDouble <mlprodict.onnxrt.ops_cpu.op_tree_ensemble_classifier.TreeEnsembleClassifierDouble>`.
:githublink:`%|py|50`
"""
[docs] def __init__(self):
OperatorSchema.__init__(self, 'Solve')
self.attributes = Solve.atts