Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ..shape_object import ShapeObject 

9from ._op import OpRun 

10 

11 

12class LabelEncoder(OpRun): 

13 

14 atts = {'default_float': 0., 'default_int64': -1, 

15 'default_string': b'', 

16 'keys_floats': numpy.empty(0, dtype=numpy.float32), 

17 'keys_int64s': numpy.empty(0, dtype=numpy.int64), 

18 'keys_strings': numpy.empty(0, dtype=numpy.str_), 

19 'values_floats': numpy.empty(0, dtype=numpy.float32), 

20 'values_int64s': numpy.empty(0, dtype=numpy.int64), 

21 'values_strings': numpy.empty(0, dtype=numpy.str_), 

22 } 

23 

24 def __init__(self, onnx_node, desc=None, **options): 

25 OpRun.__init__(self, onnx_node, desc=desc, 

26 expected_attributes=LabelEncoder.atts, 

27 **options) 

28 if len(self.keys_floats) > 0 and len(self.values_floats) > 0: 

29 self.classes_ = {k: v for k, v in zip( 

30 self.keys_floats, self.values_floats)} 

31 self.default_ = self.default_float 

32 self.dtype_ = numpy.float32 

33 elif len(self.keys_floats) > 0 and len(self.values_int64s) > 0: 

34 self.classes_ = {k: v for k, v in zip( 

35 self.keys_floats, self.values_int64s)} 

36 self.default_ = self.default_int64 

37 self.dtype_ = numpy.int64 

38 elif len(self.keys_int64s) > 0 and len(self.values_int64s) > 0: 

39 self.classes_ = {k: v for k, v in zip( 

40 self.keys_int64s, self.values_int64s)} 

41 self.default_ = self.default_int64 

42 self.dtype_ = numpy.int64 

43 elif len(self.keys_int64s) > 0 and len(self.values_floats) > 0: 

44 self.classes_ = {k: v for k, v in zip( 

45 self.keys_int64s, self.values_floats)} 

46 self.default_ = self.default_int64 

47 self.dtype_ = numpy.float32 

48 elif len(self.keys_strings) > 0 and len(self.values_int64s) > 0: 

49 self.classes_ = {k.decode('utf-8'): v for k, v in zip( 

50 self.keys_strings, self.values_int64s)} 

51 self.default_ = self.default_int64 

52 self.dtype_ = numpy.int64 

53 elif len(self.keys_strings) > 0 and len(self.values_strings) > 0: 

54 self.classes_ = { 

55 k.decode('utf-8'): v.decode('utf-8') for k, v in zip( 

56 self.keys_strings, self.values_strings)} 

57 self.default_ = self.default_string 

58 self.dtype_ = numpy.array(self.classes_.values).dtype 

59 elif len(self.keys_floats) > 0 and len(self.values_strings) > 0: 

60 self.classes_ = {k: v.decode('utf-8') for k, v in zip( 

61 self.keys_floats, self.values_strings)} 

62 self.default_ = self.default_string 

63 self.dtype_ = numpy.array(self.classes_.values).dtype 

64 elif len(self.keys_int64s) > 0 and len(self.values_strings) > 0: 

65 self.classes_ = {k: v.decode('utf-8') for k, v in zip( 

66 self.keys_int64s, self.values_strings)} 

67 self.default_ = self.default_string 

68 self.dtype_ = numpy.array(self.classes_.values).dtype 

69 elif hasattr(self, 'classes_strings'): 

70 raise RuntimeError( # pragma: no cover 

71 "This runtime does not implement version 1 of " 

72 "operator LabelEncoder.") 

73 else: 

74 raise RuntimeError( 

75 "No encoding was defined in {}.".format(onnx_node)) 

76 if len(self.classes_) == 0: 

77 raise RuntimeError( # pragma: no cover 

78 "Empty classes for LabelEncoder, (onnx_node='{}')\n{}.".format( 

79 self.onnx_node.name, onnx_node)) 

80 

81 def _run(self, x): # pylint: disable=W0221 

82 if len(x.shape) > 1: 

83 x = numpy.squeeze(x) 

84 res = numpy.empty((x.shape[0], ), dtype=self.dtype_) 

85 for i in range(0, res.shape[0]): 

86 res[i] = self.classes_.get(x[i], self.default_) 

87 return (res, ) 

88 

89 def _infer_shapes(self, x): # pylint: disable=W0221 

90 nb = len(self.classes_.values()) 

91 return (ShapeObject((x[0], nb), dtype=self.dtype_, 

92 name="{}-1".format(self.__class__.__name__)), ) 

93 

94 def _infer_types(self, x): # pylint: disable=W0221 

95 return (self.dtype_, )