module training.sgd_learning_penalty
#
Short summary#
module onnxcustom.training.sgd_learning_penalty
Helper for onnxruntime-training.
Classes#
class |
truncated documentation |
---|---|
Class handling the penalty on the coefficients for class |
|
Implements a L1 or L2 regularization on weights. |
|
No regularization. |
Static Methods#
staticmethod |
truncated documentation |
---|---|
Returns an instance of a given initialized with kwargs. |
|
|
Returns an instance of a given initialized with kwargs. |
|
Returns an instance of a given initialized with kwargs. |
Methods#
method |
truncated documentation |
---|---|
|
|
|
|
Returns the received loss. Updates the loss inplace. |
|
Computes the penalty associated to every weights and adds them up to the loss. |
|
Returns the received loss. Updates the loss inplace. |
|
Returns the received loss. Updates the weight inplace. |
|
Returns the received loss. Updates the weight inplace. |
Documentation#
Helper for onnxruntime-training.
- class onnxcustom.training.sgd_learning_penalty.BaseLearningPenalty#
Bases:
BaseLearningOnnx
Class handling the penalty on the coefficients for class
OrtGradientForwardBackwardOptimizer
.- __init__()#
- _call_iobinding(sess, bind)#
- penalty_loss(device, loss, *weights)#
Returns the received loss. Updates the loss inplace.
- Parameters:
device – device where the training takes place
loss – loss without penalty
weights – any weights to be penalized
- Returns:
loss
- static select(class_name, **kwargs)#
Returns an instance of a given initialized with kwargs. :param class_name: an instance of
BaseLearningPenalty
or a string among the following class names (see below)
- Returns:
instance of
BaseLearningPenalty
Possible values for class_name: * None or ‘penalty’: see
L1L2PenaltyLearning
- update_weights(device, statei)#
Returns the received loss. Updates the weight inplace.
- Parameters:
device – device where the training takes place
statei – loss without penalty
- Returns:
weight
- class onnxcustom.training.sgd_learning_penalty.ElasticLearningPenalty(l1=0.5, l2=0.5)#
Bases:
BaseLearningPenalty
Implements a L1 or L2 regularization on weights.
- __init__(l1=0.5, l2=0.5)#
- build_onnx_function(opset, device, n_tensors)#
This class computes a function represented as an ONNX graph. This method builds it. This function creates InferenceSession which do that.
- Parameters:
opset – opset to use
device – C_OrtDevice
args – additional arguments
- penalty_loss(device, *inputs)#
Computes the penalty associated to every weights and adds them up to the loss.
- Parameters:
device – device where the training takes place
inputs – loss without penalty and weights
- Returns:
loss + penatlies
- update_weights(n_bind, device, statei)#
Returns the received loss. Updates the weight inplace.
- Parameters:
device – device where the training takes place
statei – loss without penalty
- Returns:
weight
- class onnxcustom.training.sgd_learning_penalty.NoLearningPenalty#
Bases:
BaseLearningPenalty
No regularization.
- __init__()#
- build_onnx_function(opset, device, n_tensors)#
This class computes a function represented as an ONNX graph. This method builds it. This function creates InferenceSession which do that.
- Parameters:
opset – opset to use
device – C_OrtDevice
args – additional arguments
- penalty_loss(device, loss, *weights)#
Returns the received loss. Updates the loss inplace.
- Parameters:
device – device where the training takes place
loss – loss without penalty
weights – any weights to be penalized
- Returns:
loss
- update_weights(n_bind, device, statei)#
Returns the received loss. Updates the weight inplace.
- Parameters:
device – device where the training takes place
statei – loss without penalty
- Returns:
weight