Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from collections import OrderedDict 

8import numpy 

9from ._op_helper import _get_typed_class_attribute 

10from ._op import OpRunClassifierProb, RuntimeTypeError 

11from ._op_classifier_string import _ClassifierCommon 

12from ._new_ops import OperatorSchema 

13from .op_svm_classifier_ import ( # pylint: disable=E0611,E0401 

14 RuntimeSVMClassifierFloat, 

15 RuntimeSVMClassifierDouble, 

16) 

17 

18 

19class SVMClassifierCommon(OpRunClassifierProb, _ClassifierCommon): 

20 

21 def __init__(self, dtype, onnx_node, desc=None, 

22 expected_attributes=None, **options): 

23 OpRunClassifierProb.__init__(self, onnx_node, desc=desc, 

24 expected_attributes=expected_attributes, 

25 **options) 

26 self._init(dtype=dtype) 

27 

28 def _get_typed_attributes(self, k): 

29 return _get_typed_class_attribute(self, k, self.__class__.atts) 

30 

31 def _find_custom_operator_schema(self, op_name): 

32 """ 

33 Finds a custom operator defined by this runtime. 

34 """ 

35 if op_name == "SVMClassifierDouble": 

36 return SVMClassifierDoubleSchema() 

37 raise RuntimeError( # pragma: no cover 

38 "Unable to find a schema for operator '{}'.".format(op_name)) 

39 

40 def _init(self, dtype): 

41 self._post_process_label_attributes() 

42 if dtype == numpy.float32: 

43 self.rt_ = RuntimeSVMClassifierFloat(20) 

44 elif dtype == numpy.float64: 

45 self.rt_ = RuntimeSVMClassifierDouble(20) 

46 else: 

47 raise RuntimeTypeError( # pragma: no cover 

48 "Unsupported dtype={}.".format(dtype)) 

49 atts = [self._get_typed_attributes(k) 

50 for k in SVMClassifier.atts] 

51 self.rt_.init(*atts) 

52 

53 def _run(self, x): # pylint: disable=W0221 

54 """ 

55 This is a C++ implementation coming from 

56 :epkg:`onnxruntime`. 

57 `svm_classifier.cc 

58 <https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/svm_classifier.cc>`_. 

59 See class :class:`RuntimeSVMClassifier 

60 <mlprodict.onnxrt.ops_cpu.op_svm_classifier_.RuntimeSVMClassifier>`. 

61 """ 

62 label, scores = self.rt_.compute(x) 

63 if scores.shape[0] != label.shape[0]: 

64 scores = scores.reshape(label.shape[0], 

65 scores.shape[0] // label.shape[0]) 

66 return self._post_process_predicted_label(label, scores) 

67 

68 

69class SVMClassifier(SVMClassifierCommon): 

70 

71 atts = OrderedDict([ 

72 ('classlabels_ints', numpy.empty(0, dtype=numpy.int64)), 

73 ('classlabels_strings', []), 

74 ('coefficients', numpy.empty(0, dtype=numpy.float32)), 

75 ('kernel_params', numpy.empty(0, dtype=numpy.float32)), 

76 ('kernel_type', b'NONE'), 

77 ('post_transform', b'NONE'), 

78 ('prob_a', numpy.empty(0, dtype=numpy.float32)), 

79 ('prob_b', numpy.empty(0, dtype=numpy.float32)), 

80 ('rho', numpy.empty(0, dtype=numpy.float32)), 

81 ('support_vectors', numpy.empty(0, dtype=numpy.float32)), 

82 ('vectors_per_class', numpy.empty(0, dtype=numpy.float32)), 

83 ]) 

84 

85 def __init__(self, onnx_node, desc=None, **options): 

86 SVMClassifierCommon.__init__( 

87 self, numpy.float32, onnx_node, desc=desc, 

88 expected_attributes=SVMClassifier.atts, 

89 **options) 

90 

91 

92class SVMClassifierDouble(SVMClassifierCommon): 

93 

94 atts = OrderedDict([ 

95 ('classlabels_ints', numpy.empty(0, dtype=numpy.int64)), 

96 ('classlabels_strings', []), 

97 ('coefficients', numpy.empty(0, dtype=numpy.float64)), 

98 ('kernel_params', numpy.empty(0, dtype=numpy.float64)), 

99 ('kernel_type', b'NONE'), 

100 ('post_transform', b'NONE'), 

101 ('prob_a', numpy.empty(0, dtype=numpy.float64)), 

102 ('prob_b', numpy.empty(0, dtype=numpy.float64)), 

103 ('rho', numpy.empty(0, dtype=numpy.float64)), 

104 ('support_vectors', numpy.empty(0, dtype=numpy.float64)), 

105 ('vectors_per_class', numpy.empty(0, dtype=numpy.float64)), 

106 ]) 

107 

108 def __init__(self, onnx_node, desc=None, **options): 

109 SVMClassifierCommon.__init__( 

110 self, numpy.float64, onnx_node, desc=desc, 

111 expected_attributes=SVMClassifierDouble.atts, 

112 **options) 

113 

114 

115class SVMClassifierDoubleSchema(OperatorSchema): 

116 """ 

117 Defines a schema for operators added in this package 

118 such as @see cl SVMClassifierDouble. 

119 """ 

120 

121 def __init__(self): 

122 OperatorSchema.__init__(self, 'SVMClassifierDouble') 

123 self.attributes = SVMClassifierDouble.atts