Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ..shape_object import DimensionObject 

10 

11 

12class OneHotEncoder(OpRun): 

13 """ 

14 :epkg:`ONNX` specifications does not mention 

15 the possibility to change the output type, 

16 sparse, dense, float, double. 

17 """ 

18 

19 atts = {'cats_int64s': numpy.empty(0, dtype=numpy.int64), 

20 'cats_strings': numpy.empty(0, dtype=numpy.str_), 

21 'zeros': 1, 

22 } 

23 

24 def __init__(self, onnx_node, desc=None, **options): 

25 OpRun.__init__(self, onnx_node, desc=desc, 

26 expected_attributes=OneHotEncoder.atts, 

27 **options) 

28 if len(self.cats_int64s) > 0: 

29 self.classes_ = {v: i for i, v in enumerate(self.cats_int64s)} 

30 elif len(self.cats_strings) > 0: 

31 self.classes_ = {v.decode('utf-8'): i for i, 

32 v in enumerate(self.cats_strings)} 

33 else: 

34 raise RuntimeError("No encoding was defined.") # pragma: no cover 

35 

36 def _run(self, x): # pylint: disable=W0221 

37 shape = x.shape 

38 new_shape = shape + (len(self.classes_), ) 

39 res = numpy.zeros(new_shape, dtype=numpy.float32) 

40 if len(x.shape) == 1: 

41 for i, v in enumerate(x): 

42 j = self.classes_.get(v, -1) 

43 if j >= 0: 

44 res[i, j] = 1. 

45 elif len(x.shape) == 2: 

46 for a, row in enumerate(x): 

47 for i, v in enumerate(row): 

48 j = self.classes_.get(v, -1) 

49 if j >= 0: 

50 res[a, i, j] = 1. 

51 else: 

52 raise RuntimeError( # pragma: no cover 

53 "This operator is not implemented for shape {}.".format(x.shape)) 

54 

55 if not self.zeros: 

56 red = res.sum(axis=len(res.shape) - 1) 

57 if numpy.min(red) == 0: 

58 rows = [] 

59 for i, val in enumerate(red): 

60 if val == 0: 

61 rows.append(dict(row=i, value=x[i])) 

62 if len(rows) > 5: 

63 break 

64 raise RuntimeError( # pragma no cover 

65 "One observation did not have any defined category.\n" 

66 "classes: {}\nfirst rows:\n{}\nres:\n{}\nx:\n{}".format( 

67 self.classes_, "\n".join(str(_) for _ in rows), 

68 res[:5], x[:5])) 

69 

70 return (res, ) 

71 

72 def _infer_shapes(self, x): # pylint: disable=W0221 

73 new_shape = x.copy() 

74 dim = DimensionObject(len(self.classes_)) 

75 new_shape.append(dim) 

76 new_shape._dtype = numpy.float32 

77 new_shape.name = self.onnx_node.name 

78 return (new_shape, ) 

79 

80 def _infer_types(self, x): # pylint: disable=W0221 

81 return (numpy.float32, )