module training.grad_helper#

Inheritance diagram of onnxcustom.training.grad_helper

Short summary#

module onnxcustom.training.grad_helper

ONNX and gradient.

source on GitHub

Classes#

class

truncated documentation

DerivativeOptions

Options defining how to build the onnx graph of the gradients.

Functions#

function

truncated documentation

_default_inputs

Guesses default inputs (float ones) if not specified.

_onnx_derivative_fw

Implements a gradient based on class OrtModuleGraphBuilder.

_onnx_derivative_loss

Implements a gradient based on class PyGradientGraphBuilder.

onnx_derivative

Builds the gradient for an onnx graph.

Documentation#

ONNX and gradient.

source on GitHub

class onnxcustom.training.grad_helper.DerivativeOptions(value)#

Bases: IntFlag

Options defining how to build the onnx graph of the gradients.

  • Zero: default option, all options are disabled

  • KeepYieldOp: keeps the operator YieldOp in the graph, see onnx_derivative

  • KeepOutputs: keeps the output of the original graph

  • FillGrad: does not add any output to specify the gradient of the output but assumes it is one

  • Loss: the function assumes the loss was added to the graph

source on GitHub

onnxcustom.training.grad_helper._default_inputs(onx)#

Guesses default inputs (float ones) if not specified.

onnxcustom.training.grad_helper._onnx_derivative_fw(onx, weights, inputs, options)#

Implements a gradient based on class OrtModuleGraphBuilder.

source on GitHub

onnxcustom.training.grad_helper._onnx_derivative_loss(onx, weights, inputs, options, loss, label, path_name)#

Implements a gradient based on class PyGradientGraphBuilder.

source on GitHub

onnxcustom.training.grad_helper.onnx_derivative(onx, weights=None, inputs=None, options=DerivativeOptions.Zero, loss=None, label=None, path_name=None)#

Builds the gradient for an onnx graph.

Parameters:
  • onx – onnx graph

  • weights – gradient against those weights, None for all real weights

  • inputs – gradient against inputs, None for all real inputs

  • options – options of type DerivativeOptions

  • loss – loss output in case a loss was added in the graph, options must be equal to DerivativeOptions.Loss

  • label – if loss is specified, then the label must be specified as well

  • path_name – if options equal to DerivativeOptions.Loss, the gradient is saved to that path

Returns:

onnx graph

The function calls OrtModuleGraphBuilderConfiguration from onnxruntime-training. This graph is meant to be used with OrtGradientForwardBackward and includes operator YieldOp. That’s the graph looks this way:

These operators are the outputs of the initial graph and must be replaced by the gradient of these outputs to compute the gradient of the weights and the inputs. After they are replaced, it looks this way:

The user can still compute the outputs.

The input gradient can be filled with a constant matrix filled with one and with the expected shape.

source on GitHub