module training.ortgradient#

Inheritance diagram of onnxcustom.training.ortgradient

Short summary#

module onnxcustom.training.ortgradient

Gradient with onnxruntime-training forward backward.

source on GitHub

Classes#

class

truncated documentation

OrtGradientForwardBackward

Implements forward backward mechanism assuming the function to train is defined by an ONNX graph.

OrtGradientForwardBackwardFunction

Ancestor for a class implementing forward and backward and dynamically created by OrtGradientForwardBackward. …

Properties#

property

truncated documentation

saved_tensors

Returns saved tensors during forward step.

Static Methods#

staticmethod

truncated documentation

_provider_name_to_device_type

_repr_helper_

used to improve logging messages

_select_initializer_names

Selects all initializers with float type.

device_name

Returns the device name of a device.

input_to_ort

Converts a list of tensos into an OrtValueVector.

save_onnx_graph

Saves onnx graph stored in this class.

Methods#

method

truncated documentation

__getstate__

Removes any non pickable attribute.

__init__

__init__

__repr__

usual

__setstate__

Restores any non pickable attribute.

_create_onnx_graphs

Creates forward and backward ONNX graph. The new class has the following attributes:

_init_next

backward

Implements backward function. The function returns an OrtValueVector.

forward

Implements forward function.

get_initializer

Returns an initializer as numpy arrays.

new_instance

Creates an instance of class self.cls_type_. It implements methods forward and backward.

save_for_backward

Saves inputs furing forward steps. The list inputs is copied (simple copy, no deep copy).

Documentation#

Gradient with onnxruntime-training forward backward.

source on GitHub

class onnxcustom.training.ortgradient.OrtGradientForwardBackward(onnx_model, weights_to_train=None, input_names=None, output_names=None, class_name=None, sess_options=None, providers=None, provider_options=None, run_options=None, graph_builder_config=None, device_index=0, enable_logging=False, debug=False)#

Bases: object

Implements forward backward mechanism assuming the function to train is defined by an ONNX graph.

Parameters:
  • onnx_model – onnx model

  • weights_to_train – names of the weights to train, if None, all initializer of floats type are included in the list

  • input_names – input names or None for all

  • output_names – output names or None for all

  • class_name – name to give the class dynamically created

  • sess_options – see SessionOptions

  • providers – see InferenceSession

  • provider_options – see InferenceSession

  • run_options – see RunOptions

  • graph_builder_config – see OrtModuleGraphBuilderConfiguration

  • device_index – used for cuda (0 for cuda:0, cuda:1, …), 0 by default

  • enable_logging – enables logging while setting up the class

  • debug – to run extra verification while training

Note

The current implementation of onnxruntime forces the weights to train to appear in the alphabetical order. The constructor checks that condition is verified.

Warning

This class does not consider subgraphs.

source on GitHub

__getstate__()#

Removes any non pickable attribute.

__init__(onnx_model, weights_to_train=None, input_names=None, output_names=None, class_name=None, sess_options=None, providers=None, provider_options=None, run_options=None, graph_builder_config=None, device_index=0, enable_logging=False, debug=False)#
__repr__()#

usual

__setstate__(state)#

Restores any non pickable attribute.

_create_onnx_graphs()#

Creates forward and backward ONNX graph. The new class has the following attributes:

  • __doc__: doc string

  • __module__: module name (this file)

  • _run_options: see RunOptions

  • _sess: InferenceSession with the original graph

  • _sess_eval: InferenceSession on the graph

    with weights as inputs

  • _training_agent: TrainingAgent

  • _cache: OrtValueCache

  • _logger: logger

  • _input_names: input names

  • _debug: use debug mode

  • _grad_input_names: gradient input names

  • _output_names: output names

  • _weights_to_train: names of the weights to train

Training attributes

  • _bw_fetches_names: bw_fetches_names,

  • _fw_outputs_device_info: fw_outputs_device_info,

  • _bw_outputs_device_info: bw_outputs_device_info,

  • _fw_no_grad_output_device_info: fw_no_grad_output_device_info,

  • _graph_info: graph_info}

Additional attributes added if keep_model is True:

  • _trained_onnx: ONNX graph for the gradient

  • _optimized_pre_grad_model: evaluation ONNX graph taking

    weights as inputs

  • _graph_builder: OrtModuleGraphBuilder

source on GitHub

_init_next()#
static _provider_name_to_device_type(provider_name)#
static _repr_helper_(obj, indent=0)#

used to improve logging messages

static _select_initializer_names(onnx_model)#

Selects all initializers with float type.

Parameters:

onnx_model – ONNX graph

source on GitHub

get_initializer(name, exc=True)#

Returns an initializer as numpy arrays.

Parameters:
  • name – initializer name

  • exc – raises an exception if not found or return None

Returns:

the initializer as a C_OrtValue

source on GitHub

new_instance()#

Creates an instance of class self.cls_type_. It implements methods forward and backward.

source on GitHub

class onnxcustom.training.ortgradient.OrtGradientForwardBackwardFunction#

Bases: object

Ancestor for a class implementing forward and backward and dynamically created by OrtGradientForwardBackward.

Attributes stored in forward method: * saved_tensors_: list of tensors to save during forward

and to retrieve during backward

source on GitHub

__init__()#
backward(grad_outputs, backward_outputs_cache=None)#

Implements backward function. The function returns an OrtValueVector.

source on GitHub

static device_name(device)#

Returns the device name of a device.

Parameters:

device – OrtDevice

Returns:

string

source on GitHub

forward(inputs, training=False, forward_outputs_cache=None)#

Implements forward function.

Parameters:
  • inputs – inputs

  • training – only inference or training as well

Returns:

output as OrtValueVector

source on GitHub

static input_to_ort(tensors, devices, debug)#

Converts a list of tensos into an OrtValueVector.

save_for_backward(inputs)#

Saves inputs furing forward steps. The list inputs is copied (simple copy, no deep copy).

Parameters:

inputs – list of tensors to save.

source on GitHub

classmethod save_onnx_graph(folder, prefix=None, suffix=None)#

Saves onnx graph stored in this class.

source on GitHub

property saved_tensors#

Returns saved tensors during forward step.

source on GitHub