Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.onnx_pb import TensorProto 

9from ._op import OpRun 

10 

11 

12class Cast(OpRun): 

13 

14 atts = {'to': None} 

15 

16 def __init__(self, onnx_node, desc=None, **options): 

17 OpRun.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=Cast.atts, 

19 **options) 

20 # type help(TensorProto) to see all the possible values 

21 if self.to == TensorProto.FLOAT: # pylint: disable=E1101 

22 self._dtype = numpy.float32 

23 elif self.to == TensorProto.DOUBLE: # pylint: disable=E1101 

24 self._dtype = numpy.float64 

25 elif self.to == TensorProto.UINT8: # pylint: disable=E1101 

26 self._dtype = numpy.uint8 

27 elif self.to == TensorProto.INT8: # pylint: disable=E1101 

28 self._dtype = numpy.int8 

29 elif self.to == TensorProto.INT16: # pylint: disable=E1101 

30 self._dtype = numpy.int16 

31 elif self.to == TensorProto.INT32: # pylint: disable=E1101 

32 self._dtype = numpy.int32 

33 elif self.to == TensorProto.INT64: # pylint: disable=E1101 

34 self._dtype = numpy.int64 

35 elif self.to == TensorProto.UINT16: # pylint: disable=E1101 

36 self._dtype = numpy.uint16 

37 elif self.to == TensorProto.UINT32: # pylint: disable=E1101 

38 self._dtype = numpy.uint32 

39 elif self.to == TensorProto.UINT64: # pylint: disable=E1101 

40 self._dtype = numpy.uint64 

41 elif self.to == TensorProto.BOOL: # pylint: disable=E1101 

42 self._dtype = numpy.bool_ 

43 elif self.to == TensorProto.STRING: # pylint: disable=E1101 

44 self._dtype = numpy.str_ 

45 elif self.to == TensorProto.FLOAT16: # pylint: disable=E1101 

46 self._dtype = numpy.float16 

47 elif self.to == TensorProto.COMPLEX64: # pylint: disable=E1101 

48 self._dtype = numpy.complex64 

49 elif self.to == TensorProto.COMPLEX128: # pylint: disable=E1101 

50 self._dtype = numpy.complex128 

51 else: 

52 raise ValueError( # pragma: no cover 

53 "Unexpected value for to='{}'.".format( 

54 self.to)) # pylint: disable=E1101 

55 self._cast = lambda x: x.astype(self._dtype) 

56 

57 def _run(self, x): # pylint: disable=W0221 

58 if self.inplaces.get(0, False): 

59 return self._run_inplace(x) 

60 return (self._cast(x), ) 

61 

62 def _run_inplace(self, x): 

63 if x.dtype == self._dtype: 

64 return (x, ) 

65 return (self._cast(x), ) 

66 

67 def _infer_shapes(self, x): # pylint: disable=W0221 

68 return (x.copy(dtype=self._dtype), ) 

69 

70 def _infer_types(self, x): # pylint: disable=W0221 

71 return (self._dtype, ) 

72 

73 def _infer_sizes(self, *args, **kwargs): 

74 res = self.run(*args, **kwargs) 

75 return (dict(temp=0), ) + res