module training.ortgradient
#
Short summary#
module onnxcustom.training.ortgradient
Gradient with onnxruntime-training forward backward.
Classes#
class |
truncated documentation |
---|---|
Implements forward backward mechanism assuming the function to train is defined by an ONNX graph. |
|
Ancestor for a class implementing forward and backward and dynamically created by |
Properties#
property |
truncated documentation |
---|---|
Returns saved tensors during forward step. |
Static Methods#
staticmethod |
truncated documentation |
---|---|
used to improve logging messages |
|
Selects all initializers with float type. |
|
Returns the device name of a device. |
|
Converts a list of tensos into an OrtValueVector. |
|
Saves onnx graph stored in this class. |
Methods#
method |
truncated documentation |
---|---|
Removes any non pickable attribute. |
|
usual |
|
Restores any non pickable attribute. |
|
Creates forward and backward ONNX graph. The new class has the following attributes: |
|
Implements backward function. The function returns an OrtValueVector. |
|
Implements forward function. |
|
Returns an initializer as numpy arrays. |
|
Creates an instance of class self.cls_type_. It implements methods forward and backward. |
|
Saves inputs furing forward steps. The list inputs is copied (simple copy, no deep copy). |
Documentation#
Gradient with onnxruntime-training forward backward.
- class onnxcustom.training.ortgradient.OrtGradientForwardBackward(onnx_model, weights_to_train=None, input_names=None, output_names=None, class_name=None, sess_options=None, providers=None, provider_options=None, run_options=None, graph_builder_config=None, device_index=0, enable_logging=False, debug=False)#
Bases:
object
Implements forward backward mechanism assuming the function to train is defined by an ONNX graph.
- Parameters:
onnx_model – onnx model
weights_to_train – names of the weights to train, if None, all initializer of floats type are included in the list
input_names – input names or None for all
output_names – output names or None for all
class_name – name to give the class dynamically created
sess_options – see SessionOptions
providers – see InferenceSession
provider_options – see InferenceSession
run_options – see RunOptions
graph_builder_config – see OrtModuleGraphBuilderConfiguration
device_index – used for cuda (0 for cuda:0, cuda:1, …), 0 by default
enable_logging – enables logging while setting up the class
debug – to run extra verification while training
Note
The current implementation of onnxruntime forces the weights to train to appear in the alphabetical order. The constructor checks that condition is verified.
Warning
This class does not consider subgraphs.
- __getstate__()#
Removes any non pickable attribute.
- __init__(onnx_model, weights_to_train=None, input_names=None, output_names=None, class_name=None, sess_options=None, providers=None, provider_options=None, run_options=None, graph_builder_config=None, device_index=0, enable_logging=False, debug=False)#
- __repr__()#
usual
- __setstate__(state)#
Restores any non pickable attribute.
- _create_onnx_graphs()#
Creates forward and backward ONNX graph. The new class has the following attributes:
__doc__: doc string
__module__: module name (this file)
_run_options: see RunOptions
_sess: InferenceSession with the original graph
- _sess_eval: InferenceSession on the graph
with weights as inputs
_training_agent: TrainingAgent
_cache: OrtValueCache
_logger: logger
_input_names: input names
_debug: use debug mode
_grad_input_names: gradient input names
_output_names: output names
_weights_to_train: names of the weights to train
Training attributes
_bw_fetches_names: bw_fetches_names,
_fw_outputs_device_info: fw_outputs_device_info,
_bw_outputs_device_info: bw_outputs_device_info,
_fw_no_grad_output_device_info: fw_no_grad_output_device_info,
_graph_info: graph_info}
Additional attributes added if keep_model is True:
_trained_onnx: ONNX graph for the gradient
- _optimized_pre_grad_model: evaluation ONNX graph taking
weights as inputs
_graph_builder: OrtModuleGraphBuilder
- _init_next()#
- static _provider_name_to_device_type(provider_name)#
- static _repr_helper_(obj, indent=0)#
used to improve logging messages
- static _select_initializer_names(onnx_model)#
Selects all initializers with float type.
- Parameters:
onnx_model – ONNX graph
- get_initializer(name, exc=True)#
Returns an initializer as numpy arrays.
- Parameters:
name – initializer name
exc – raises an exception if not found or return None
- Returns:
the initializer as a C_OrtValue
- new_instance()#
Creates an instance of class self.cls_type_. It implements methods forward and backward.
- class onnxcustom.training.ortgradient.OrtGradientForwardBackwardFunction#
Bases:
object
Ancestor for a class implementing forward and backward and dynamically created by
OrtGradientForwardBackward
.Attributes stored in forward method: * saved_tensors_: list of tensors to save during forward
and to retrieve during backward
state_: current weights stored in PartialGraphExecutionState
- __init__()#
- backward(grad_outputs, backward_outputs_cache=None)#
Implements backward function. The function returns an OrtValueVector.
- static device_name(device)#
Returns the device name of a device.
- Parameters:
device – OrtDevice
- Returns:
string
- forward(inputs, training=False, forward_outputs_cache=None)#
Implements forward function.
- Parameters:
inputs – inputs
training – only inference or training as well
- Returns:
output as OrtValueVector
- static input_to_ort(tensors, devices, debug)#
Converts a list of tensos into an OrtValueVector.
- save_for_backward(inputs)#
Saves inputs furing forward steps. The list inputs is copied (simple copy, no deep copy).
- Parameters:
inputs – list of tensors to save.
- classmethod save_onnx_graph(folder, prefix=None, suffix=None)#
Saves onnx graph stored in this class.
- property saved_tensors#
Returns saved tensors during forward step.