Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.onnx_pb import TensorProto
9from ._op import OpRun
12class Cast(OpRun):
14 atts = {'to': None}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRun.__init__(self, onnx_node, desc=desc,
18 expected_attributes=Cast.atts,
19 **options)
20 # type help(TensorProto) to see all the possible values
21 if self.to == TensorProto.FLOAT: # pylint: disable=E1101
22 self._dtype = numpy.float32
23 elif self.to == TensorProto.DOUBLE: # pylint: disable=E1101
24 self._dtype = numpy.float64
25 elif self.to == TensorProto.UINT8: # pylint: disable=E1101
26 self._dtype = numpy.uint8
27 elif self.to == TensorProto.INT8: # pylint: disable=E1101
28 self._dtype = numpy.int8
29 elif self.to == TensorProto.INT16: # pylint: disable=E1101
30 self._dtype = numpy.int16
31 elif self.to == TensorProto.INT32: # pylint: disable=E1101
32 self._dtype = numpy.int32
33 elif self.to == TensorProto.INT64: # pylint: disable=E1101
34 self._dtype = numpy.int64
35 elif self.to == TensorProto.UINT16: # pylint: disable=E1101
36 self._dtype = numpy.uint16
37 elif self.to == TensorProto.UINT32: # pylint: disable=E1101
38 self._dtype = numpy.uint32
39 elif self.to == TensorProto.UINT64: # pylint: disable=E1101
40 self._dtype = numpy.uint64
41 elif self.to == TensorProto.BOOL: # pylint: disable=E1101
42 self._dtype = numpy.bool_
43 elif self.to == TensorProto.STRING: # pylint: disable=E1101
44 self._dtype = numpy.str_
45 elif self.to == TensorProto.FLOAT16: # pylint: disable=E1101
46 self._dtype = numpy.float16
47 elif self.to == TensorProto.COMPLEX64: # pylint: disable=E1101
48 self._dtype = numpy.complex64
49 elif self.to == TensorProto.COMPLEX128: # pylint: disable=E1101
50 self._dtype = numpy.complex128
51 else:
52 raise ValueError( # pragma: no cover
53 "Unexpected value for to='{}'.".format(
54 self.to)) # pylint: disable=E1101
55 self._cast = lambda x: x.astype(self._dtype)
57 def _run(self, x): # pylint: disable=W0221
58 if self.inplaces.get(0, False):
59 return self._run_inplace(x)
60 return (self._cast(x), )
62 def _run_inplace(self, x):
63 if x.dtype == self._dtype:
64 return (x, )
65 return (self._cast(x), )
67 def _infer_shapes(self, x): # pylint: disable=W0221
68 return (x.copy(dtype=self._dtype), )
70 def _infer_types(self, x): # pylint: disable=W0221
71 return (self._dtype, )
73 def _infer_sizes(self, *args, **kwargs):
74 res = self.run(*args, **kwargs)
75 return (dict(temp=0), ) + res