Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Rewrites some of the converters implemented in 

4:epkg:`sklearn-onnx`. 

5""" 

6import numpy 

7from skl2onnx.operator_converters.support_vector_machines import ( 

8 convert_sklearn_svm_regressor, 

9 convert_sklearn_svm_classifier) 

10from skl2onnx.common.data_types import guess_numpy_type 

11 

12 

13def _op_type_domain_regressor(dtype): 

14 """ 

15 Defines *op_type* and *op_domain* based on `dtype`. 

16 """ 

17 if dtype == numpy.float32: 

18 return 'SVMRegressor', 'ai.onnx.ml', 1 

19 if dtype == numpy.float64: 

20 return 'SVMRegressorDouble', 'mlprodict', 1 

21 raise RuntimeError( # pragma: no cover 

22 "Unsupported dtype {}.".format(dtype)) 

23 

24 

25def _op_type_domain_classifier(dtype): 

26 """ 

27 Defines *op_type* and *op_domain* based on `dtype`. 

28 """ 

29 if dtype == numpy.float32: 

30 return 'SVMClassifier', 'ai.onnx.ml', 1 

31 if dtype == numpy.float64: 

32 return 'SVMClassifierDouble', 'mlprodict', 1 

33 raise RuntimeError( # pragma: no cover 

34 "Unsupported dtype {}.".format(dtype)) 

35 

36 

37def new_convert_sklearn_svm_regressor(scope, operator, container): 

38 """ 

39 Rewrites the converters implemented in 

40 :epkg:`sklearn-onnx` to support an operator supporting 

41 doubles. 

42 """ 

43 dtype = guess_numpy_type(operator.inputs[0].type) 

44 if dtype != numpy.float64: 

45 dtype = numpy.float32 

46 op_type, op_domain, op_version = _op_type_domain_regressor(dtype) 

47 convert_sklearn_svm_regressor( 

48 scope, operator, container, op_type=op_type, op_domain=op_domain, 

49 op_version=op_version) 

50 

51 

52def new_convert_sklearn_svm_classifier(scope, operator, container): 

53 """ 

54 Rewrites the converters implemented in 

55 :epkg:`sklearn-onnx` to support an operator supporting 

56 doubles. 

57 """ 

58 dtype = guess_numpy_type(operator.inputs[0].type) 

59 if dtype != numpy.float64: 

60 dtype = numpy.float32 

61 op_type, op_domain, op_version = _op_type_domain_classifier(dtype) 

62 convert_sklearn_svm_classifier( 

63 scope, operator, container, op_type=op_type, op_domain=op_domain, 

64 op_version=op_version)