Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ..shape_object import ShapeObject 

10from ._op_onnx_numpy import ( # pylint: disable=E0611,E0401 

11 array_feature_extractor_double, 

12 array_feature_extractor_int64, 

13 array_feature_extractor_float) 

14 

15 

16def _array_feature_extrator(data, indices): 

17 """ 

18 Implementation of operator *ArrayFeatureExtractor* 

19 with :epkg:`numpy`. 

20 """ 

21 if len(indices.shape) == 2 and indices.shape[0] == 1: 

22 index = indices.ravel().tolist() 

23 add = len(index) 

24 elif len(indices.shape) == 1: 

25 index = indices.tolist() 

26 add = len(index) 

27 else: 

28 add = 1 

29 for s in indices.shape: 

30 add *= s 

31 index = indices.ravel().tolist() 

32 if len(data.shape) == 1: 

33 new_shape = (1, add) 

34 else: 

35 new_shape = list(data.shape[:-1]) + [add] 

36 tem = data[..., index] 

37 res = tem.reshape(new_shape) 

38 return res 

39 

40 

41def sizeof_dtype(dty): 

42 if dty == numpy.float64: 

43 return 8 

44 if dty == numpy.float32: 

45 return 4 

46 if dty == numpy.int64: 

47 return 8 

48 raise ValueError( 

49 "Unable to get bytes size for type {}.".format(numpy.dtype)) 

50 

51 

52class ArrayFeatureExtractor(OpRun): 

53 

54 def __init__(self, onnx_node, desc=None, **options): 

55 OpRun.__init__(self, onnx_node, desc=desc, 

56 **options) 

57 

58 def _run(self, data, indices): # pylint: disable=W0221 

59 """ 

60 Runtime for operator *ArrayFeatureExtractor*. 

61 

62 .. warning:: 

63 ONNX specifications may be imprecise in some cases. 

64 When the input data is a vector (one dimension), 

65 the output has still two like a matrix with one row. 

66 The implementation follows what :epkg:`onnxruntime` does in 

67 `array_feature_extractor.cc 

68 <https://github.com/microsoft/onnxruntime/blob/master/ 

69 onnxruntime/core/providers/cpu/ml/array_feature_extractor.cc#L84>`_. 

70 """ 

71 if data.dtype == numpy.float64: 

72 res = array_feature_extractor_double(data, indices) 

73 elif data.dtype == numpy.float32: 

74 res = array_feature_extractor_float(data, indices) 

75 elif data.dtype == numpy.int64: 

76 res = array_feature_extractor_int64(data, indices) 

77 else: 

78 # for strings, still not C++ 

79 res = _array_feature_extrator(data, indices) 

80 return (res, ) 

81 

82 def _infer_shapes(self, data, indices): # pylint: disable=W0221 

83 """ 

84 Infer the shapes for the output. 

85 """ 

86 add = indices.product() 

87 

88 if len(data) == 1: 

89 dim = ShapeObject((1, add), dtype=data.dtype) 

90 else: 

91 dim = data.copy() 

92 dim.append(add) 

93 return (dim, ) 

94 

95 def _infer_types(self, data, indices): # pylint: disable=W0221 

96 """ 

97 Returns the type of the output. 

98 """ 

99 return (data, )