Coverage for onnxcustom/training/optimizers_partial.py: 98%

283 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 01:42 +0100

1""" 

2@file 

3@brief Optimizer with :epkg:`onnxruntime-training` forward backward training. 

4""" 

5import logging 

6import warnings 

7import numpy 

8from onnxruntime import InferenceSession, SessionOptions 

9from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

10 OrtValue as C_OrtValue) 

11from ..utils.onnx_helper import get_onnx_opset, proto_type_to_dtype 

12from ..utils.onnxruntime_helper import ( 

13 device_to_providers, numpy_to_ort_value) 

14from ..utils.onnx_function import function_onnx_graph 

15from ..utils.print_helper import str_ortvalue 

16from ..utils.orttraining_helper import get_train_initializer 

17from .ortgradient import OrtGradientForwardBackward 

18from ._base_estimator import BaseEstimator 

19from .sgd_learning_loss import BaseLearningLoss 

20from .sgd_learning_penalty import BaseLearningPenalty 

21from .data_loader import OrtDataLoader 

22from .excs import ConvergenceError, ConvergenceWarning 

23 

24 

25class OrtGradientForwardBackwardOptimizer(BaseEstimator): 

26 """ 

27 Implements a simple :epkg:`Stochastic Gradient Descent` 

28 with :epkg:`onnxruntime-training`. It leverages class 

29 @see cl OrtGradientForwardBackward. 

30 

31 :param model_onnx: onnx graph to train 

32 :param weights_to_train: names of initializers to be optimized, 

33 if None, function @see fn get_train_initialize returns 

34 the list of float iniitializer 

35 :param loss_output_name: name of the loss output 

36 :param max_iter: number of training iterations 

37 :param training_optimizer_name: optimizing algorithm 

38 :param batch_size: batch size (see class *DataLoader*) 

39 :param learning_rate: a name or a learning rate instance or a float, 

40 see module :mod:`onnxcustom.training.sgd_learning_rate` 

41 :param device: device as :epkg:`C_OrtDevice` or a string 

42 representing this device 

43 :param warm_start: when set to True, reuse the solution of the previous 

44 call to fit as initialization, otherwise, just erase the previous 

45 solution. 

46 :param learning_loss: loss function (see below) 

47 :param verbose: use :epkg:`tqdm` to display the training progress 

48 :param validation_every: validation with a test set every 

49 *validation_every* iterations 

50 :param enable_logging: enable logging (mostly for debugging puporse 

51 as it slows down the training) 

52 :param weight_name: if not None, the class assumes it is trained 

53 with training weight 

54 :param learning_penalty: weight penalty, None, or instance of 

55 @see cl BaseLearningPenalty 

56 :param exc: raise exceptions (about convergence for example) 

57 or keep them silent as much as possible 

58 

59 *learning_rate* can be any instance of @see cl BaseLearningRate or 

60 a nick name in the following list as specified in 

61 :meth:`BaseLearningRate.select 

62 <onnxcustom.training.sgd_learning_loss.BaseLearningRate.select>`. 

63 

64 *learning_loss* can be any instance of @see cl BaseLearningLoss or 

65 a nick name in the following list as specified in 

66 :meth:`BaseLearningLoss.select 

67 <onnxcustom.training.sgd_loss.BaseLearningLoss.select>`. 

68 """ 

69 

70 def __init__(self, model_onnx, weights_to_train=None, 

71 loss_output_name='loss', max_iter=100, 

72 training_optimizer_name='SGDOptimizer', 

73 batch_size=10, learning_rate='SGD', 

74 device='cpu', warm_start=False, verbose=0, 

75 validation_every=0.1, learning_loss="square_error", 

76 enable_logging=False, weight_name=None, 

77 learning_penalty=None, exc=True): 

78 if weights_to_train is None: 

79 weights_to_train = list(get_train_initializer(model_onnx)) 

80 BaseEstimator.__init__(self, model_onnx, learning_rate, device) 

81 self.batch_size = batch_size 

82 self.weights_to_train = weights_to_train 

83 self.loss_output_name = loss_output_name 

84 self.training_optimizer_name = training_optimizer_name 

85 self.verbose = verbose 

86 self.max_iter = max_iter 

87 self.warm_start = warm_start 

88 self.learning_loss = BaseLearningLoss.select(learning_loss) 

89 self.learning_penalty = BaseLearningPenalty.select(learning_penalty) 

90 self.enable_logging = enable_logging 

91 self.weight_name = weight_name 

92 self.exc = exc 

93 if validation_every < 1: 

94 self.validation_every = int(self.max_iter * validation_every) 

95 else: 

96 self.validation_every = validation_every # pragma: no cover 

97 self.build_onnx_function() 

98 

99 @property 

100 def needs_grad(self): 

101 """ 

102 Returns the True if the gradient update needs to retain 

103 past gradients. 

104 """ 

105 return self.learning_rate.needs_grad 

106 

107 def __getstate__(self): 

108 "Removes any non pickable attribute." 

109 state = BaseEstimator.__getstate__(self) 

110 for att in ['train_state_', 'train_grad_state_']: 

111 if hasattr(self, att): 

112 train_state = [] 

113 for v in self.get_state(): 

114 if v is None: 

115 train_state.append(v) 

116 else: 

117 train_state.append(v.numpy()) 

118 state[att[:-1]] = train_state 

119 return state 

120 

121 def __setstate__(self, state): 

122 "Restores any non pickable attribute." 

123 popped = {} 

124 for att in ['train_state', 'train_grad_state']: 

125 if att in state: 

126 popped[att] = state.pop(att) 

127 BaseEstimator.__setstate__(self, state) 

128 for k, v in popped.items(): 

129 if k == 'train_state': 

130 self.set_state(v, check_trained=False, kind='weight') 

131 elif k == 'train_grad_state': 

132 self.set_state(v, check_trained=False, kind='grad') 

133 else: 

134 raise ValueError( # pragma: no cover 

135 f"Unexpected key state {k!r}.") 

136 self.build_onnx_function() 

137 return self 

138 

139 def _get_att_state(self, kind): 

140 if kind == 'weight': 

141 return 'train_state_' 

142 if kind == 'grad': 

143 return 'train_grad_state_' 

144 raise ValueError( # pragma: no cover 

145 f"Unexpected kind={kind!r}.") 

146 

147 def get_full_state(self, kind='weight'): 

148 """ 

149 Returns the trained weights and the inputs. 

150 """ 

151 if isinstance(kind, list): 

152 return [self.get_full_state(kind=k) for k in kind] 

153 att = self._get_att_state(kind) 

154 if not hasattr(self, att): 

155 raise AttributeError( # pragma: no cover 

156 "Method fit must be called before.") 

157 return getattr(self, att) 

158 

159 def get_state(self, kind='weight'): 

160 """ 

161 Returns the trained weights. 

162 """ 

163 att = self._get_att_state(kind) 

164 if not hasattr(self, att): 

165 raise AttributeError("Method fit must be called before.") 

166 if getattr(self, att, None) is None: 

167 raise RuntimeError( # pragma: no cover 

168 f"No attribute {att!r} available (None).") 

169 if self.weights_to_train is None: 

170 raise RuntimeError( # pragma: no cover 

171 "Unexpected self.weights_to_train (None).") 

172 value = getattr(self, att) 

173 n = len(value) - len(self.weights_to_train) 

174 return value[n:] 

175 

176 @property 

177 def trained_coef_(self): 

178 """ 

179 Returns the trained coefficients a dictionary. 

180 """ 

181 return dict(zip(self.weights_to_train, self.get_state())) 

182 

183 def get_trained_onnx(self, model=None): 

184 """ 

185 Returns the trained onnx graph, the initial graph 

186 modified by replacing the initializers with the trained 

187 weights. 

188 

189 :param model: replace the weights in another graph 

190 than the training graph 

191 :return: onnx graph 

192 """ 

193 state = dict(zip(self.weights_to_train, self.get_state())) 

194 return self._get_trained_onnx(state, model=model) 

195 

196 def set_state(self, state, check_trained=True, kind='weight', zero=False): 

197 """ 

198 Changes the trained weights. 

199 """ 

200 if check_trained and not hasattr(self, 'train_session_'): 

201 raise AttributeError( # pragma: no cover 

202 "Method fit must be called before.") 

203 state_ = [] 

204 state_numpy_ = [] 

205 for i, v in enumerate(state): 

206 if v is None: 

207 state_.append(None) 

208 state_numpy_.append(None) 

209 elif isinstance(v, numpy.ndarray): 

210 if zero: 

211 v = numpy.zeros(v.shape, dtype=v.dtype) 

212 ortvalue = numpy_to_ort_value(v, self.device) 

213 state_.append(ortvalue) 

214 # The numpy container must be retained as the ortvalue 

215 # just borrows the pointer. 

216 state_numpy_.append(v) 

217 elif isinstance(v, C_OrtValue): 

218 if zero: 

219 v = self.zero_sess_.run_with_ort_values(['Y'], {'X': v}) 

220 state_.append(v) 

221 state_numpy_.append(None) 

222 else: 

223 raise TypeError( # pragma: no cover 

224 f"Unexpected type {type(v)!r} for state {i!r}.") 

225 att = self._get_att_state(kind) 

226 setattr(self, att, state_) 

227 setattr(self, att + "numpy_", state_numpy_) 

228 

229 def build_onnx_function(self): 

230 """ 

231 Creates ONNX graph and *InferenceSession* related to 

232 any operations applying on *OrtValue*. 

233 """ 

234 opset = get_onnx_opset(self.model_onnx) 

235 so = SessionOptions() 

236 so.log_severity_level = 4 

237 

238 n = len(self.weights_to_train) 

239 

240 # loss_grad 

241 self.learning_loss.build_onnx_function( 

242 opset, self.device, self.weight_name) 

243 

244 # weight update 

245 self.learning_rate.build_onnx_function(opset, self.device, n) 

246 

247 # regularization 

248 self.learning_penalty.build_onnx_function(opset, self.device, n) 

249 

250 # zero 

251 self.zero_onnx_ = function_onnx_graph("zero") 

252 self.zero_sess_ = InferenceSession( 

253 self.zero_onnx_.SerializeToString(), so, 

254 providers=device_to_providers(self.device)) 

255 

256 # logging 

257 if self.enable_logging: 

258 self._logger = logging.getLogger("onnxcustom") 

259 else: 

260 self._logger = None 

261 

262 def fit(self, X, y, sample_weight=None, 

263 X_val=None, y_val=None): 

264 """ 

265 Trains the model. 

266 

267 :param X: features 

268 :param y: expected output 

269 :param sample_weight: training weight or None 

270 :param X_val: evaluation dataset 

271 :param y_val: evaluation dataset 

272 :return: self 

273 """ 

274 if self.training_optimizer_name != 'SGDOptimizer': 

275 raise NotImplementedError( 

276 "Only the SGDOptimizer is implemented not %r." 

277 "" % self.training_optimizer_name) 

278 logger = self._logger 

279 

280 session_function = self._create_training_session( 

281 self.model_onnx, self.weights_to_train, 

282 device=self.device) 

283 self.train_session_ = session_function[0] 

284 self.train_function_ = session_function[1] 

285 

286 self.input_names_ = self.train_session_.cls_type_._grad_input_names 

287 self.output_names_ = self.train_session_.cls_type_._bw_fetches_names 

288 weights_to_train = self.train_session_.weights_to_train 

289 

290 if logger is not None: 

291 logger.info( 

292 "[OrtGradientForwardBackwardOptimizer.fit] " 

293 "input_names=%r", self.input_names_) 

294 logger.info( 

295 "[OrtGradientForwardBackwardOptimizer.fit] " 

296 "output_names=%r", self.output_names_) 

297 logger.info( 

298 "[OrtGradientForwardBackwardOptimizer.fit] " 

299 "weights_to_train=%r", self.weights_to_train) 

300 logger.info( 

301 "[OrtGradientForwardBackwardOptimizer.fit] " 

302 "device=%r|%r", 

303 self.device.device_id(), self.device.device_type()) 

304 if logger is not None: 

305 logger.info( 

306 "[OrtGradientForwardBackwardOptimizer.fit] " 

307 "warm_start=%r", self.warm_start) 

308 

309 if not hasattr(self, 'state_'): 

310 self.set_state([ 

311 self.train_session_.get_initializer(name, exc=False) 

312 for name in self.input_names_]) 

313 if self.needs_grad and not hasattr(self, 'state_grad_'): 

314 self.set_state([ 

315 self.train_session_.get_initializer(name, exc=False) 

316 for name in self.input_names_], 

317 kind='grad', zero=True) 

318 if not self.warm_start: 

319 state = self.get_full_state() 

320 if len(state) != len(self.input_names_): 

321 raise RuntimeError( # pragma: no cover 

322 f"Length mismatch {len(state)!r} != {len(self.input_names_)!r}.") 

323 new_state = [] 

324 for iv, v in enumerate(state): 

325 if v is None: 

326 new_state.append(v) 

327 else: 

328 if not isinstance(v, C_OrtValue): 

329 raise RuntimeError( # pragma: no cover 

330 "Unexpected type %r (state[%d])." % ( 

331 type(v), iv)) 

332 dtype = proto_type_to_dtype( 

333 v.proto_type() 

334 if hasattr(v, 'proto_type') 

335 else v.data_type()) 

336 if len(v.shape()) > 0: 

337 new_state.append( 

338 numpy.random.randn(*v.shape()).astype(dtype)) 

339 else: 

340 new_state.append( 

341 numpy.random.randn(1).astype(dtype)) 

342 self.set_state(new_state) 

343 if self.needs_grad: 

344 self.set_state(new_state, kind='grad', zero=True) 

345 

346 data_loader = OrtDataLoader( 

347 X, y, sample_weight, batch_size=self.batch_size, 

348 device=self.device) 

349 if X_val is not None: 

350 data_loader_val = OrtDataLoader( 

351 X_val, y_val, batch_size=X_val.shape[0], device=self.device, 

352 random_iter=False) 

353 else: 

354 data_loader_val = None 

355 

356 self.learning_rate.init_learning_rate() 

357 

358 if self.verbose > 0: # pragma: no cover 

359 from tqdm import tqdm # pylint: disable=C0415 

360 loop = tqdm(range(self.max_iter)) 

361 else: 

362 loop = range(self.max_iter) 

363 

364 self.train_losses_ = [] 

365 val_losses = [] 

366 kinds = ['weight', 'grad'] if self.needs_grad else ['weight'] 

367 for it in loop: 

368 loss = self._iteration( 

369 data_loader, self.get_full_state(kind=kinds), 

370 len(weights_to_train)) 

371 lr = self.learning_rate.update_learning_rate(it).value 

372 if self.verbose > 1: # pragma: no cover 

373 loop.set_description( 

374 "loss=%1.3g lr=%1.3g" % ( # pylint: disable=E1101,E1307 

375 loss, lr)) # pylint: disable=E1101,E1307 

376 if logger is not None: 

377 logger.info( 

378 "[OrtGradientForwardBackwardOptimizer.fit] " 

379 "lr value=%r", lr) 

380 

381 self.train_losses_.append(loss) 

382 if (data_loader_val is not None and 

383 (it + 1) % self.validation_every == 0): 

384 val_losses.append( 

385 self._evaluation(data_loader_val, self.get_full_state())) 

386 self.validation_losses_ = ( 

387 None if data_loader_val is None else val_losses) 

388 

389 if logger is not None: 

390 logger.info( 

391 "[OrtGradientForwardBackwardOptimizer.fit] " 

392 "end loss=%r", self.train_losses_[-1]) 

393 return self 

394 

395 def _iteration(self, data_loader, states, n_weights): 

396 actual_losses = [] 

397 bs = data_loader.batch_size 

398 logger = self._logger 

399 if len(states) == 1: 

400 state = states[0] 

401 grad = None 

402 else: 

403 state, grad = states 

404 

405 if logger is not None: 

406 logger.debug( 

407 "[OrtGradientForwardBackwardOptimizer._iteration] " 

408 "iteration begin learning_rate=%r", 

409 self.learning_rate) 

410 

411 prediction_cache = None 

412 prediction_cache_shape = None 

413 backward_outputs_cache = None 

414 for ib, ito in enumerate(data_loader.iter_ortvalue()): 

415 if len(ito) == 2: 

416 (ortx, orty) = ito 

417 ortw = None 

418 else: 

419 (ortx, orty, ortw) = ito 

420 state[0] = ortx 

421 

422 if logger is not None: 

423 logger.debug( 

424 "[OrtGradientForwardBackwardOptimizer._iteration] " 

425 "batch %d", ib) 

426 

427 ortx_shape = tuple(ortx.shape()) 

428 same_shape = ( 

429 prediction_cache_shape is not None and 

430 ortx_shape == prediction_cache_shape) 

431 

432 if logger is not None: 

433 logger.debug( 

434 "[OrtGradientForwardBackwardOptimizer._iteration] forward") 

435 

436 # forward 

437 if prediction_cache_shape is None or same_shape: 

438 prediction_cache = None 

439 prediction_cache_shape = None 

440 prediction = self.train_function_.forward( 

441 states[0], training=True, 

442 forward_outputs_cache=prediction_cache) 

443 prediction_cache = prediction 

444 prediction_cache_shape = ortx_shape 

445 

446 if logger is not None: 

447 logger.debug( 

448 "[OrtGradientForwardBackwardOptimizer._iteration] " 

449 "loss types=%r,%r", 

450 orty.data_type(), prediction[0].data_type()) 

451 

452 # loss 

453 loss, loss_gradient = self.learning_loss.loss_gradient( 

454 self.device, orty, prediction[0], weight=ortw) 

455 

456 if logger is not None: 

457 logger.debug( 

458 "[OrtGradientForwardBackwardOptimizer._iteration] " 

459 "loss=%g has_weight=%r", 

460 loss.numpy(), ortw is not None) 

461 

462 n = len(state) - n_weights 

463 loss = self.learning_penalty.penalty_loss( 

464 self.device, loss, *state[n:]) 

465 

466 cpu_loss = loss.numpy() 

467 

468 if logger is not None: 

469 logger.debug( 

470 "[OrtGradientForwardBackwardOptimizer._iteration] " 

471 "cpu_loss=%r", cpu_loss) 

472 

473 if numpy.isinf(cpu_loss) or numpy.isnan(cpu_loss): 

474 if self.exc: 

475 raise ConvergenceError( 

476 "Loss is nan, learning_rate=%r, " 

477 "the gradient descent has failed " 

478 "(past losses=%r)." % ( 

479 self.learning_rate, 

480 [float(v) for v in ( 

481 actual_losses if len(actual_losses) < 5 

482 else actual_losses[-5:])])) 

483 warnings.warn( # pragma: no cover 

484 "Loss is nan, learning_rate=%r, " 

485 "the gradient descent has failed " 

486 "(past losses=%r)." % ( 

487 self.learning_rate, 

488 [float(v) for v in ( 

489 actual_losses if len(actual_losses) < 5 

490 else actual_losses[-5:])]), 

491 ConvergenceWarning) 

492 if numpy.isinf(cpu_loss): # pragma: no cover 

493 cpu_loss = numpy.nan 

494 

495 # backward 

496 if not same_shape: 

497 backward_outputs_cache = None 

498 gradient = self.train_function_.backward( 

499 [loss_gradient], backward_outputs_cache=backward_outputs_cache) 

500 backward_outputs_cache = gradient 

501 

502 if len(gradient) != len(state): 

503 raise RuntimeError( # pragma: no cover 

504 "gradient and state should have the same length but " 

505 "%r != %r." % (len(gradient), len(state))) 

506 

507 n = len(state) - n_weights 

508 

509 for i in range(n, len(state)): 

510 self.learning_penalty.update_weights( 

511 i - n, self.device, state[i]) 

512 self.learning_rate.update_weights( 

513 i - n, self.device, state[i], 

514 gradient[i], bs, 

515 None if grad is None else grad[i]) 

516 

517 if logger is not None: 

518 logger.debug( 

519 "[OrtGradientForwardBackwardOptimizer._iteration] " 

520 "loss=%g n_weights=%d", cpu_loss, n) 

521 for i in range(n, len(state)): 

522 logger.debug( 

523 "[OrtGradientForwardBackwardOptimizer._iteration] " 

524 "state[%i]=%s", i, str_ortvalue(state[i])) 

525 

526 actual_losses.append(cpu_loss / bs) 

527 

528 if logger is not None: 

529 logger.debug( 

530 "[OrtGradientForwardBackwardOptimizer._iteration] " 

531 "iteration end") 

532 

533 return numpy.array(actual_losses).mean() 

534 

535 def _evaluation(self, data_loader, state): 

536 logger = self._logger 

537 actual_losses = [] 

538 for ib, (ortx, orty) in enumerate(data_loader.iter_ortvalue()): 

539 state[0] = ortx 

540 

541 if logger is not None: 

542 logger.debug( # pragma: no cover 

543 "[OrtGradientForwardBackwardOptimizer._evaluation] " 

544 "batch %d", ib) 

545 

546 prediction = self.train_function_.forward(state, training=False) 

547 loss, _ = self.learning_loss.loss_gradient( 

548 self.device, orty, prediction[0]) 

549 cpu_loss = loss.numpy() 

550 if numpy.isinf(cpu_loss) or numpy.isnan(cpu_loss): 

551 if self.exc: # pragma: no cover 

552 raise ConvergenceError( 

553 "Loss is nan, " 

554 "the evaluation has failed " 

555 "(past losses=%r)." % 

556 [float(v) for v in ( 

557 actual_losses if len(actual_losses) < 5 

558 else actual_losses[-5:])]) 

559 warnings.warn( 

560 "Loss is nan, learning_rate=%r, " 

561 "the gradient descent has failed " 

562 "(past losses=%r)." % ( 

563 self.learning_rate, 

564 [float(v) for v in ( 

565 actual_losses if len(actual_losses) < 5 

566 else actual_losses[-5:])]), 

567 ConvergenceWarning) 

568 if numpy.isinf(cpu_loss): 

569 cpu_loss = numpy.nan 

570 actual_losses.append(cpu_loss) 

571 

572 return numpy.array(actual_losses).sum() / len(data_loader) 

573 

574 def score(self, X, y, sample_weight=None): 

575 """ 

576 Return the whole score associated. 

577 

578 :param X: features 

579 :param y: expected output 

580 :param sample_weight: training weight or None 

581 :return: score 

582 """ 

583 scores = self.losses(X, y, sample_weight=sample_weight) 

584 return -scores.sum() / X.shape[0] 

585 

586 def losses(self, X, y, sample_weight=None): 

587 """ 

588 Returns the losses associated to every observation. 

589 

590 :param X: features 

591 :param y: expected output 

592 :param sample_weight: training weight or None 

593 :return: scores 

594 """ 

595 data_loader = OrtDataLoader( 

596 X, y, sample_weight, batch_size=self.batch_size, 

597 device=self.device) 

598 

599 state = self.get_full_state() 

600 scores = numpy.empty((X.shape[0], ), dtype=X.dtype) 

601 pos = 0 

602 for ito in data_loader.iter_ortvalue(): 

603 if len(ito) == 2: 

604 (ortx, orty) = ito 

605 ortw = None 

606 else: 

607 (ortx, orty, ortw) = ito 

608 state[0] = ortx 

609 prediction = self.train_function_.forward(state, training=False) 

610 score = self.learning_loss.loss_scores( 

611 self.device, orty, prediction[0], ortw) 

612 np_score = score.numpy() 

613 # data copy could be avoided by giving a pointer to 

614 # loss score or if we could create an OrtValue from a 

615 # pointer. 

616 end = pos + np_score.shape[0] 

617 if end <= scores.shape[0]: 

618 scores[pos: end] = np_score.ravel() 

619 else: 

620 scores[pos: end] = np_score.ravel()[end - scores.shape[0]:] 

621 pos += np_score.shape[0] 

622 return scores 

623 

624 def _create_training_session( 

625 self, model_onnx, weights_to_train, device): 

626 

627 forback = OrtGradientForwardBackward( 

628 model_onnx, weights_to_train=weights_to_train, 

629 debug=False, enable_logging=False, 

630 providers=device_to_providers(device)) 

631 inst = forback.new_instance() 

632 return (forback, inst)