Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from collections import OrderedDict 

8import numpy 

9from ._op_helper import _get_typed_class_attribute 

10from ._op import OpRunUnaryNum, RuntimeTypeError 

11from ._new_ops import OperatorSchema 

12from .op_svm_regressor_ import ( # pylint: disable=E0611,E0401 

13 RuntimeSVMRegressorFloat, 

14 RuntimeSVMRegressorDouble, 

15) 

16 

17 

18class SVMRegressorCommon(OpRunUnaryNum): 

19 

20 def __init__(self, dtype, onnx_node, desc=None, 

21 expected_attributes=None, **options): 

22 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

23 expected_attributes=expected_attributes, 

24 **options) 

25 self._init(dtype=dtype) 

26 

27 def _get_typed_attributes(self, k): 

28 return _get_typed_class_attribute(self, k, self.__class__.atts) 

29 

30 def _find_custom_operator_schema(self, op_name): 

31 """ 

32 Finds a custom operator defined by this runtime. 

33 """ 

34 if op_name == "SVMRegressorDouble": 

35 return SVMRegressorDoubleSchema() 

36 raise RuntimeError( # pragma: no cover 

37 "Unable to find a schema for operator '{}'.".format(op_name)) 

38 

39 def _init(self, dtype): 

40 if dtype == numpy.float32: 

41 self.rt_ = RuntimeSVMRegressorFloat(50) 

42 elif dtype == numpy.float64: 

43 self.rt_ = RuntimeSVMRegressorDouble(50) 

44 else: 

45 raise RuntimeTypeError( # pragma: no cover 

46 "Unsupported dtype={}.".format(dtype)) 

47 atts = [self._get_typed_attributes(k) 

48 for k in SVMRegressor.atts] 

49 self.rt_.init(*atts) 

50 

51 def _run(self, x): # pylint: disable=W0221 

52 """ 

53 This is a C++ implementation coming from 

54 :epkg:`onnxruntime`. 

55 `svm_regressor.cc 

56 <https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/svm_regressor.cc>`_. 

57 See class :class:`RuntimeSVMRegressor 

58 <mlprodict.onnxrt.ops_cpu.op_svm_regressor_.RuntimeSVMRegressor>`. 

59 """ 

60 pred = self.rt_.compute(x) 

61 if pred.shape[0] != x.shape[0]: 

62 pred = pred.reshape(x.shape[0], pred.shape[0] // x.shape[0]) 

63 return (pred, ) 

64 

65 

66class SVMRegressor(SVMRegressorCommon): 

67 

68 atts = OrderedDict([ 

69 ('coefficients', numpy.empty(0, dtype=numpy.float32)), 

70 ('kernel_params', numpy.empty(0, dtype=numpy.float32)), 

71 ('kernel_type', b'NONE'), 

72 ('n_supports', 0), 

73 ('one_class', 0), 

74 ('post_transform', b'NONE'), 

75 ('rho', numpy.empty(0, dtype=numpy.float32)), 

76 ('support_vectors', numpy.empty(0, dtype=numpy.float32)), 

77 ]) 

78 

79 def __init__(self, onnx_node, desc=None, **options): 

80 SVMRegressorCommon.__init__( 

81 self, numpy.float32, onnx_node, desc=desc, 

82 expected_attributes=SVMRegressor.atts, 

83 **options) 

84 

85 

86class SVMRegressorDouble(SVMRegressorCommon): 

87 

88 atts = OrderedDict([ 

89 ('coefficients', numpy.empty(0, dtype=numpy.float64)), 

90 ('kernel_params', numpy.empty(0, dtype=numpy.float64)), 

91 ('kernel_type', b'NONE'), 

92 ('n_supports', 0), 

93 ('one_class', 0), 

94 ('post_transform', b'NONE'), 

95 ('rho', numpy.empty(0, dtype=numpy.float64)), 

96 ('support_vectors', numpy.empty(0, dtype=numpy.float64)), 

97 ]) 

98 

99 def __init__(self, onnx_node, desc=None, **options): 

100 SVMRegressorCommon.__init__( 

101 self, numpy.float64, onnx_node, desc=desc, 

102 expected_attributes=SVMRegressorDouble.atts, 

103 **options) 

104 

105 

106class SVMRegressorDoubleSchema(OperatorSchema): 

107 """ 

108 Defines a schema for operators added in this package 

109 such as @see cl SVMRegressorDouble. 

110 """ 

111 

112 def __init__(self): 

113 OperatorSchema.__init__(self, 'SVMRegressorDouble') 

114 self.attributes = SVMRegressorDouble.atts