Training API

Options and Parameters

TrainingParameters

class onnxruntime.TrainingParameters(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingParameters)

Configuration information for training.

property allreduce_post_accumulation
property attn_dropout_recompute
property data_parallel_size
property deepspeed_zero_stage
property enable_adasum
property enable_grad_norm_clip
property gelu_recompute
property gradient_accumulation_steps
property horizontal_parallel_size
property immutable_weights
property loss_output_name
property loss_scale
property lr_params_feed_name
property model_after_graph_transforms_path
property model_with_gradient_graph_path
property model_with_training_graph_path
property num_pipeline_micro_batches
property number_recompute_layers
property optimizer_attributes_map
property optimizer_int_attributes_map
property pipeline_cut_info_string
property pipeline_parallel_size
property propagate_cast_ops_allow
property propagate_cast_ops_level
property set_gradients_as_graph_outputs
set_optimizer_initial_state(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingParameters, arg0: Dict[str, Dict[str, object]]) None
property sliced_axes
property sliced_schema
property sliced_tensor_names
property training_optimizer_name
property transformer_layer_recompute
property use_fp16_moments
property use_memory_efficient_gradient
property use_mixed_precision
property weights_not_to_train
property weights_to_train
property world_rank
property world_size

Hidden API

GraphInfo

class onnxruntime.capi._pybind_state.GraphInfo(self: onnxruntime.capi.onnxruntime_pybind11_state.GraphInfo)

Bases: pybind11_object

The information of split graphs for frontend.

property cached_node_arg_names
property frontier_node_arg_map
property initializer_grad_names_to_train
property initializer_names
property initializer_names_to_train
property module_output_gradient_name
property module_output_indices_requires_save_for_backward
property output_grad_indices_non_differentiable
property output_grad_indices_require_full_shape
property user_input_grad_names
property user_input_names
property user_output_names

GradientNodeAttributeDefinition

class onnxruntime.capi._pybind_state.GradientNodeAttributeDefinition(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeAttributeDefinition)

Bases: pybind11_object

Attribute definition for gradient graph nodes.

property dtype
property is_tensor
property name
property value_json

GradientNodeDefinition

class onnxruntime.capi._pybind_state.GradientNodeDefinition(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeDefinition)

Bases: pybind11_object

Definition for gradient graph nodes.

property attributes
property domain
property inputs
property op_type
property outputs

GraphTransformerConfiguration

class onnxruntime.capi._pybind_state.GraphTransformerConfiguration(self: onnxruntime.capi.onnxruntime_pybind11_state.GraphTransformerConfiguration)

Bases: pybind11_object

Graph transformer configuration.

property propagate_cast_ops_config

OrtModuleGraphBuilder

class onnxruntime.capi._pybind_state.OrtModuleGraphBuilder(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder)

Bases: pybind11_object

build(*args, **kwargs)

Overloaded function.

  1. build(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) -> None

  2. build(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder, arg0: List[List[int]]) -> None

get_forward_model(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) bytes
get_gradient_model(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) bytes
get_graph_info(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) onnxruntime.capi.onnxruntime_pybind11_state.GraphInfo
initialize(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder, arg0: bytes, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilderConfiguration) None

OrtModuleGraphBuilderConfiguration

class onnxruntime.capi._pybind_state.OrtModuleGraphBuilderConfiguration(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilderConfiguration)

Bases: pybind11_object

Configuration information for module graph builder.

property build_gradient_graph
property enable_caching
property graph_transformer_config
property initializer_names
property initializer_names_to_train
property input_names_require_grad
property loglevel
property use_memory_efficient_gradient

OrtValueCache

class onnxruntime.capi._pybind_state.OrtValueCache(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache)

Bases: pybind11_object

clear(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None
count(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str) int
insert(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValue) None
keys(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) list
remove(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str) None

OrtValueVector

class onnxruntime.capi._pybind_state.OrtValueVector(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector)

Bases: pybind11_object

bool_tensor_indices(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector) List[int]

Returns the indices of every boolean tensor in this vector of OrtValue. In case of a boolean tensor, method to_dlpacks returns a uint8 tensor instead of a boolean tensor. If torch consumes the dlpack structure, .to(torch.bool) must be applied to the torch tensor to get a boolean tensor.

dlpack_at(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg0: int) object
element_type_at(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, idx: int) int

Returns an integer equal to the ONNX proto type of the tensor at position i. This integer is one type defined by ONNX TensorProto_DataType (such as onnx.TensorProto.FLOAT).Raises an exception in any other case.

push_back(*args, **kwargs)

Overloaded function.

  1. push_back(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg0: onnxruntime.capi.onnxruntime_pybind11_state.OrtValue) -> None

  2. push_back(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, dlpack_tensor: object, is_bool_tensor: bool = False) -> None

Add a new OrtValue after being ownership was transferred from the DLPack structure.

push_back_batch(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg0: List[object], arg1: List[int], arg2: List[object], arg3: List[List[int]], arg4: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice]) None

Add a batch of OrtValue’s by wrapping PyTorch tensors.

reserve(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg0: int) None
shrink_to_fit(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector) None
to_dlpacks(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, to_tensor: object) list

Converts all OrtValue into tensors through DLPack protocol, the method creates a DLPack structure for every tensors, then calls python function to_tensor to a new object consuming the DLPack structure or return a list of capsule if this function is None.

Parameters:

to_tensor – this function takes a capsule holding a pointer onto a DLPack structure and returns a new tensor which becomes the new owner of the data. This function takes one python object and returns a new python object. It fits the same signature as torch.utils.from_dlpack, if None, the method returns a capsule for every new DLPack structure.

Returns:

a list containing the new tensors or a the new capsules if to_tensor is None

This method is used to replace tuple(torch._C._from_dlpack(ov.to_dlpack()) for ov in ort_values) by a faster instruction tuple(ort_values.to_dlpack(torch._C._from_dlpack)). This loop is difficult to parallelize as it goes through the GIL many times. It creates many tensors acquiring ownership of existing OrtValue. This method saves one object creation and an C++ allocation for every transferred tensor.

PartialGraphExecutionState

class onnxruntime.capi._pybind_state.PartialGraphExecutionState(self: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState)

Bases: pybind11_object

PropagateCastOpsConfiguration

class onnxruntime.capi._pybind_state.PropagateCastOpsConfiguration(self: onnxruntime.capi.onnxruntime_pybind11_state.PropagateCastOpsConfiguration)

Bases: pybind11_object

Propagate cast ops configuration.

property allow
property level
property strategy

TrainingConfigurationResult

class onnxruntime.capi._pybind_state.TrainingConfigurationResult(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingConfigurationResult)

Bases: pybind11_object

pbdoc(Configuration result for training.)pbdoc

property loss_scale_input_name

TrainingGraphTransformerConfiguration

class onnxruntime.capi._pybind_state.TrainingGraphTransformerConfiguration(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingGraphTransformerConfiguration)

Bases: GraphTransformerConfiguration

Training Graph transformer configuration.

property attn_dropout_recompute
property enable_gelu_approximation
property gelu_recompute
property number_recompute_layers
property propagate_cast_ops_config
property transformer_layer_recompute

Functions

onnxruntime.capi._pybind_state.register_aten_op_executor(arg0: str, arg1: str) None
onnxruntime.capi._pybind_state.register_backward_runner(arg0: object) None
onnxruntime.capi._pybind_state.register_forward_runner(arg0: object) None
onnxruntime.capi._pybind_state.register_torch_autograd_function(arg0: str, arg1: object) None
onnxruntime.capi._pybind_state.register_gradient_definition(arg0: str, arg1: List[onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeDefinition]) None
onnxruntime.capi._pybind_state.unregister_python_functions() None

TrainingSession

class onnxruntime.TrainingSession(path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None)[source]

Bases: InferenceSession

Parameters:
  • path_or_bytes – filename or serialized ONNX or ORT format model in a byte string

  • sess_options – session options

  • providers – Optional sequence of providers in order of decreasing precedence. Values can either be provider names or tuples of (provider name, options dict). If not provided, then all available providers are used with the default precedence.

  • provider_options – Optional sequence of options dicts corresponding to the providers listed in ‘providers’.

The model type will be inferred unless explicitly set in the SessionOptions. To explicitly set:

so = onnxruntime.SessionOptions()
# so.add_session_config_entry('session.load_model_format', 'ONNX') or
so.add_session_config_entry('session.load_model_format', 'ORT')

A file extension of ‘.ort’ will be inferred as an ORT format model. All other filenames are assumed to be ONNX format models.

‘providers’ can contain either names or names and options. When any options are given in ‘providers’, ‘provider_options’ should not be used.

The list of providers is ordered by precedence. For example [‘CUDAExecutionProvider’, ‘CPUExecutionProvider’] means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.

disable_fallback()

Disable session.run() fallback mechanism.

enable_fallback()

Enable session.Run() fallback mechanism. If session.Run() fails due to an internal Execution Provider failure, reset the Execution Providers enabled for this session. If GPU is enabled, fall back to CUDAExecutionProvider. otherwise fall back to CPUExecutionProvider.

end_profiling()

End profiling and return results in a file.

The results are stored in a filename if the option onnxruntime.SessionOptions.enable_profiling().

get_inputs()

Return the inputs metadata as a list of onnxruntime.NodeArg.

get_modelmeta()

Return the metadata. See onnxruntime.ModelMetadata.

get_outputs()

Return the outputs metadata as a list of onnxruntime.NodeArg.

get_overridable_initializers()

Return the inputs (including initializers) metadata as a list of onnxruntime.NodeArg.

get_profiling_start_time_ns()

Return the nanoseconds of profiling’s start time Comparable to time.monotonic_ns() after Python 3.3 On some platforms, this timer may not be as precise as nanoseconds For instance, on Windows and MacOS, the precision will be ~100ns

get_provider_options()

Return registered execution providers’ configurations.

get_providers()

Return list of registered execution providers.

get_session_options()

Return the session options. See onnxruntime.SessionOptions.

io_binding()

Return an onnxruntime.IOBinding object`.

run(output_names, input_feed, run_options=None)

Compute the predictions.

Parameters:
  • output_names – name of the outputs

  • input_feed – dictionary { input_name: input_value }

  • run_options – See onnxruntime.RunOptions.

Returns:

list of results, every result is either a numpy array, a sparse tensor, a list or a dictionary.

sess.run([output_name], {input_name: x})
run_with_iobinding(iobinding, run_options=None)

Compute the predictions.

Parameters:
  • iobinding – the iobinding object that has graph inputs/outputs bind.

  • run_options – See onnxruntime.RunOptions.

run_with_ort_values(output_names, input_dict_ort_values, run_options=None)

Compute the predictions.

Parameters:
  • output_names – name of the outputs

  • input_dict_ort_values – dictionary { input_name: input_ort_value } See OrtValue class how to create OrtValue from numpy array or SparseTensor

  • run_options – See onnxruntime.RunOptions.

Returns:

an array of OrtValue

sess.run([output_name], {input_name: x})
run_with_ortvaluevector(run_options, feed_names, feeds, fetch_names, fetches, fetch_devices)

Compute the predictions similar to other run_*() methods but with minimal C++/Python conversion overhead.

Parameters:
  • run_options – See onnxruntime.RunOptions.

  • feed_names – list of input names.

  • feeds – list of input OrtValue.

  • fetch_names – list of output names.

  • fetches – list of output OrtValue.

  • fetch_devices – list of output devices.

set_providers(providers=None, provider_options=None)

Register the input list of execution providers. The underlying session is re-created.

Parameters:
  • providers – Optional sequence of providers in order of decreasing precedence. Values can either be provider names or tuples of (provider name, options dict). If not provided, then all available providers are used with the default precedence.

  • provider_options – Optional sequence of options dicts corresponding to the providers listed in ‘providers’.

‘providers’ can contain either names or names and options. When any options are given in ‘providers’, ‘provider_options’ should not be used.

The list of providers is ordered by precedence. For example [‘CUDAExecutionProvider’, ‘CPUExecutionProvider’] means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.