module training.grad_helper
#
Short summary#
module onnxcustom.training.grad_helper
ONNX and gradient.
Classes#
class |
truncated documentation |
---|---|
Options defining how to build the onnx graph of the gradients. |
Functions#
function |
truncated documentation |
---|---|
Guesses default inputs (float ones) if not specified. |
|
Implements a gradient based on class OrtModuleGraphBuilder. |
|
Implements a gradient based on class PyGradientGraphBuilder. |
|
Builds the gradient for an onnx graph. |
Documentation#
ONNX and gradient.
- class onnxcustom.training.grad_helper.DerivativeOptions(value)#
Bases:
IntFlag
Options defining how to build the onnx graph of the gradients.
Zero: default option, all options are disabled
KeepYieldOp: keeps the operator YieldOp in the graph, see
onnx_derivative
KeepOutputs: keeps the output of the original graph
FillGrad: does not add any output to specify the gradient of the output but assumes it is one
Loss: the function assumes the loss was added to the graph
- onnxcustom.training.grad_helper._default_inputs(onx)#
Guesses default inputs (float ones) if not specified.
- onnxcustom.training.grad_helper._onnx_derivative_fw(onx, weights, inputs, options)#
Implements a gradient based on class OrtModuleGraphBuilder.
- onnxcustom.training.grad_helper._onnx_derivative_loss(onx, weights, inputs, options, loss, label, path_name)#
Implements a gradient based on class PyGradientGraphBuilder.
- onnxcustom.training.grad_helper.onnx_derivative(onx, weights=None, inputs=None, options=DerivativeOptions.Zero, loss=None, label=None, path_name=None)#
Builds the gradient for an onnx graph.
- Parameters:
onx – onnx graph
weights – gradient against those weights, None for all real weights
inputs – gradient against inputs, None for all real inputs
options – options of type
DerivativeOptions
loss – loss output in case a loss was added in the graph, options must be equal to DerivativeOptions.Loss
label – if loss is specified, then the label must be specified as well
path_name – if options equal to DerivativeOptions.Loss, the gradient is saved to that path
- Returns:
onnx graph
The function calls OrtModuleGraphBuilderConfiguration from onnxruntime-training. This graph is meant to be used with
OrtGradientForwardBackward
and includes operator YieldOp. That’s the graph looks this way:These operators are the outputs of the initial graph and must be replaced by the gradient of these outputs to compute the gradient of the weights and the inputs. After they are replaced, it looks this way:
The user can still compute the outputs.
The input gradient can be filled with a constant matrix filled with one and with the expected shape.