Coverage for onnxcustom/training/optimizers_partial.py: 98%
283 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
1"""
2@file
3@brief Optimizer with :epkg:`onnxruntime-training` forward backward training.
4"""
5import logging
6import warnings
7import numpy
8from onnxruntime import InferenceSession, SessionOptions
9from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
10 OrtValue as C_OrtValue)
11from ..utils.onnx_helper import get_onnx_opset, proto_type_to_dtype
12from ..utils.onnxruntime_helper import (
13 device_to_providers, numpy_to_ort_value)
14from ..utils.onnx_function import function_onnx_graph
15from ..utils.print_helper import str_ortvalue
16from ..utils.orttraining_helper import get_train_initializer
17from .ortgradient import OrtGradientForwardBackward
18from ._base_estimator import BaseEstimator
19from .sgd_learning_loss import BaseLearningLoss
20from .sgd_learning_penalty import BaseLearningPenalty
21from .data_loader import OrtDataLoader
22from .excs import ConvergenceError, ConvergenceWarning
25class OrtGradientForwardBackwardOptimizer(BaseEstimator):
26 """
27 Implements a simple :epkg:`Stochastic Gradient Descent`
28 with :epkg:`onnxruntime-training`. It leverages class
29 @see cl OrtGradientForwardBackward.
31 :param model_onnx: onnx graph to train
32 :param weights_to_train: names of initializers to be optimized,
33 if None, function @see fn get_train_initialize returns
34 the list of float iniitializer
35 :param loss_output_name: name of the loss output
36 :param max_iter: number of training iterations
37 :param training_optimizer_name: optimizing algorithm
38 :param batch_size: batch size (see class *DataLoader*)
39 :param learning_rate: a name or a learning rate instance or a float,
40 see module :mod:`onnxcustom.training.sgd_learning_rate`
41 :param device: device as :epkg:`C_OrtDevice` or a string
42 representing this device
43 :param warm_start: when set to True, reuse the solution of the previous
44 call to fit as initialization, otherwise, just erase the previous
45 solution.
46 :param learning_loss: loss function (see below)
47 :param verbose: use :epkg:`tqdm` to display the training progress
48 :param validation_every: validation with a test set every
49 *validation_every* iterations
50 :param enable_logging: enable logging (mostly for debugging puporse
51 as it slows down the training)
52 :param weight_name: if not None, the class assumes it is trained
53 with training weight
54 :param learning_penalty: weight penalty, None, or instance of
55 @see cl BaseLearningPenalty
56 :param exc: raise exceptions (about convergence for example)
57 or keep them silent as much as possible
59 *learning_rate* can be any instance of @see cl BaseLearningRate or
60 a nick name in the following list as specified in
61 :meth:`BaseLearningRate.select
62 <onnxcustom.training.sgd_learning_loss.BaseLearningRate.select>`.
64 *learning_loss* can be any instance of @see cl BaseLearningLoss or
65 a nick name in the following list as specified in
66 :meth:`BaseLearningLoss.select
67 <onnxcustom.training.sgd_loss.BaseLearningLoss.select>`.
68 """
70 def __init__(self, model_onnx, weights_to_train=None,
71 loss_output_name='loss', max_iter=100,
72 training_optimizer_name='SGDOptimizer',
73 batch_size=10, learning_rate='SGD',
74 device='cpu', warm_start=False, verbose=0,
75 validation_every=0.1, learning_loss="square_error",
76 enable_logging=False, weight_name=None,
77 learning_penalty=None, exc=True):
78 if weights_to_train is None:
79 weights_to_train = list(get_train_initializer(model_onnx))
80 BaseEstimator.__init__(self, model_onnx, learning_rate, device)
81 self.batch_size = batch_size
82 self.weights_to_train = weights_to_train
83 self.loss_output_name = loss_output_name
84 self.training_optimizer_name = training_optimizer_name
85 self.verbose = verbose
86 self.max_iter = max_iter
87 self.warm_start = warm_start
88 self.learning_loss = BaseLearningLoss.select(learning_loss)
89 self.learning_penalty = BaseLearningPenalty.select(learning_penalty)
90 self.enable_logging = enable_logging
91 self.weight_name = weight_name
92 self.exc = exc
93 if validation_every < 1:
94 self.validation_every = int(self.max_iter * validation_every)
95 else:
96 self.validation_every = validation_every # pragma: no cover
97 self.build_onnx_function()
99 @property
100 def needs_grad(self):
101 """
102 Returns the True if the gradient update needs to retain
103 past gradients.
104 """
105 return self.learning_rate.needs_grad
107 def __getstate__(self):
108 "Removes any non pickable attribute."
109 state = BaseEstimator.__getstate__(self)
110 for att in ['train_state_', 'train_grad_state_']:
111 if hasattr(self, att):
112 train_state = []
113 for v in self.get_state():
114 if v is None:
115 train_state.append(v)
116 else:
117 train_state.append(v.numpy())
118 state[att[:-1]] = train_state
119 return state
121 def __setstate__(self, state):
122 "Restores any non pickable attribute."
123 popped = {}
124 for att in ['train_state', 'train_grad_state']:
125 if att in state:
126 popped[att] = state.pop(att)
127 BaseEstimator.__setstate__(self, state)
128 for k, v in popped.items():
129 if k == 'train_state':
130 self.set_state(v, check_trained=False, kind='weight')
131 elif k == 'train_grad_state':
132 self.set_state(v, check_trained=False, kind='grad')
133 else:
134 raise ValueError( # pragma: no cover
135 f"Unexpected key state {k!r}.")
136 self.build_onnx_function()
137 return self
139 def _get_att_state(self, kind):
140 if kind == 'weight':
141 return 'train_state_'
142 if kind == 'grad':
143 return 'train_grad_state_'
144 raise ValueError( # pragma: no cover
145 f"Unexpected kind={kind!r}.")
147 def get_full_state(self, kind='weight'):
148 """
149 Returns the trained weights and the inputs.
150 """
151 if isinstance(kind, list):
152 return [self.get_full_state(kind=k) for k in kind]
153 att = self._get_att_state(kind)
154 if not hasattr(self, att):
155 raise AttributeError( # pragma: no cover
156 "Method fit must be called before.")
157 return getattr(self, att)
159 def get_state(self, kind='weight'):
160 """
161 Returns the trained weights.
162 """
163 att = self._get_att_state(kind)
164 if not hasattr(self, att):
165 raise AttributeError("Method fit must be called before.")
166 if getattr(self, att, None) is None:
167 raise RuntimeError( # pragma: no cover
168 f"No attribute {att!r} available (None).")
169 if self.weights_to_train is None:
170 raise RuntimeError( # pragma: no cover
171 "Unexpected self.weights_to_train (None).")
172 value = getattr(self, att)
173 n = len(value) - len(self.weights_to_train)
174 return value[n:]
176 @property
177 def trained_coef_(self):
178 """
179 Returns the trained coefficients a dictionary.
180 """
181 return dict(zip(self.weights_to_train, self.get_state()))
183 def get_trained_onnx(self, model=None):
184 """
185 Returns the trained onnx graph, the initial graph
186 modified by replacing the initializers with the trained
187 weights.
189 :param model: replace the weights in another graph
190 than the training graph
191 :return: onnx graph
192 """
193 state = dict(zip(self.weights_to_train, self.get_state()))
194 return self._get_trained_onnx(state, model=model)
196 def set_state(self, state, check_trained=True, kind='weight', zero=False):
197 """
198 Changes the trained weights.
199 """
200 if check_trained and not hasattr(self, 'train_session_'):
201 raise AttributeError( # pragma: no cover
202 "Method fit must be called before.")
203 state_ = []
204 state_numpy_ = []
205 for i, v in enumerate(state):
206 if v is None:
207 state_.append(None)
208 state_numpy_.append(None)
209 elif isinstance(v, numpy.ndarray):
210 if zero:
211 v = numpy.zeros(v.shape, dtype=v.dtype)
212 ortvalue = numpy_to_ort_value(v, self.device)
213 state_.append(ortvalue)
214 # The numpy container must be retained as the ortvalue
215 # just borrows the pointer.
216 state_numpy_.append(v)
217 elif isinstance(v, C_OrtValue):
218 if zero:
219 v = self.zero_sess_.run_with_ort_values(['Y'], {'X': v})
220 state_.append(v)
221 state_numpy_.append(None)
222 else:
223 raise TypeError( # pragma: no cover
224 f"Unexpected type {type(v)!r} for state {i!r}.")
225 att = self._get_att_state(kind)
226 setattr(self, att, state_)
227 setattr(self, att + "numpy_", state_numpy_)
229 def build_onnx_function(self):
230 """
231 Creates ONNX graph and *InferenceSession* related to
232 any operations applying on *OrtValue*.
233 """
234 opset = get_onnx_opset(self.model_onnx)
235 so = SessionOptions()
236 so.log_severity_level = 4
238 n = len(self.weights_to_train)
240 # loss_grad
241 self.learning_loss.build_onnx_function(
242 opset, self.device, self.weight_name)
244 # weight update
245 self.learning_rate.build_onnx_function(opset, self.device, n)
247 # regularization
248 self.learning_penalty.build_onnx_function(opset, self.device, n)
250 # zero
251 self.zero_onnx_ = function_onnx_graph("zero")
252 self.zero_sess_ = InferenceSession(
253 self.zero_onnx_.SerializeToString(), so,
254 providers=device_to_providers(self.device))
256 # logging
257 if self.enable_logging:
258 self._logger = logging.getLogger("onnxcustom")
259 else:
260 self._logger = None
262 def fit(self, X, y, sample_weight=None,
263 X_val=None, y_val=None):
264 """
265 Trains the model.
267 :param X: features
268 :param y: expected output
269 :param sample_weight: training weight or None
270 :param X_val: evaluation dataset
271 :param y_val: evaluation dataset
272 :return: self
273 """
274 if self.training_optimizer_name != 'SGDOptimizer':
275 raise NotImplementedError(
276 "Only the SGDOptimizer is implemented not %r."
277 "" % self.training_optimizer_name)
278 logger = self._logger
280 session_function = self._create_training_session(
281 self.model_onnx, self.weights_to_train,
282 device=self.device)
283 self.train_session_ = session_function[0]
284 self.train_function_ = session_function[1]
286 self.input_names_ = self.train_session_.cls_type_._grad_input_names
287 self.output_names_ = self.train_session_.cls_type_._bw_fetches_names
288 weights_to_train = self.train_session_.weights_to_train
290 if logger is not None:
291 logger.info(
292 "[OrtGradientForwardBackwardOptimizer.fit] "
293 "input_names=%r", self.input_names_)
294 logger.info(
295 "[OrtGradientForwardBackwardOptimizer.fit] "
296 "output_names=%r", self.output_names_)
297 logger.info(
298 "[OrtGradientForwardBackwardOptimizer.fit] "
299 "weights_to_train=%r", self.weights_to_train)
300 logger.info(
301 "[OrtGradientForwardBackwardOptimizer.fit] "
302 "device=%r|%r",
303 self.device.device_id(), self.device.device_type())
304 if logger is not None:
305 logger.info(
306 "[OrtGradientForwardBackwardOptimizer.fit] "
307 "warm_start=%r", self.warm_start)
309 if not hasattr(self, 'state_'):
310 self.set_state([
311 self.train_session_.get_initializer(name, exc=False)
312 for name in self.input_names_])
313 if self.needs_grad and not hasattr(self, 'state_grad_'):
314 self.set_state([
315 self.train_session_.get_initializer(name, exc=False)
316 for name in self.input_names_],
317 kind='grad', zero=True)
318 if not self.warm_start:
319 state = self.get_full_state()
320 if len(state) != len(self.input_names_):
321 raise RuntimeError( # pragma: no cover
322 f"Length mismatch {len(state)!r} != {len(self.input_names_)!r}.")
323 new_state = []
324 for iv, v in enumerate(state):
325 if v is None:
326 new_state.append(v)
327 else:
328 if not isinstance(v, C_OrtValue):
329 raise RuntimeError( # pragma: no cover
330 "Unexpected type %r (state[%d])." % (
331 type(v), iv))
332 dtype = proto_type_to_dtype(
333 v.proto_type()
334 if hasattr(v, 'proto_type')
335 else v.data_type())
336 if len(v.shape()) > 0:
337 new_state.append(
338 numpy.random.randn(*v.shape()).astype(dtype))
339 else:
340 new_state.append(
341 numpy.random.randn(1).astype(dtype))
342 self.set_state(new_state)
343 if self.needs_grad:
344 self.set_state(new_state, kind='grad', zero=True)
346 data_loader = OrtDataLoader(
347 X, y, sample_weight, batch_size=self.batch_size,
348 device=self.device)
349 if X_val is not None:
350 data_loader_val = OrtDataLoader(
351 X_val, y_val, batch_size=X_val.shape[0], device=self.device,
352 random_iter=False)
353 else:
354 data_loader_val = None
356 self.learning_rate.init_learning_rate()
358 if self.verbose > 0: # pragma: no cover
359 from tqdm import tqdm # pylint: disable=C0415
360 loop = tqdm(range(self.max_iter))
361 else:
362 loop = range(self.max_iter)
364 self.train_losses_ = []
365 val_losses = []
366 kinds = ['weight', 'grad'] if self.needs_grad else ['weight']
367 for it in loop:
368 loss = self._iteration(
369 data_loader, self.get_full_state(kind=kinds),
370 len(weights_to_train))
371 lr = self.learning_rate.update_learning_rate(it).value
372 if self.verbose > 1: # pragma: no cover
373 loop.set_description(
374 "loss=%1.3g lr=%1.3g" % ( # pylint: disable=E1101,E1307
375 loss, lr)) # pylint: disable=E1101,E1307
376 if logger is not None:
377 logger.info(
378 "[OrtGradientForwardBackwardOptimizer.fit] "
379 "lr value=%r", lr)
381 self.train_losses_.append(loss)
382 if (data_loader_val is not None and
383 (it + 1) % self.validation_every == 0):
384 val_losses.append(
385 self._evaluation(data_loader_val, self.get_full_state()))
386 self.validation_losses_ = (
387 None if data_loader_val is None else val_losses)
389 if logger is not None:
390 logger.info(
391 "[OrtGradientForwardBackwardOptimizer.fit] "
392 "end loss=%r", self.train_losses_[-1])
393 return self
395 def _iteration(self, data_loader, states, n_weights):
396 actual_losses = []
397 bs = data_loader.batch_size
398 logger = self._logger
399 if len(states) == 1:
400 state = states[0]
401 grad = None
402 else:
403 state, grad = states
405 if logger is not None:
406 logger.debug(
407 "[OrtGradientForwardBackwardOptimizer._iteration] "
408 "iteration begin learning_rate=%r",
409 self.learning_rate)
411 prediction_cache = None
412 prediction_cache_shape = None
413 backward_outputs_cache = None
414 for ib, ito in enumerate(data_loader.iter_ortvalue()):
415 if len(ito) == 2:
416 (ortx, orty) = ito
417 ortw = None
418 else:
419 (ortx, orty, ortw) = ito
420 state[0] = ortx
422 if logger is not None:
423 logger.debug(
424 "[OrtGradientForwardBackwardOptimizer._iteration] "
425 "batch %d", ib)
427 ortx_shape = tuple(ortx.shape())
428 same_shape = (
429 prediction_cache_shape is not None and
430 ortx_shape == prediction_cache_shape)
432 if logger is not None:
433 logger.debug(
434 "[OrtGradientForwardBackwardOptimizer._iteration] forward")
436 # forward
437 if prediction_cache_shape is None or same_shape:
438 prediction_cache = None
439 prediction_cache_shape = None
440 prediction = self.train_function_.forward(
441 states[0], training=True,
442 forward_outputs_cache=prediction_cache)
443 prediction_cache = prediction
444 prediction_cache_shape = ortx_shape
446 if logger is not None:
447 logger.debug(
448 "[OrtGradientForwardBackwardOptimizer._iteration] "
449 "loss types=%r,%r",
450 orty.data_type(), prediction[0].data_type())
452 # loss
453 loss, loss_gradient = self.learning_loss.loss_gradient(
454 self.device, orty, prediction[0], weight=ortw)
456 if logger is not None:
457 logger.debug(
458 "[OrtGradientForwardBackwardOptimizer._iteration] "
459 "loss=%g has_weight=%r",
460 loss.numpy(), ortw is not None)
462 n = len(state) - n_weights
463 loss = self.learning_penalty.penalty_loss(
464 self.device, loss, *state[n:])
466 cpu_loss = loss.numpy()
468 if logger is not None:
469 logger.debug(
470 "[OrtGradientForwardBackwardOptimizer._iteration] "
471 "cpu_loss=%r", cpu_loss)
473 if numpy.isinf(cpu_loss) or numpy.isnan(cpu_loss):
474 if self.exc:
475 raise ConvergenceError(
476 "Loss is nan, learning_rate=%r, "
477 "the gradient descent has failed "
478 "(past losses=%r)." % (
479 self.learning_rate,
480 [float(v) for v in (
481 actual_losses if len(actual_losses) < 5
482 else actual_losses[-5:])]))
483 warnings.warn( # pragma: no cover
484 "Loss is nan, learning_rate=%r, "
485 "the gradient descent has failed "
486 "(past losses=%r)." % (
487 self.learning_rate,
488 [float(v) for v in (
489 actual_losses if len(actual_losses) < 5
490 else actual_losses[-5:])]),
491 ConvergenceWarning)
492 if numpy.isinf(cpu_loss): # pragma: no cover
493 cpu_loss = numpy.nan
495 # backward
496 if not same_shape:
497 backward_outputs_cache = None
498 gradient = self.train_function_.backward(
499 [loss_gradient], backward_outputs_cache=backward_outputs_cache)
500 backward_outputs_cache = gradient
502 if len(gradient) != len(state):
503 raise RuntimeError( # pragma: no cover
504 "gradient and state should have the same length but "
505 "%r != %r." % (len(gradient), len(state)))
507 n = len(state) - n_weights
509 for i in range(n, len(state)):
510 self.learning_penalty.update_weights(
511 i - n, self.device, state[i])
512 self.learning_rate.update_weights(
513 i - n, self.device, state[i],
514 gradient[i], bs,
515 None if grad is None else grad[i])
517 if logger is not None:
518 logger.debug(
519 "[OrtGradientForwardBackwardOptimizer._iteration] "
520 "loss=%g n_weights=%d", cpu_loss, n)
521 for i in range(n, len(state)):
522 logger.debug(
523 "[OrtGradientForwardBackwardOptimizer._iteration] "
524 "state[%i]=%s", i, str_ortvalue(state[i]))
526 actual_losses.append(cpu_loss / bs)
528 if logger is not None:
529 logger.debug(
530 "[OrtGradientForwardBackwardOptimizer._iteration] "
531 "iteration end")
533 return numpy.array(actual_losses).mean()
535 def _evaluation(self, data_loader, state):
536 logger = self._logger
537 actual_losses = []
538 for ib, (ortx, orty) in enumerate(data_loader.iter_ortvalue()):
539 state[0] = ortx
541 if logger is not None:
542 logger.debug( # pragma: no cover
543 "[OrtGradientForwardBackwardOptimizer._evaluation] "
544 "batch %d", ib)
546 prediction = self.train_function_.forward(state, training=False)
547 loss, _ = self.learning_loss.loss_gradient(
548 self.device, orty, prediction[0])
549 cpu_loss = loss.numpy()
550 if numpy.isinf(cpu_loss) or numpy.isnan(cpu_loss):
551 if self.exc: # pragma: no cover
552 raise ConvergenceError(
553 "Loss is nan, "
554 "the evaluation has failed "
555 "(past losses=%r)." %
556 [float(v) for v in (
557 actual_losses if len(actual_losses) < 5
558 else actual_losses[-5:])])
559 warnings.warn(
560 "Loss is nan, learning_rate=%r, "
561 "the gradient descent has failed "
562 "(past losses=%r)." % (
563 self.learning_rate,
564 [float(v) for v in (
565 actual_losses if len(actual_losses) < 5
566 else actual_losses[-5:])]),
567 ConvergenceWarning)
568 if numpy.isinf(cpu_loss):
569 cpu_loss = numpy.nan
570 actual_losses.append(cpu_loss)
572 return numpy.array(actual_losses).sum() / len(data_loader)
574 def score(self, X, y, sample_weight=None):
575 """
576 Return the whole score associated.
578 :param X: features
579 :param y: expected output
580 :param sample_weight: training weight or None
581 :return: score
582 """
583 scores = self.losses(X, y, sample_weight=sample_weight)
584 return -scores.sum() / X.shape[0]
586 def losses(self, X, y, sample_weight=None):
587 """
588 Returns the losses associated to every observation.
590 :param X: features
591 :param y: expected output
592 :param sample_weight: training weight or None
593 :return: scores
594 """
595 data_loader = OrtDataLoader(
596 X, y, sample_weight, batch_size=self.batch_size,
597 device=self.device)
599 state = self.get_full_state()
600 scores = numpy.empty((X.shape[0], ), dtype=X.dtype)
601 pos = 0
602 for ito in data_loader.iter_ortvalue():
603 if len(ito) == 2:
604 (ortx, orty) = ito
605 ortw = None
606 else:
607 (ortx, orty, ortw) = ito
608 state[0] = ortx
609 prediction = self.train_function_.forward(state, training=False)
610 score = self.learning_loss.loss_scores(
611 self.device, orty, prediction[0], ortw)
612 np_score = score.numpy()
613 # data copy could be avoided by giving a pointer to
614 # loss score or if we could create an OrtValue from a
615 # pointer.
616 end = pos + np_score.shape[0]
617 if end <= scores.shape[0]:
618 scores[pos: end] = np_score.ravel()
619 else:
620 scores[pos: end] = np_score.ravel()[end - scores.shape[0]:]
621 pos += np_score.shape[0]
622 return scores
624 def _create_training_session(
625 self, model_onnx, weights_to_train, device):
627 forback = OrtGradientForwardBackward(
628 model_onnx, weights_to_train=weights_to_train,
629 debug=False, enable_logging=False,
630 providers=device_to_providers(device))
631 inst = forback.new_instance()
632 return (forback, inst)