Source code for mlprodict.onnxrt.ops_cpu.op_where
# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.
:githublink:`%|py|7`
"""
import numpy
from ._op import OpRun
[docs]class Where(OpRun):
[docs] def __init__(self, onnx_node, desc=None, **options):
OpRun.__init__(self, onnx_node, desc=desc,
**options)
[docs] def _run(self, condition, x, y): # pylint: disable=W0221
if x.dtype != y.dtype:
raise RuntimeError( # pragma: no cover
"x and y should share the same dtype {} != {}".format(
x.dtype, y.dtype))
if x.shape != y.shape:
raise RuntimeError( # pragma: no cover
"x and y should share the same shape {} != {}".format(
x.shape, y.shape))
return (numpy.where(condition, x, y).astype(x.dtype), )
[docs] def _infer_shapes(self, condition, x, y): # pylint: disable=W0221
"""
Returns the same shape by default.
:githublink:`%|py|31`
"""
return (x, )