Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Implementation of a dummy score using 

4:epkg:`cdist`. 

5""" 

6import numpy 

7from onnx import onnx_pb as onnx_proto 

8from scipy.spatial.distance import cdist 

9 

10 

11def score_cdist_sum(expected_values, predictions, 

12 metric='sqeuclidean', p=None): 

13 """ 

14 Computes the sum of pairwise distances between 

15 *expected_values* and *predictions*. It has no 

16 particular purpose except the one of converting 

17 a scorer into ONNX. 

18 

19 @param expected_values expected_values 

20 @param predictions predictions 

21 @param metric see function :epkg:`cdist` 

22 @param p see function :epkg:`cdist` 

23 @return some of the pairwise distances 

24 """ 

25 if p is None: 

26 dist = cdist(expected_values, predictions, metric=metric) 

27 else: 

28 dist = cdist(expected_values, predictions, metric=metric, p=p) 

29 return numpy.sum(dist, axis=1) 

30 

31 

32def convert_score_cdist_sum(scope, operator, container): 

33 """ 

34 Converts function @see fn score_cdist_sum into :epkg:`ONNX`. 

35 """ 

36 op = operator.raw_operator 

37 if op._fct != score_cdist_sum: # pylint: disable=W0143 

38 raise RuntimeError( # pragma: no cover 

39 "The wrong converter was called {} != {}.".format( 

40 op._fct, score_cdist_sum)) 

41 

42 from skl2onnx.algebra.complex_functions import onnx_cdist 

43 from skl2onnx.algebra.onnx_ops import OnnxReduceSumApi11 # pylint: disable=E0611 

44 from skl2onnx.common.data_types import guess_numpy_type 

45 

46 X = operator.inputs[0] 

47 Y = operator.inputs[1] 

48 out = operator.outputs 

49 opv = container.target_opset 

50 dtype = guess_numpy_type(operator.inputs[0].type) 

51 out = operator.outputs 

52 

53 options = container.get_options(score_cdist_sum, dict(cdist=None)) 

54 

55 kwargs = op.kwargs 

56 

57 if options.get('cdist', None) == 'single-node': 

58 attrs = kwargs 

59 cdist_name = scope.get_unique_variable_name('cdist') 

60 container.add_node('CDist', [X.full_name, Y.full_name], cdist_name, 

61 op_domain='mlprodict', name=scope.get_unique_operator_name('CDist'), 

62 **attrs) 

63 if container.target_opset < 13: 

64 container.add_node('ReduceSum', [cdist_name], out[0].full_name, 

65 axes=[1], keepdims=0, 

66 name=scope.get_unique_operator_name('ReduceSum')) 

67 else: 

68 axis_name = scope.get_unique_variable_name('axis') 

69 container.add_initializer( 

70 axis_name, onnx_proto.TensorProto.INT64, [1], [1]) # pylint: disable=E1101 

71 container.add_node( 

72 'ReduceSum', [cdist_name, axis_name], 

73 out[0].full_name, keepdims=0, 

74 name=scope.get_unique_operator_name('ReduceSum')) 

75 else: 

76 metric = kwargs['metric'] 

77 if metric == 'minkowski': 

78 dists = onnx_cdist(X, Y, dtype=dtype, op_version=opv, 

79 metric=metric, p=kwargs.get('p', 2)) 

80 else: 

81 dists = onnx_cdist(X, Y, dtype=dtype, op_version=opv, 

82 metric=kwargs['metric']) 

83 

84 res = OnnxReduceSumApi11(dists, axes=[1], keepdims=0, 

85 output_names=[out[0].full_name], 

86 op_version=opv) 

87 res.add_to(scope, container)