Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class Where(OpRun): 

12 

13 def __init__(self, onnx_node, desc=None, **options): 

14 OpRun.__init__(self, onnx_node, desc=desc, 

15 **options) 

16 

17 def _run(self, condition, x, y): # pylint: disable=W0221 

18 if x.dtype != y.dtype and x.dtype not in (numpy.object_, ): 

19 raise RuntimeError( # pragma: no cover 

20 "x and y should share the same dtype {} != {}".format( 

21 x.dtype, y.dtype)) 

22 if x.shape != y.shape and x.shape != (1, ) and y.shape != (1, ): 

23 raise RuntimeError( # pragma: no cover 

24 "x and y should share the same shape {} != {}".format( 

25 x.shape, y.shape)) 

26 return (numpy.where(condition, x, y).astype(x.dtype), ) 

27 

28 def _infer_shapes(self, condition, x, y): # pylint: disable=W0221 

29 return (x, ) 

30 

31 def _infer_types(self, condition, x, y): # pylint: disable=W0221 

32 return (x, )