module training.optimizers_partial#

Inheritance diagram of onnxcustom.training.optimizers_partial

Short summary#

module onnxcustom.training.optimizers_partial

Optimizer with onnxruntime-training forward backward training.

source on GitHub

Classes#

class

truncated documentation

OrtGradientForwardBackwardOptimizer

Implements a simple Stochastic Gradient Descent with onnxruntime-training. It leverages class OrtGradientForwardBackward. …

Properties#

property

truncated documentation

needs_grad

Returns the True if the gradient update needs to retain past gradients.

trained_coef_

Returns the trained coefficients a dictionary.

Methods#

method

truncated documentation

__getstate__

Removes any non pickable attribute.

__init__

__setstate__

Restores any non pickable attribute.

_create_training_session

_evaluation

_get_att_state

_iteration

build_onnx_function

Creates ONNX graph and InferenceSession related to any operations applying on OrtValue.

fit

Trains the model.

get_full_state

Returns the trained weights and the inputs.

get_state

Returns the trained weights.

get_trained_onnx

Returns the trained onnx graph, the initial graph modified by replacing the initializers with the trained …

losses

Returns the losses associated to every observation.

score

Return the whole score associated.

set_state

Changes the trained weights.

Documentation#

Optimizer with onnxruntime-training forward backward training.

source on GitHub

class onnxcustom.training.optimizers_partial.OrtGradientForwardBackwardOptimizer(model_onnx, weights_to_train=None, loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=10, learning_rate='SGD', device='cpu', warm_start=False, verbose=0, validation_every=0.1, learning_loss='square_error', enable_logging=False, weight_name=None, learning_penalty=None, exc=True)#

Bases: BaseEstimator

Implements a simple Stochastic Gradient Descent with onnxruntime-training. It leverages class OrtGradientForwardBackward.

Parameters:
  • model_onnx – onnx graph to train

  • weights_to_train – names of initializers to be optimized, if None, function get_train_initialize() returns the list of float iniitializer

  • loss_output_name – name of the loss output

  • max_iter – number of training iterations

  • training_optimizer_name – optimizing algorithm

  • batch_size – batch size (see class DataLoader)

  • learning_rate – a name or a learning rate instance or a float, see module onnxcustom.training.sgd_learning_rate

  • device – device as C_OrtDevice or a string representing this device

  • warm_start – when set to True, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution.

  • learning_loss – loss function (see below)

  • verbose – use tqdm to display the training progress

  • validation_every – validation with a test set every validation_every iterations

  • enable_logging – enable logging (mostly for debugging puporse as it slows down the training)

  • weight_name – if not None, the class assumes it is trained with training weight

  • learning_penalty – weight penalty, None, or instance of BaseLearningPenalty

  • exc – raise exceptions (about convergence for example) or keep them silent as much as possible

learning_rate can be any instance of BaseLearningRate or a nick name in the following list as specified in BaseLearningRate.select.

learning_loss can be any instance of BaseLearningLoss or a nick name in the following list as specified in BaseLearningLoss.select.

source on GitHub

__getstate__()#

Removes any non pickable attribute.

__init__(model_onnx, weights_to_train=None, loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=10, learning_rate='SGD', device='cpu', warm_start=False, verbose=0, validation_every=0.1, learning_loss='square_error', enable_logging=False, weight_name=None, learning_penalty=None, exc=True)#
__setstate__(state)#

Restores any non pickable attribute.

_create_training_session(model_onnx, weights_to_train, device)#
_evaluation(data_loader, state)#
_get_att_state(kind)#
_iteration(data_loader, states, n_weights)#
build_onnx_function()#

Creates ONNX graph and InferenceSession related to any operations applying on OrtValue.

source on GitHub

fit(X, y, sample_weight=None, X_val=None, y_val=None)#

Trains the model.

Parameters:
  • X – features

  • y – expected output

  • sample_weight – training weight or None

  • X_val – evaluation dataset

  • y_val – evaluation dataset

Returns:

self

source on GitHub

get_full_state(kind='weight')#

Returns the trained weights and the inputs.

source on GitHub

get_state(kind='weight')#

Returns the trained weights.

source on GitHub

get_trained_onnx(model=None)#

Returns the trained onnx graph, the initial graph modified by replacing the initializers with the trained weights.

Parameters:

model – replace the weights in another graph than the training graph

Returns:

onnx graph

source on GitHub

losses(X, y, sample_weight=None)#

Returns the losses associated to every observation.

Parameters:
  • X – features

  • y – expected output

  • sample_weight – training weight or None

Returns:

scores

source on GitHub

property needs_grad#

Returns the True if the gradient update needs to retain past gradients.

source on GitHub

score(X, y, sample_weight=None)#

Return the whole score associated.

Parameters:
  • X – features

  • y – expected output

  • sample_weight – training weight or None

Returns:

score

source on GitHub

set_state(state, check_trained=True, kind='weight', zero=False)#

Changes the trained weights.

source on GitHub

property trained_coef_#

Returns the trained coefficients a dictionary.

source on GitHub