Source code for mlprodict.onnxrt.ops_cpu.op_where

# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.


:githublink:`%|py|7`
"""
import numpy
from ._op import OpRun


[docs]class Where(OpRun):
[docs] def __init__(self, onnx_node, desc=None, **options): OpRun.__init__(self, onnx_node, desc=desc, **options)
[docs] def _run(self, condition, x, y): # pylint: disable=W0221 if x.dtype != y.dtype: raise RuntimeError( # pragma: no cover "x and y should share the same dtype {} != {}".format( x.dtype, y.dtype)) if x.shape != y.shape: raise RuntimeError( # pragma: no cover "x and y should share the same shape {} != {}".format( x.shape, y.shape)) return (numpy.where(condition, x, y).astype(x.dtype), )
[docs] def _infer_shapes(self, condition, x, y): # pylint: disable=W0221 """ Returns the same shape by default. :githublink:`%|py|31` """ return (x, )