Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ..shape_object import ShapeObject
9from ._op import OpRun
12class LabelEncoder(OpRun):
14 atts = {'default_float': 0., 'default_int64': -1,
15 'default_string': b'',
16 'keys_floats': numpy.empty(0, dtype=numpy.float32),
17 'keys_int64s': numpy.empty(0, dtype=numpy.int64),
18 'keys_strings': numpy.empty(0, dtype=numpy.str_),
19 'values_floats': numpy.empty(0, dtype=numpy.float32),
20 'values_int64s': numpy.empty(0, dtype=numpy.int64),
21 'values_strings': numpy.empty(0, dtype=numpy.str_),
22 }
24 def __init__(self, onnx_node, desc=None, **options):
25 OpRun.__init__(self, onnx_node, desc=desc,
26 expected_attributes=LabelEncoder.atts,
27 **options)
28 if len(self.keys_floats) > 0 and len(self.values_floats) > 0:
29 self.classes_ = {k: v for k, v in zip(
30 self.keys_floats, self.values_floats)}
31 self.default_ = self.default_float
32 self.dtype_ = numpy.float32
33 elif len(self.keys_floats) > 0 and len(self.values_int64s) > 0:
34 self.classes_ = {k: v for k, v in zip(
35 self.keys_floats, self.values_int64s)}
36 self.default_ = self.default_int64
37 self.dtype_ = numpy.int64
38 elif len(self.keys_int64s) > 0 and len(self.values_int64s) > 0:
39 self.classes_ = {k: v for k, v in zip(
40 self.keys_int64s, self.values_int64s)}
41 self.default_ = self.default_int64
42 self.dtype_ = numpy.int64
43 elif len(self.keys_int64s) > 0 and len(self.values_floats) > 0:
44 self.classes_ = {k: v for k, v in zip(
45 self.keys_int64s, self.values_floats)}
46 self.default_ = self.default_int64
47 self.dtype_ = numpy.float32
48 elif len(self.keys_strings) > 0 and len(self.values_int64s) > 0:
49 self.classes_ = {k.decode('utf-8'): v for k, v in zip(
50 self.keys_strings, self.values_int64s)}
51 self.default_ = self.default_int64
52 self.dtype_ = numpy.int64
53 elif len(self.keys_strings) > 0 and len(self.values_strings) > 0:
54 self.classes_ = {
55 k.decode('utf-8'): v.decode('utf-8') for k, v in zip(
56 self.keys_strings, self.values_strings)}
57 self.default_ = self.default_string
58 self.dtype_ = numpy.array(self.classes_.values).dtype
59 elif len(self.keys_floats) > 0 and len(self.values_strings) > 0:
60 self.classes_ = {k: v.decode('utf-8') for k, v in zip(
61 self.keys_floats, self.values_strings)}
62 self.default_ = self.default_string
63 self.dtype_ = numpy.array(self.classes_.values).dtype
64 elif len(self.keys_int64s) > 0 and len(self.values_strings) > 0:
65 self.classes_ = {k: v.decode('utf-8') for k, v in zip(
66 self.keys_int64s, self.values_strings)}
67 self.default_ = self.default_string
68 self.dtype_ = numpy.array(self.classes_.values).dtype
69 elif hasattr(self, 'classes_strings'):
70 raise RuntimeError( # pragma: no cover
71 "This runtime does not implement version 1 of "
72 "operator LabelEncoder.")
73 else:
74 raise RuntimeError(
75 "No encoding was defined in {}.".format(onnx_node))
76 if len(self.classes_) == 0:
77 raise RuntimeError( # pragma: no cover
78 "Empty classes for LabelEncoder, (onnx_node='{}')\n{}.".format(
79 self.onnx_node.name, onnx_node))
81 def _run(self, x): # pylint: disable=W0221
82 if len(x.shape) > 1:
83 x = numpy.squeeze(x)
84 res = numpy.empty((x.shape[0], ), dtype=self.dtype_)
85 for i in range(0, res.shape[0]):
86 res[i] = self.classes_.get(x[i], self.default_)
87 return (res, )
89 def _infer_shapes(self, x): # pylint: disable=W0221
90 nb = len(self.classes_.values())
91 return (ShapeObject((x[0], nb), dtype=self.dtype_,
92 name="{}-1".format(self.__class__.__name__)), )
94 def _infer_types(self, x): # pylint: disable=W0221
95 return (self.dtype_, )