ORTTrainer

Schema

class onnxruntime.capi.ort_trainer.IODescription(name, shape, dtype=None, num_classes=None)[source]
class onnxruntime.capi.ort_trainer.ModelDescription(inputs, outputs)[source]

ORTTrainerOptions

class onnxruntime.training.ORTTrainerOptions(options={})[source]

Settings used by ONNX Runtime training backend

The parameters are hierarchically organized to facilitate configuration through semantic groups that encompasses features, such as distributed training, etc.

Input validation is performed on the input dict during instantiation to ensure that supported parameters and values are passed in. Invalid input results in ValueError exception with details on it.

Parameters:
  • options (dict) – contains all training options

  • _validate (bool, default is True) – for internal use only

Supported schema for kwargs:

schema = {
            'batch' : {
                'type' : 'dict',
                'required': False,
                'default' : {},
                'schema' : {
                    'gradient_accumulation_steps' : {
                        'type' : 'integer',
                        'min' : 1,
                        'default' : 1
                    }
                },
            },
            'device' : {
                'type' : 'dict',
                'required': False,
                'default' : {},
                'schema' : {
                    'id' : {
                        'type' : 'string',
                        'default' : 'cuda'
                    },
                    'mem_limit' : {
                        'type' : 'integer',
                        'min' : 0,
                        'default' : 0
                    }
                }
            },
            'distributed': {
                'type': 'dict',
                'default': {},
                'required': False,
                'schema': {
                    'world_rank': {
                        'type': 'integer',
                        'min': 0,
                        'default': 0
                    },
                    'world_size': {
                        'type': 'integer',
                        'min': 1,
                        'default': 1
                    },
                    'local_rank': {
                        'type': 'integer',
                        'min': 0,
                        'default': 0
                    },
                    'data_parallel_size': {
                        'type': 'integer',
                        'min': 1,
                        'default': 1
                    },
                    'horizontal_parallel_size': {
                        'type': 'integer',
                        'min': 1,
                        'default': 1
                    },
                    'pipeline_parallel' : {
                        'type': 'dict',
                        'default': {},
                        'required': False,
                        'schema': {
                            'pipeline_parallel_size': {
                                'type': 'integer',
                                'min': 1,
                                'default': 1
                            },
                            'num_pipeline_micro_batches': {
                                'type': 'integer',
                                'min': 1,
                                'default': 1
                            },
                            'pipeline_cut_info_string': {
                                'type': 'string',
                                'default': ''
                            },
                            'sliced_schema': {
                                'type': 'dict',
                                'default': {},
                                'keysrules': {'type': 'string'},
                                'valuesrules': {
                                    'type': 'list',
                                    'schema': {'type': 'integer'}
                                }
                            },
                            'sliced_axes': {
                                'type': 'dict',
                                'default': {},
                                'keysrules': {'type': 'string'},
                                'valuesrules': {'type': 'integer'}
                            },
                            'sliced_tensor_names': {
                                'type': 'list',
                                'schema': {'type': 'string'},
                                'default': []
                            }
                        }
                    },
                    'allreduce_post_accumulation': {
                        'type': 'boolean',
                        'default': False
                    },
                    'deepspeed_zero_optimization': {
                        'type': 'dict',
                        'default': {},
                        'required': False,
                        'schema': {
                            'stage': {
                                'type': 'integer',
                                'min': 0,
                                'max': 1,
                                'default': 0
                            },
                        }
                    },
                    'enable_adasum': {
                        'type': 'boolean',
                        'default': False
                    }
                }
            },
            'lr_scheduler' : {
                'type' : 'optim.lr_scheduler',
                'nullable' : True,
                'default' : None
            },
            'mixed_precision' : {
                'type' : 'dict',
                'required': False,
                'default' : {},
                'schema' : {
                    'enabled' : {
                        'type' : 'boolean',
                        'default' : False
                    },
                    'loss_scaler' : {
                        'type' : 'amp.loss_scaler',
                        'nullable' : True,
                        'default' : None
                    }
                }
            }
        },
        'graph_transformer': {
            'type': 'dict',
            'required': False,
            'default': {},
            'schema': {
                'attn_dropout_recompute': {
                    'type': 'boolean',
                    'default': False
                },
                'gelu_recompute': {
                    'type': 'boolean',
                    'default': False
                },
                'transformer_layer_recompute': {
                    'type': 'boolean',
                    'default': False
                },
                'number_recompute_layers': {
                    'type': 'integer',
                    'min': 0,
                    'default': 0
                },
                'propagate_cast_ops_config': {
                    'type': 'dict',
                    'required': False,
                    'default': {},
                    'schema': {
                        'propagate_cast_ops_strategy': {
                            'type': 'onnxruntime.training.PropagateCastOpsStrategy',
                            'default': PropagateCastOpsStrategy.FLOOD_FILL
                        },
                        'propagate_cast_ops_level': {
                            'type': 'integer',
                            'default': 1
                        },
                        'propagate_cast_ops_allow': {
                            'type': 'list',
                            'schema': {'type': 'string'},
                            'default': []
                        }
                    },
                    'allow_layer_norm_mod_precision': {
                        'type': 'boolean',
                        'default': False
                    }
                }
            },
            'utils' : {
                'type' : 'dict',
                'required': False,
                'default' : {},
                'schema' : {
                    'frozen_weights' : {
                        'type' : 'list',
                        'default' : []
                    },
                    'grad_norm_clip' : {
                        'type' : 'boolean',
                        'default' : True
                    },
                    'memory_efficient_gradient' : {
                        'type' : 'boolean',
                        'default' : False
                    },
                    'run_symbolic_shape_infer' : {
                        'type' : 'boolean',
                        'default' : False
                    }
                }
            },
            'debug' : {
                'type' : 'dict',
                'required': False,
                'default' : {},
                'schema' : {
                    'deterministic_compute' : {
                        'type' : 'boolean',
                        'default' : False
                    },
                    'check_model_export' : {
                        'type' : 'boolean',
                        'default' : False
                    },
                    'graph_save_paths' : {
                        'type' : 'dict',
                        'default': {},
                        'required': False,
                        'schema': {
                            'model_after_graph_transforms_path': {
                                'type': 'string',
                                'default': ''
                            },
                            'model_with_gradient_graph_path':{
                                'type': 'string',
                                'default': ''
                            },
                            'model_with_training_graph_path': {
                                'type': 'string',
                                'default': ''
                            },
                            'model_with_training_graph_after_optimization_path': {
                                'type': 'string',
                                'default': ''
                            },
                        }
                    },
                }
            },
            '_internal_use' : {
                'type' : 'dict',
                'required': False,
                'default' : {},
                'schema' : {
                    'enable_internal_postprocess' : {
                        'type' : 'boolean',
                        'default' : True
                    },
                    'extra_postprocess' : {
                        'type' : 'callable',
                        'nullable' : True,
                        'default' : None
                    },
                    'onnx_opset_version': {
                        'type': 'integer',
                        'min' : 12,
                        'max' : 13,
                        'default': 12
                    },
                    'enable_onnx_contrib_ops' : {
                        'type' : 'boolean',
                        'default' : True
                    }
                },
            }
        },
        '_internal_use' : {
            'type' : 'dict',
            'required': False,
            'default' : {},
            'schema' : {
                'enable_internal_postprocess' : {
                    'type' : 'boolean',
                    'default' : True
                },
                'extra_postprocess' : {
                    'type' : 'callable',
                    'nullable' : True,
                    'default' : None
                },
                'onnx_opset_version': {
                    'type': 'integer',
                    'min' : 12,
                    'max' :14,
                    'default': 14
                },
                'enable_onnx_contrib_ops' : {
                    'type' : 'boolean',
                    'default' : True
                }
            },
            'provider_options':{
                'type': 'dict',
                'default': {},
                'required': False,
                'schema': {}
            },
            'session_options': {
                'type': 'SessionOptions',
                'nullable': True,
                'default': None
            },
         }
Keyword Arguments:
  • batch (dict) – batch related settings

  • batch.gradient_accumulation_steps (int, default is 1) – number of steps to accumulate before do collective gradient reduction

  • device (dict) – compute device related settings

  • device.id (string, default is 'cuda') – device to run training

  • device.mem_limit (int) – maximum memory size (in bytes) used by device.id

  • distributed (dict) – distributed training options.

  • distributed.world_rank (int, default is 0) – rank ID used for data/horizontal parallelism

  • distributed.world_size (int, default is 1) – number of ranks participating in parallelism

  • distributed.data_parallel_size (int, default is 1) – number of ranks participating in data parallelism

  • distributed.horizontal_parallel_size (int, default is 1) – number of ranks participating in horizontal parallelism

  • distributed.pipeline_parallel (dict) – Options which are only useful to pipeline parallel.

  • distributed.pipeline_parallel.pipeline_parallel_size (int, default is 1) – number of ranks participating in pipeline parallelism

  • distributed.pipeline_parallel.num_pipeline_micro_batches (int, default is 1) – number of micro-batches. We divide input batch into micro-batches and run the graph.

  • distributed.pipeline_parallel.pipeline_cut_info_string (string, default is '') – string of cutting ids for pipeline partition.

  • distributed.allreduce_post_accumulation (bool, default is False) – True enables overlap of AllReduce with computation, while False, postpone AllReduce until all gradients are ready

  • distributed.deepspeed_zero_optimization – DeepSpeed ZeRO options.

  • distributed.deepspeed_zero_optimization.stage (int, default is 0) – select which stage of DeepSpeed ZeRO to use. Stage 0 means disabled.

  • distributed.enable_adasum (bool, default is False) – enable Adasum algorithm for AllReduce

  • lr_scheduler (optim._LRScheduler, default is None) – specifies learning rate scheduler

  • mixed_precision (dict) – mixed precision training options

  • mixed_precision.enabled (bool, default is False) – enable mixed precision (fp16)

  • mixed_precision.loss_scaler (amp.LossScaler, default is None) – specifies a loss scaler to be used for fp16. If not specified, DynamicLossScaler is used with default values. Users can also instantiate DynamicLossScaler and override its parameters. Lastly, a completely new implementation can be specified by extending LossScaler class from scratch

  • graph_transformer (dict) – graph transformer related configurations

  • graph_transformer.attn_dropout_recompute (bool, default False) –

  • graph_transformer.gelu_recompute (bool, default False) –

  • graph_transformer.transformer_layer_recompute (bool, default False) –

  • graph_transformer.number_recompute_layers (bool, default False) –

  • graph_transformer.propagate_cast_ops_config (dict) –

    graph_transformer.propagate_cast_ops_config.strategy(PropagateCastOpsStrategy, default FLOOD_FILL)

    Specify the choice of the cast propagation optimization strategy, either, NONE, INSERT_AND_REDUCE or FLOOD_FILL. NONE strategy does not perform any cast propagation transformation on the graph, although other optimizations locally change cast operations, for example, in order to fuse Transpose and MatMul nodes, the TransposeMatMulFunsion optimization could interchange Transpose and Cast if the Cast node exists between Transpose and MatMul. INSERT_AND_REDUCE strategy inserts and reduces cast operations around the nodes with allowed opcodes. FLOOD_FILL strategy expands float16 regions in the graph using the allowed opcodes, and unlike INSERT_AND_REDUCE does not touch opcodes outside expanded float16 region.

    graph_transformer.propagate_cast_ops_config.level(integer, default 1)

    Optimize by moving Cast operations if propagate_cast_ops_level is non-negative. Use predetermined list of opcodes considered safe to move before/after cast operation if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise.

    graph_transformer.propagate_cast_ops_config.allow(list of str, [])

    List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero.

  • attn_dropout_recompute (bool, default is False) – enable recomputing attention dropout to save memory

  • gelu_recompute (bool, default is False) – enable recomputing Gelu activation output to save memory

  • transformer_layer_recompute (bool, default is False) – enable recomputing transformer layerwise to save memory

  • number_recompute_layers (int, default is 0) – number of layers to apply transformer_layer_recompute, by default system will apply recompute to all the layers, except for the last one

  • utils (dict) – miscellaneous options

  • utils.frozen_weights (list of str, []) – list of model parameter names to skip training (weights don’t change)

  • utils.grad_norm_clip (bool, default is True) – enables gradient norm clipping for ‘AdamOptimizer’ and ‘LambOptimizer’

  • utils.memory_efficient_gradient (bool, default is False) – enables use of memory aware gradient builder.

  • utils.run_symbolic_shape_infer (bool, default is False) – runs symbolic shape inference on the model

  • debug (dict) – debug options

  • debug.deterministic_compute (bool, default is False) – forces compute to be deterministic accross runs

  • debug.check_model_export (bool, default is False) – compares PyTorch model outputs with ONNX model outputs in inference before the first train step to ensure successful model export

  • debug.graph_save_paths (dict) – paths used for dumping ONNX graphs for debugging purposes

  • debug.graph_save_paths.model_after_graph_transforms_path (str, default is "") – path to export the ONNX graph after training-related graph transforms have been applied. No output when it is empty.

  • debug.graph_save_paths.model_with_gradient_graph_path (str, default is "") – path to export the ONNX graph with the gradient graph added. No output when it is empty.

  • debug.graph_save_paths.model_with_training_graph_path (str, default is "") – path to export the training ONNX graph with forward, gradient and optimizer nodes. No output when it is empty.

  • debug.graph_save_paths.model_with_training_graph_after_optimization_path (str, default is "") – outputs the optimized training graph to the path if nonempty.

  • _internal_use (dict) – internal options, possibly undocumented, that might be removed without notice

  • _internal_use.enable_internal_postprocess (bool, default is True) – enable internal internal post processing of the ONNX model

  • _internal_use.extra_postprocess (callable, default is None) – a functor to postprocess the ONNX model and return a new ONNX model. It does not override _internal_use.enable_internal_postprocess, but complement it

  • _internal_use.onnx_opset_version (int, default is 14) – ONNX opset version used during model exporting.

  • _internal_use.enable_onnx_contrib_ops (bool, default is True) – enable PyTorch to export nodes as contrib ops in ONNX. This flag may be removed anytime in the future.

  • session_options (onnxruntime.SessionOptions) – The SessionOptions instance that TrainingSession will use.

  • provider_options (dict) – The provider_options for customized execution providers. it is dict map from EP name to a key-value pairs, like {‘EP1’ : {‘key1’ : ‘val1’}, ….}

Example:

opts = ORTTrainerOptions({
                   'batch' : {
                       'gradient_accumulation_steps' : 128
                   },
                   'device' : {
                       'id' : 'cuda:0',
                       'mem_limit' : 2*1024*1024*1024,
                   },
                   'lr_scheduler' : optim.lr_scheduler.LinearWarmupLRScheduler(),
                   'mixed_precision' : {
                       'enabled': True,
                       'loss_scaler': amp.LossScaler(loss_scale=float(1 << 16))
                   }
})
fp16_enabled = opts.mixed_precision.enabled

Internal ORTTrainer

class onnxruntime.training.ORTTrainer(model, model_desc, optim_config, loss_fn=None, options=None)[source]

Bases: object

Pytorch frontend for ONNX Runtime training

Entry point that exposes the C++ backend of ORT as a Pytorch frontend.

Parameters:
  • model (torch.nn.Module or onnx.ModelProto) – either a PyTorch or ONNX model. When a PyTorch model and loss_fn are specified, model and loss_fn are combined. When a ONNX model is provided, the loss is identified by the flag is_loss=True in one of the model_desc.outputs entries.

  • model_desc (dict) – model input and output description. This is used to identify inputs and outputs and their shapes, so that ORT can generate back propagation graph, plan memory allocation for training, and perform optimizations. model_desc must be consistent with the training model and have the following (dict) schema { 'inputs': [tuple(name, shape)], 'outputs': [tuple(name, shape, is_loss)]}. name is a string representing the name of input or output of the model. For model_desc['inputs'] entries, name must match input names of the original PyTorch model’s torch.nn.Module.forward() method. For ONNX models, both name and order of input names must match. For model_desc['outputs'] entries, the order must match the original PyTorch’s output as returned by torch.nn.Module.forward() method. For ONNX models, both name and order of output names must match. shape is a list of string or integers that describes the shape of the input/output. Each dimension size can be either a string or an int. String means the dimension size is dynamic, while integers mean static dimensions. An empty list implies a scalar. Lastly, is_loss is a boolean (default is False) that flags if this output is considered a loss. ORT backend needs to know which output is loss in order to generate back propagation graph. Loss output must be specified when either loss_fn is specified or when loss is embedded in the model. Note that only one loss output is supported per model.

  • optimizer_config (optim._OptimizerConfig) – optimizer config. One of optim.AdamConfig, optim.LambConfig or optim.SGDConfig.

  • loss_fn (callable, default is None) – a PyTorch loss function. It takes two inputs [prediction, label] and outputs a scalar loss tensor. If provided, loss_fn is combined with the PyTorch model to form a combined PyTorch model. Inputs to the combined PyTorch model are concatenation of the model’s input and loss_fn’s label input. Outputs of the combined PyTorch model are concatenation of loss_fn’s loss output and model’s outputs.

  • options (ORTTrainerOptions, default is None) – options for additional features.

Example

model = ...
loss_fn = ...
model_desc = {
    "inputs": [
        ("input_ids", ["batch", "max_seq_len_in_batch"]),
        ("attention_mask", ["batch", "max_seq_len_in_batch"]),
        ("token_type_ids", ["batch", "max_seq_len_in_batch"]),
        ("masked_lm_labels", ["batch", "max_seq_len_in_batch"]),
        ("next_sentence_label", ["batch", 1])
    ],
    "outputs": [
        ("loss", [], True),
    ],
}
optim_config = optim.LambConfig(param_groups = [ { 'params' : ['model_param0'], 'alpha' : 0.8, 'beta' : 0.7},
                                                 { 'params' : ['model_param1' , 'model_param_2'], 'alpha' : 0.0}
                                               ],
                                alpha=0.9, beta=0.999)
ort_trainer = ORTTrainer(model, model_desc, optim_config, loss_fn)

Deprecated since version 1.10.

Use ORTModule instead.

eval_step(*args, **kwargs)[source]

Evaluation step method

Parameters:
  • *args – Arbitrary arguments that are used as model input (data only)

  • **kwargs – Arbitrary keyword arguments that are used as model input (data only)

Returns:

ordered list with model outputs as described by ORTTrainer.model_desc

save_as_onnx(path)[source]

Persists ONNX model into path

The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard. The graph includes full information, including inference and training metadata.

Parameters:

path (str) – Full path, including filename, to save the ONNX model in the filesystem

Raises:
  • RuntimeWarning – raised when neither train_step or eval_step was called at least once

  • ValueError – raised when path is not valid path

train_step(*args, **kwargs)[source]

Train step method

After forward pass, an ordered list with all outputs described at ORTTrainer.model_desc is returned. Additional information relevant to the train step is maintend by ORTTrainer._train_step_info. See TrainStepInfo for details.

Parameters:
  • *args – Arbitrary arguments that are used as model input (data only)

  • **kwargs – Arbitrary keyword arguments that are used as model input (data only)

Returns:

ordered list with model outputs as described by ORTTrainer.model_desc

state_dict(pytorch_format=False)[source]

Returns a dictionary with model, train step info and optionally, optimizer states

The returned dictionary contains the following information:

  • Model and optimizer states

  • Required ORTTrainerOptions settings

  • Distributed training information, such as but not limited to ZeRO

  • Train step info settings

Structure of the returned dictionary:

  • When pytorch_format = False

    schema:
    {
        "model":
        {
            type: dict,
            schema:
            {
                "full_precision":
                {
                    type: dict,
                    schema:
                    {
                        model_weight_name:
                        {
                            type: array
                        }
                    }
                }
            }
        },
        "optimizer":
        {
            type: dict,
            schema:
            {
                model_weight_name:
                {
                    type: dict,
                    schema:
                    {
                        "Moment_1":
                        {
                            type: array
                        },
                        "Moment_2":
                        {
                            type: array
                        },
                        "Update_Count":
                        {
                            type: array,
                            optional: True # present if optimizer is adam, absent otherwise
                        }
                    }
                },
                "shared_optimizer_state":
                {
                    type: dict,
                    optional: True, # present optimizer is shared, absent otherwise.
                    schema:
                    {
                        "step":
                        {
                            type: array,
                        }
                    }
                }
            }
        },
        "trainer_options":
        {
            type: dict,
            schema:
            {
                "mixed_precision":
                {
                    type: bool
                },
                "zero_stage":
                {
                    type: int
                },
                "world_rank":
                {
                    type: int
                },
                "world_size":
                {
                    type: int
                },
                "optimizer_name":
                {
                    type: str
                },
                "data_parallel_size":
                {
                    type: int
                },
                "horizontal_parallel_size":
                {
                    type: int
                }
            }
        },
        "partition_info":
        {
            type: dict,
            optional: True, # present if states partitioned, else absent
            schema:
            {
                model_weight_name:
                {
                    type: dict,
                    schema:
                    {
                        "original_dim":
                        {
                            type: array
                        },
                        "megatron_row_partition":
                        {
                            type: int
                        }
                    }
                }
            }
        }
    }
    

    }, “train_step_info”: {

    type: dict, schema: {

    “optimization_step”: {

    type: int

    }, “step”: {

    type: int

    }

    }

    }

} - When pytorch_format = True

schema:
{
    model_weight_name:
    {
        type: tensor
    }
}
Parameters:

pytorch_format – boolean flag to select either ONNX Runtime or PyTorch state schema

Returns:

A dictionary with ORTTrainer state

load_state_dict(state_dict, strict=True)[source]

Loads state_dict containing model/optimizer states into ORTTrainer

The state_dict dictionary may contain the following information: - Model and optimizer states - Required ORTTrainerOptions settings - Distributed training information, such as but not limited to ZeRO

Parameters:
  • state_dict – state dictionary containing both model and optimizer states. The structure of this dictionary should be the same as the one that is returned by ORTTrainer.state_dict for the case when pytorch_format=False

  • strict – boolean flag to strictly enforce that the input state_dict keys match the keys from ORTTrainer.state_dict

save_checkpoint(path, user_dict={}, include_optimizer_states=True)[source]

Persists ORTTrainer state dictionary on disk along with user_dict.

Saves the state_dict along with the user_dict to a file specified by path.

Parameters:
  • path – string representation to a file path or a python file-like object. if file already exists at path, an exception is raised.

  • user_dict – custom data to be saved along with the state_dict. This data will be returned to the user when load_checkpoint is called.

  • include_optimizer_states – boolean flag indicating whether or not to persist the optimizer states. on load_checkpoint, only model states will be loaded if include_optimizer_states==True

load_checkpoint(*paths, strict=True)[source]

Loads the saved checkpoint state dictionary into the ORTTrainer

Reads the saved checkpoint files specified by paths from disk and loads the state dictionary onto the ORTTrainer. Aggregates the checkpoint files if aggregation is required.

Parameters:
  • paths – one or more files represented as strings where the checkpoint is saved

  • strict – boolean flag to strictly enforce that the saved checkpoint state_dict keys match the keys from ORTTrainer.state_dict

Returns:

dictionary that the user had saved when calling save_checkpoint

Internal ORTTrainer

class onnxruntime.capi.ort_trainer.ORTTrainer(model, loss_fn, model_desc, training_optimizer_name, map_optimizer_attributes, learning_rate_description, device, gradient_accumulation_steps=1, world_rank=0, world_size=1, use_mixed_precision=False, allreduce_post_accumulation=False, global_step=0, get_lr_this_step=None, loss_scaler=None, deepspeed_zero_stage=0, enable_grad_norm_clip=True, frozen_weights=[], _opset_version=14, _enable_internal_postprocess=True, _extra_postprocess=None, _use_deterministic_compute=False, use_memory_efficient_gradient=False, run_symbolic_shape_infer=False, enable_adasum=False, optimized_model_filepath='')[source]

Bases: object

train_step(*args, **kwargs)[source]

inputs: model inputs, labels, learning rate, and, if in mixed_precision mode, loss_scale. outputs: if fetches is not provided, outputs are loss and

(if in mixed mode and is finishing gradient accumulation) all_finite. if fetches is provided, outputs contains these requested with fetches.

fetches: names of requested outputs

eval_step(*args, **kwargs)[source]

inputs: model inputs and/or labels. outputs: if ‘fetches’ is not provided, outputs are loss and

(if in mixed mode and is finishing gradient accumulation) all_finite. if fetches is provided, outputs contains these requested with fetches.

fetches: names of requested outputs