module training.optimizers
#
Short summary#
module onnxcustom.training.optimizers
Optimizer with onnxruntime-training.
Classes#
class |
truncated documentation |
---|---|
Implements a simple Stochastic Gradient Descent with onnxruntime-training. |
Methods#
method |
truncated documentation |
---|---|
Binds C_OrtValue to the structure used by InferenceSession to run inference. |
|
Creates an instance of TrainingSession. |
|
Trains the model. |
|
Returns the trained weights. |
|
Returns the trained onnx graph, the initial graph modified by replacing the initializers with the trained … |
|
Changes the trained weights. |
Documentation#
Optimizer with onnxruntime-training.
- class onnxcustom.training.optimizers.OrtGradientOptimizer(model_onnx, weights_to_train, loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=10, learning_rate='SGD', device='cpu', warm_start=False, verbose=0, validation_every=0.1, saved_gradient=None, sample_weight_name='weight')#
Bases:
BaseEstimator
Implements a simple Stochastic Gradient Descent with onnxruntime-training.
- Parameters:
model_onnx – onnx graph to train
weights_to_train – names of initializers to be optimized
loss_output_name – name of the loss output
max_iter – number of training iterations
training_optimizer_name – optimizing algorithm
batch_size – batch size (see class DataLoader)
learning_rate – a name or a learning rate instance or a float, see module
onnxcustom.training.sgd_learning_rate
device – device as C_OrtDevice or a string representing this device
warm_start – when set to True, reuse the solution of the previous call to fit as initialization, otherwise, just erase the previous solution.
verbose – use tqdm to display the training progress
validation_every – validation with a test set every validation_every iterations
saved_gradient – if not None, a filename, the optimizer saves the gradient into it
sample_weight_name – name of the sample weight input
Once initialized, the class creates the attribute train_session_ which holds an instance of Python Wrapper for TrainingSession.
See example Train a scikit-learn neural network with onnxruntime-training on GPU.
- __init__(model_onnx, weights_to_train, loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=10, learning_rate='SGD', device='cpu', warm_start=False, verbose=0, validation_every=0.1, saved_gradient=None, sample_weight_name='weight')#
- _bind_input_ortvalue(name, bind, c_ortvalue)#
Binds C_OrtValue to the structure used by InferenceSession to run inference.
- Parameters:
name – str
bind – python structure
c_ortvalue – C structure for OrtValue (C_OrtValue), it can be also a numpy array
- _create_training_session(training_onnx, weights_to_train, loss_output_name='loss', training_optimizer_name='SGDOptimizer', device='cpu')#
Creates an instance of TrainingSession.
- Parameters:
training_onnx – an ONNX graph with a loss function
weights_to_train – list of initializer names to optimize
loss_output_name – output name for the loss
training_optimizer_name – optimizer name
device – one C_OrtDevice or a string
- Returns:
an instance of TrainingSession
- _evaluation(data_loader, bind)#
- _iteration(data_loader, ort_lr, bind, use_numpy, sample_weight)#
- fit(X, y, sample_weight=None, X_val=None, y_val=None, use_numpy=False)#
Trains the model.
- Parameters:
X – features
y – expected output
sample_weight – sample weight if any
X_val – evaluation dataset
y_val – evaluation dataset
use_numpy – if True, slow iterator using numpy, otherwise, minimizes copy
- Returns:
self
- get_state()#
Returns the trained weights.
- get_trained_onnx(model=None)#
Returns the trained onnx graph, the initial graph modified by replacing the initializers with the trained weights. If model is not specified, it uses the model given as an argument to this class. This graph outputs the loss and not the predictions. Parameter model can be used to use the graph before loss was added and then the returned graph will produce the predictions.
- Parameters:
model – replace the weights in another graph than the training graph
- Returns:
onnx graph
- set_state(state)#
Changes the trained weights.