Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief A pipeline which serializes into ONNX steps by steps. 

4""" 

5import numpy 

6from sklearn.base import clone 

7from sklearn.pipeline import Pipeline, _fit_transform_one 

8from sklearn.utils.validation import check_memory 

9from sklearn.utils import _print_elapsed_time 

10from ..onnx_conv import to_onnx 

11from .onnx_transformer import OnnxTransformer 

12 

13 

14class OnnxPipeline(Pipeline): 

15 """ 

16 The pipeline overwrites method *fit*, it trains and converts 

17 every steps into ONNX before training the next step 

18 in order to minimize discrepencies. By default, 

19 ONNX is using float and not double which is the default 

20 for :epkg:`scikit-learn`. It may introduce discrepencies 

21 when a non-continuous model (mathematical definition) such 

22 as tree ensemble and part of the pipeline. 

23 

24 :param steps: 

25 List of (name, transform) tuples (implementing fit/transform) that are 

26 chained, in the order in which they are chained, with the last object 

27 an estimator. 

28 :param memory: str or object with the joblib.Memory interface, default=None 

29 Used to cache the fitted transformers of the pipeline. By default, 

30 no caching is performed. If a string is given, it is the path to 

31 the caching directory. Enabling caching triggers a clone of 

32 the transformers before fitting. Therefore, the transformer 

33 instance given to the pipeline cannot be inspected 

34 directly. Use the attribute ``named_steps`` or ``steps`` to 

35 inspect estimators within the pipeline. Caching the 

36 transformers is advantageous when fitting is time consuming. 

37 :param verbose: bool, default=False 

38 If True, the time elapsed while fitting each step will be printed as it 

39 is completed. 

40 :param output_name: string 

41 requested output name or None to request all and 

42 have method *transform* to store all of them in a dataframe 

43 :param enforce_float32: boolean 

44 :epkg:`onnxruntime` only supports *float32*, 

45 :epkg:`scikit-learn` usually uses double floats, this parameter 

46 ensures that every array of double floats is converted into 

47 single floats 

48 :param runtime: string, defined the runtime to use 

49 as described in @see cl OnnxInference. 

50 :param options: see @see fn to_onnx 

51 :param white_op: see @see fn to_onnx 

52 :param black_op: see @see fn to_onnx 

53 :param final_types: see @see fn to_onnx 

54 :param op_version: ONNX targeted opset 

55 

56 The class stores transformers before converting them into ONNX 

57 in attributes ``raw_steps_``. 

58 

59 See notebook :ref:`onnxdiscrepenciesrst` to see it can 

60 be used to reduce discrepencies after it was converted into 

61 *ONNX*. 

62 """ 

63 

64 def __init__(self, steps, *, memory=None, verbose=False, 

65 output_name=None, enforce_float32=True, 

66 runtime='python', options=None, 

67 white_op=None, black_op=None, final_types=None, 

68 op_version=None): 

69 self.output_name = output_name 

70 self.enforce_float32 = enforce_float32 

71 self.runtime = runtime 

72 self.options = options 

73 self.white_op = white_op 

74 self.white_op = white_op 

75 self.black_op = black_op 

76 self.final_types = final_types 

77 self.op_version = op_version 

78 # The constructor calls _validate_step and it checks the value 

79 # of black_op. 

80 Pipeline.__init__( 

81 self, steps, memory=memory, verbose=verbose) 

82 

83 def fit(self, X, y=None, **fit_params): 

84 """ 

85 Fits the model, fits all the transforms one after the 

86 other and transform the data, then fit the transformed 

87 data using the final estimator. 

88 

89 :param X: iterable 

90 Training data. Must fulfill input requirements of first step of the 

91 pipeline. 

92 :param y: iterable, default=None 

93 Training targets. Must fulfill label requirements for all steps of 

94 the pipeline. 

95 :param fit_params: dict of string -> object 

96 Parameters passed to the ``fit`` method of each step, where 

97 each parameter name is prefixed such that parameter ``p`` for step 

98 ``s`` has key ``s__p``. 

99 :return: self, Pipeline, this estimator 

100 """ 

101 fit_params_steps = self._check_fit_params(**fit_params) 

102 Xt = self._fit(X, y, **fit_params_steps) 

103 with _print_elapsed_time('OnnxPipeline', 

104 self._log_message(len(self.steps) - 1)): 

105 if self._final_estimator != 'passthrough': 

106 fit_params_last_step = fit_params_steps[self.steps[-1][0]] 

107 self._final_estimator.fit(Xt, y, **fit_params_last_step) 

108 

109 return self 

110 

111 def _fit(self, X, y=None, **fit_params_steps): 

112 # shallow copy of steps - this should really be steps_ 

113 if hasattr(self, 'raw_steps_') and self.raw_steps_ is not None: # pylint: disable=E0203 

114 # Let's reuse the previous training. 

115 self.steps = list(self.raw_steps_) # pylint: disable=E0203 

116 self.raw_steps_ = list(self.raw_steps_) 

117 else: 

118 self.steps = list(self.steps) 

119 self.raw_steps_ = list(self.steps) 

120 

121 self._validate_steps() 

122 # Setup the memory 

123 memory = check_memory(self.memory) 

124 

125 fit_transform_one_cached = memory.cache(_fit_transform_one) 

126 

127 for (step_idx, 

128 name, 

129 transformer) in self._iter(with_final=False, 

130 filter_passthrough=False): 

131 if (transformer is None or transformer == 'passthrough'): 

132 with _print_elapsed_time('Pipeline', 

133 self._log_message(step_idx)): 

134 continue 

135 

136 if hasattr(memory, 'location'): 

137 # joblib >= 0.12 

138 if memory.location is None: 

139 # we do not clone when caching is disabled to 

140 # preserve backward compatibility 

141 cloned_transformer = transformer 

142 else: 

143 cloned_transformer = clone(transformer) 

144 else: 

145 cloned_transformer = clone(transformer) 

146 

147 # Fit or load from cache the current transformer 

148 x_train = X 

149 X, fitted_transformer = fit_transform_one_cached( 

150 cloned_transformer, X, y, None, 

151 message_clsname='Pipeline', 

152 message=self._log_message(step_idx), 

153 **fit_params_steps[name]) 

154 # Replace the transformer of the step with the fitted 

155 # transformer. This is necessary when loading the transformer 

156 # from the cache. 

157 self.raw_steps_[step_idx] = (name, fitted_transformer) 

158 self.steps[step_idx] = ( 

159 name, self._to_onnx(name, fitted_transformer, x_train)) 

160 return X 

161 

162 def _to_onnx(self, name, fitted_transformer, x_train, rewrite_ops=True, 

163 verbose=0): 

164 """ 

165 Converts a transformer into ONNX. 

166 

167 :param name: model name 

168 :param fitted_transformer: fitted transformer 

169 :param x_train: training dataset 

170 :param rewrite_ops: use rewritten converters 

171 :param verbose: display some information 

172 :return: corresponding @see cl OnnxTransformer 

173 """ 

174 if not isinstance(x_train, numpy.ndarray): 

175 raise RuntimeError( # pragma: no cover 

176 "The pipeline only handle numpy arrays not {}.".format( 

177 type(x_train))) 

178 atts = {'options', 'white_op', 'black_op', 'final_types'} 

179 kwargs = {k: getattr(self, k) for k in atts} 

180 if self.enforce_float32 or x_train.dtype != numpy.float64: 

181 x_train = x_train.astype(numpy.float32) 

182 if 'options' in kwargs: 

183 kwargs['options'] = self._preprocess_options( 

184 name, kwargs['options']) 

185 kwargs['target_opset'] = self.op_version 

186 onx = to_onnx(fitted_transformer, x_train, 

187 rewrite_ops=rewrite_ops, verbose=verbose, 

188 **kwargs) 

189 if len(onx.graph.output) != 1: 

190 raise RuntimeError( 

191 "Only one output is allowed in the ONNX graph not %d. " 

192 "Model=%r" % (len(onx.graph.output), fitted_transformer)) 

193 tr = OnnxTransformer( 

194 onx.SerializeToString(), output_name=self.output_name, 

195 enforce_float32=self.enforce_float32, runtime=self.runtime) 

196 return tr.fit() 

197 

198 def _preprocess_options(self, name, options): 

199 """ 

200 Preprocesses the options. 

201 

202 @param name option name 

203 @param options conversion options 

204 @return new options 

205 """ 

206 if options is None: 

207 return None 

208 prefix = name + '__' 

209 new_options = {} 

210 for k, v in options.items(): 

211 if isinstance(k, str): 

212 if k.startswith(prefix): 

213 new_options[k[len(prefix):]] = v 

214 else: 

215 new_options[k] = v 

216 return new_options