module training.sgd_learning_penalty#

Inheritance diagram of onnxcustom.training.sgd_learning_penalty

Short summary#

module onnxcustom.training.sgd_learning_penalty

Helper for onnxruntime-training.

source on GitHub

Classes#

class

truncated documentation

BaseLearningPenalty

Class handling the penalty on the coefficients for class OrtGradientForwardBackwardOptimizer.

ElasticLearningPenalty

Implements a L1 or L2 regularization on weights.

NoLearningPenalty

No regularization.

Static Methods#

staticmethod

truncated documentation

select

Returns an instance of a given initialized with kwargs.

select

Returns an instance of a given initialized with kwargs.

select

Returns an instance of a given initialized with kwargs.

Methods#

method

truncated documentation

__init__

__init__

__init__

_call_iobinding

_call_iobinding

_call_iobinding

build_onnx_function

build_onnx_function

penalty_loss

Returns the received loss. Updates the loss inplace.

penalty_loss

Computes the penalty associated to every weights and adds them up to the loss.

penalty_loss

Returns the received loss. Updates the loss inplace.

update_weights

Returns the received loss. Updates the weight inplace.

update_weights

update_weights

Returns the received loss. Updates the weight inplace.

Documentation#

Helper for onnxruntime-training.

source on GitHub

class onnxcustom.training.sgd_learning_penalty.BaseLearningPenalty#

Bases: BaseLearningOnnx

Class handling the penalty on the coefficients for class OrtGradientForwardBackwardOptimizer.

source on GitHub

__init__()#
_call_iobinding(sess, bind)#
penalty_loss(device, loss, *weights)#

Returns the received loss. Updates the loss inplace.

Parameters:
  • device – device where the training takes place

  • loss – loss without penalty

  • weights – any weights to be penalized

Returns:

loss

source on GitHub

static select(class_name, **kwargs)#

Returns an instance of a given initialized with kwargs. :param class_name: an instance of BaseLearningPenalty

or a string among the following class names (see below)

Returns:

instance of BaseLearningPenalty

Possible values for class_name: * None or ‘penalty’: see L1L2PenaltyLearning

source on GitHub

update_weights(device, statei)#

Returns the received loss. Updates the weight inplace.

Parameters:
  • device – device where the training takes place

  • statei – loss without penalty

Returns:

weight

source on GitHub

class onnxcustom.training.sgd_learning_penalty.ElasticLearningPenalty(l1=0.5, l2=0.5)#

Bases: BaseLearningPenalty

Implements a L1 or L2 regularization on weights.

source on GitHub

__init__(l1=0.5, l2=0.5)#
build_onnx_function(opset, device, n_tensors)#

This class computes a function represented as an ONNX graph. This method builds it. This function creates InferenceSession which do that.

Parameters:
  • opset – opset to use

  • deviceC_OrtDevice

  • args – additional arguments

source on GitHub

penalty_loss(device, *inputs)#

Computes the penalty associated to every weights and adds them up to the loss.

Parameters:
  • device – device where the training takes place

  • inputs – loss without penalty and weights

Returns:

loss + penatlies

source on GitHub

update_weights(n_bind, device, statei)#

Returns the received loss. Updates the weight inplace.

Parameters:
  • device – device where the training takes place

  • statei – loss without penalty

Returns:

weight

source on GitHub

class onnxcustom.training.sgd_learning_penalty.NoLearningPenalty#

Bases: BaseLearningPenalty

No regularization.

source on GitHub

__init__()#
build_onnx_function(opset, device, n_tensors)#

This class computes a function represented as an ONNX graph. This method builds it. This function creates InferenceSession which do that.

Parameters:
  • opset – opset to use

  • deviceC_OrtDevice

  • args – additional arguments

source on GitHub

penalty_loss(device, loss, *weights)#

Returns the received loss. Updates the loss inplace.

Parameters:
  • device – device where the training takes place

  • loss – loss without penalty

  • weights – any weights to be penalized

Returns:

loss

source on GitHub

update_weights(n_bind, device, statei)#

Returns the received loss. Updates the weight inplace.

Parameters:
  • device – device where the training takes place

  • statei – loss without penalty

Returns:

weight

source on GitHub