Training API
Options and Parameters
TrainingParameters
- class onnxruntime.TrainingParameters(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingParameters)
Configuration information for training.
- property allreduce_post_accumulation
- property attn_dropout_recompute
- property data_parallel_size
- property deepspeed_zero_stage
- property enable_adasum
- property enable_grad_norm_clip
- property gelu_recompute
- property gradient_accumulation_steps
- property horizontal_parallel_size
- property immutable_weights
- property loss_output_name
- property loss_scale
- property lr_params_feed_name
- property model_after_graph_transforms_path
- property model_with_gradient_graph_path
- property model_with_training_graph_path
- property num_pipeline_micro_batches
- property number_recompute_layers
- property optimizer_attributes_map
- property optimizer_int_attributes_map
- property pipeline_cut_info_string
- property pipeline_parallel_size
- property propagate_cast_ops_allow
- property propagate_cast_ops_level
- property set_gradients_as_graph_outputs
- set_optimizer_initial_state(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingParameters, arg0: Dict[str, Dict[str, object]]) None
- property sliced_axes
- property sliced_schema
- property sliced_tensor_names
- property training_optimizer_name
- property transformer_layer_recompute
- property use_fp16_moments
- property use_memory_efficient_gradient
- property use_mixed_precision
- property weights_not_to_train
- property weights_to_train
- property world_rank
- property world_size
Functions
- onnxruntime.capi._pybind_state.register_gradient_definition(arg0: str, arg1: List[onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeDefinition]) None
TrainingSession
- class onnxruntime.TrainingSession(path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None)[source]
Bases:
InferenceSession
- Parameters:
path_or_bytes – filename or serialized ONNX or ORT format model in a byte string
sess_options – session options
providers – Optional sequence of providers in order of decreasing precedence. Values can either be provider names or tuples of (provider name, options dict). If not provided, then all available providers are used with the default precedence.
provider_options – Optional sequence of options dicts corresponding to the providers listed in ‘providers’.
The model type will be inferred unless explicitly set in the SessionOptions. To explicitly set:
so = onnxruntime.SessionOptions() # so.add_session_config_entry('session.load_model_format', 'ONNX') or so.add_session_config_entry('session.load_model_format', 'ORT')
A file extension of ‘.ort’ will be inferred as an ORT format model. All other filenames are assumed to be ONNX format models.
‘providers’ can contain either names or names and options. When any options are given in ‘providers’, ‘provider_options’ should not be used.
The list of providers is ordered by precedence. For example [‘CUDAExecutionProvider’, ‘CPUExecutionProvider’] means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
- disable_fallback()
Disable session.run() fallback mechanism.
- enable_fallback()
Enable session.Run() fallback mechanism. If session.Run() fails due to an internal Execution Provider failure, reset the Execution Providers enabled for this session. If GPU is enabled, fall back to CUDAExecutionProvider. otherwise fall back to CPUExecutionProvider.
- end_profiling()
End profiling and return results in a file.
The results are stored in a filename if the option
onnxruntime.SessionOptions.enable_profiling()
.
- get_inputs()
Return the inputs metadata as a list of
onnxruntime.NodeArg
.
- get_modelmeta()
Return the metadata. See
onnxruntime.ModelMetadata
.
- get_outputs()
Return the outputs metadata as a list of
onnxruntime.NodeArg
.
- get_overridable_initializers()
Return the inputs (including initializers) metadata as a list of
onnxruntime.NodeArg
.
- get_profiling_start_time_ns()
Return the nanoseconds of profiling’s start time Comparable to time.monotonic_ns() after Python 3.3 On some platforms, this timer may not be as precise as nanoseconds For instance, on Windows and MacOS, the precision will be ~100ns
- get_provider_options()
Return registered execution providers’ configurations.
- get_providers()
Return list of registered execution providers.
- get_session_options()
Return the session options. See
onnxruntime.SessionOptions
.
- io_binding()
Return an onnxruntime.IOBinding object`.
- run(output_names, input_feed, run_options=None)
Compute the predictions.
- Parameters:
output_names – name of the outputs
input_feed – dictionary
{ input_name: input_value }
run_options – See
onnxruntime.RunOptions
.
- Returns:
list of results, every result is either a numpy array, a sparse tensor, a list or a dictionary.
sess.run([output_name], {input_name: x})
- run_with_iobinding(iobinding, run_options=None)
Compute the predictions.
- Parameters:
iobinding – the iobinding object that has graph inputs/outputs bind.
run_options – See
onnxruntime.RunOptions
.
- run_with_ort_values(output_names, input_dict_ort_values, run_options=None)
Compute the predictions.
- Parameters:
output_names – name of the outputs
input_dict_ort_values – dictionary
{ input_name: input_ort_value }
SeeOrtValue
class how to create OrtValue from numpy array or SparseTensorrun_options – See
onnxruntime.RunOptions
.
- Returns:
an array of OrtValue
sess.run([output_name], {input_name: x})
- run_with_ortvaluevector(run_options, feed_names, feeds, fetch_names, fetches, fetch_devices)
Compute the predictions similar to other run_*() methods but with minimal C++/Python conversion overhead.
- Parameters:
run_options – See
onnxruntime.RunOptions
.feed_names – list of input names.
feeds – list of input OrtValue.
fetch_names – list of output names.
fetches – list of output OrtValue.
fetch_devices – list of output devices.
- set_providers(providers=None, provider_options=None)
Register the input list of execution providers. The underlying session is re-created.
- Parameters:
providers – Optional sequence of providers in order of decreasing precedence. Values can either be provider names or tuples of (provider name, options dict). If not provided, then all available providers are used with the default precedence.
provider_options – Optional sequence of options dicts corresponding to the providers listed in ‘providers’.
‘providers’ can contain either names or names and options. When any options are given in ‘providers’, ‘provider_options’ should not be used.
The list of providers is ordered by precedence. For example [‘CUDAExecutionProvider’, ‘CPUExecutionProvider’] means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.