ORTTrainer
Schema
- class onnxruntime.capi.ort_trainer.IODescription(name, shape, dtype=None, num_classes=None)[source]
ORTTrainerOptions
- class onnxruntime.training.ORTTrainerOptions(options={})[source]
Settings used by ONNX Runtime training backend
The parameters are hierarchically organized to facilitate configuration through semantic groups that encompasses features, such as distributed training, etc.
Input validation is performed on the input dict during instantiation to ensure that supported parameters and values are passed in. Invalid input results in
ValueError
exception with details on it.- Parameters:
Supported schema for kwargs:
schema = { 'batch' : { 'type' : 'dict', 'required': False, 'default' : {}, 'schema' : { 'gradient_accumulation_steps' : { 'type' : 'integer', 'min' : 1, 'default' : 1 } }, }, 'device' : { 'type' : 'dict', 'required': False, 'default' : {}, 'schema' : { 'id' : { 'type' : 'string', 'default' : 'cuda' }, 'mem_limit' : { 'type' : 'integer', 'min' : 0, 'default' : 0 } } }, 'distributed': { 'type': 'dict', 'default': {}, 'required': False, 'schema': { 'world_rank': { 'type': 'integer', 'min': 0, 'default': 0 }, 'world_size': { 'type': 'integer', 'min': 1, 'default': 1 }, 'local_rank': { 'type': 'integer', 'min': 0, 'default': 0 }, 'data_parallel_size': { 'type': 'integer', 'min': 1, 'default': 1 }, 'horizontal_parallel_size': { 'type': 'integer', 'min': 1, 'default': 1 }, 'pipeline_parallel' : { 'type': 'dict', 'default': {}, 'required': False, 'schema': { 'pipeline_parallel_size': { 'type': 'integer', 'min': 1, 'default': 1 }, 'num_pipeline_micro_batches': { 'type': 'integer', 'min': 1, 'default': 1 }, 'pipeline_cut_info_string': { 'type': 'string', 'default': '' }, 'sliced_schema': { 'type': 'dict', 'default': {}, 'keysrules': {'type': 'string'}, 'valuesrules': { 'type': 'list', 'schema': {'type': 'integer'} } }, 'sliced_axes': { 'type': 'dict', 'default': {}, 'keysrules': {'type': 'string'}, 'valuesrules': {'type': 'integer'} }, 'sliced_tensor_names': { 'type': 'list', 'schema': {'type': 'string'}, 'default': [] } } }, 'allreduce_post_accumulation': { 'type': 'boolean', 'default': False }, 'deepspeed_zero_optimization': { 'type': 'dict', 'default': {}, 'required': False, 'schema': { 'stage': { 'type': 'integer', 'min': 0, 'max': 1, 'default': 0 }, } }, 'enable_adasum': { 'type': 'boolean', 'default': False } } }, 'lr_scheduler' : { 'type' : 'optim.lr_scheduler', 'nullable' : True, 'default' : None }, 'mixed_precision' : { 'type' : 'dict', 'required': False, 'default' : {}, 'schema' : { 'enabled' : { 'type' : 'boolean', 'default' : False }, 'loss_scaler' : { 'type' : 'amp.loss_scaler', 'nullable' : True, 'default' : None } } } }, 'graph_transformer': { 'type': 'dict', 'required': False, 'default': {}, 'schema': { 'attn_dropout_recompute': { 'type': 'boolean', 'default': False }, 'gelu_recompute': { 'type': 'boolean', 'default': False }, 'transformer_layer_recompute': { 'type': 'boolean', 'default': False }, 'number_recompute_layers': { 'type': 'integer', 'min': 0, 'default': 0 }, 'propagate_cast_ops_config': { 'type': 'dict', 'required': False, 'default': {}, 'schema': { 'propagate_cast_ops_strategy': { 'type': 'onnxruntime.training.PropagateCastOpsStrategy', 'default': PropagateCastOpsStrategy.FLOOD_FILL }, 'propagate_cast_ops_level': { 'type': 'integer', 'default': 1 }, 'propagate_cast_ops_allow': { 'type': 'list', 'schema': {'type': 'string'}, 'default': [] } }, 'allow_layer_norm_mod_precision': { 'type': 'boolean', 'default': False } } }, 'utils' : { 'type' : 'dict', 'required': False, 'default' : {}, 'schema' : { 'frozen_weights' : { 'type' : 'list', 'default' : [] }, 'grad_norm_clip' : { 'type' : 'boolean', 'default' : True }, 'memory_efficient_gradient' : { 'type' : 'boolean', 'default' : False }, 'run_symbolic_shape_infer' : { 'type' : 'boolean', 'default' : False } } }, 'debug' : { 'type' : 'dict', 'required': False, 'default' : {}, 'schema' : { 'deterministic_compute' : { 'type' : 'boolean', 'default' : False }, 'check_model_export' : { 'type' : 'boolean', 'default' : False }, 'graph_save_paths' : { 'type' : 'dict', 'default': {}, 'required': False, 'schema': { 'model_after_graph_transforms_path': { 'type': 'string', 'default': '' }, 'model_with_gradient_graph_path':{ 'type': 'string', 'default': '' }, 'model_with_training_graph_path': { 'type': 'string', 'default': '' }, 'model_with_training_graph_after_optimization_path': { 'type': 'string', 'default': '' }, } }, } }, '_internal_use' : { 'type' : 'dict', 'required': False, 'default' : {}, 'schema' : { 'enable_internal_postprocess' : { 'type' : 'boolean', 'default' : True }, 'extra_postprocess' : { 'type' : 'callable', 'nullable' : True, 'default' : None }, 'onnx_opset_version': { 'type': 'integer', 'min' : 12, 'max' : 13, 'default': 12 }, 'enable_onnx_contrib_ops' : { 'type' : 'boolean', 'default' : True } }, } }, '_internal_use' : { 'type' : 'dict', 'required': False, 'default' : {}, 'schema' : { 'enable_internal_postprocess' : { 'type' : 'boolean', 'default' : True }, 'extra_postprocess' : { 'type' : 'callable', 'nullable' : True, 'default' : None }, 'onnx_opset_version': { 'type': 'integer', 'min' : 12, 'max' :14, 'default': 14 }, 'enable_onnx_contrib_ops' : { 'type' : 'boolean', 'default' : True } }, 'provider_options':{ 'type': 'dict', 'default': {}, 'required': False, 'schema': {} }, 'session_options': { 'type': 'SessionOptions', 'nullable': True, 'default': None }, }
- Keyword Arguments:
batch (dict) – batch related settings
batch.gradient_accumulation_steps (int, default is 1) – number of steps to accumulate before do collective gradient reduction
device (dict) – compute device related settings
device.id (string, default is 'cuda') – device to run training
device.mem_limit (int) – maximum memory size (in bytes) used by device.id
distributed (dict) – distributed training options.
distributed.world_rank (int, default is 0) – rank ID used for data/horizontal parallelism
distributed.world_size (int, default is 1) – number of ranks participating in parallelism
distributed.data_parallel_size (int, default is 1) – number of ranks participating in data parallelism
distributed.horizontal_parallel_size (int, default is 1) – number of ranks participating in horizontal parallelism
distributed.pipeline_parallel (dict) – Options which are only useful to pipeline parallel.
distributed.pipeline_parallel.pipeline_parallel_size (int, default is 1) – number of ranks participating in pipeline parallelism
distributed.pipeline_parallel.num_pipeline_micro_batches (int, default is 1) – number of micro-batches. We divide input batch into micro-batches and run the graph.
distributed.pipeline_parallel.pipeline_cut_info_string (string, default is '') – string of cutting ids for pipeline partition.
distributed.allreduce_post_accumulation (bool, default is False) – True enables overlap of AllReduce with computation, while False, postpone AllReduce until all gradients are ready
distributed.deepspeed_zero_optimization – DeepSpeed ZeRO options.
distributed.deepspeed_zero_optimization.stage (int, default is 0) – select which stage of DeepSpeed ZeRO to use. Stage 0 means disabled.
distributed.enable_adasum (bool, default is False) – enable Adasum algorithm for AllReduce
lr_scheduler (optim._LRScheduler, default is None) – specifies learning rate scheduler
mixed_precision (dict) – mixed precision training options
mixed_precision.enabled (bool, default is False) – enable mixed precision (fp16)
mixed_precision.loss_scaler (amp.LossScaler, default is None) – specifies a loss scaler to be used for fp16. If not specified,
DynamicLossScaler
is used with default values. Users can also instantiateDynamicLossScaler
and override its parameters. Lastly, a completely new implementation can be specified by extendingLossScaler
class from scratchgraph_transformer (dict) – graph transformer related configurations
graph_transformer.attn_dropout_recompute (bool, default False) –
graph_transformer.gelu_recompute (bool, default False) –
graph_transformer.transformer_layer_recompute (bool, default False) –
graph_transformer.number_recompute_layers (bool, default False) –
graph_transformer.propagate_cast_ops_config (dict) –
- graph_transformer.propagate_cast_ops_config.strategy(PropagateCastOpsStrategy, default FLOOD_FILL)
Specify the choice of the cast propagation optimization strategy, either, NONE, INSERT_AND_REDUCE or FLOOD_FILL. NONE strategy does not perform any cast propagation transformation on the graph, although other optimizations locally change cast operations, for example, in order to fuse Transpose and MatMul nodes, the TransposeMatMulFunsion optimization could interchange Transpose and Cast if the Cast node exists between Transpose and MatMul. INSERT_AND_REDUCE strategy inserts and reduces cast operations around the nodes with allowed opcodes. FLOOD_FILL strategy expands float16 regions in the graph using the allowed opcodes, and unlike INSERT_AND_REDUCE does not touch opcodes outside expanded float16 region.
- graph_transformer.propagate_cast_ops_config.level(integer, default 1)
Optimize by moving Cast operations if propagate_cast_ops_level is non-negative. Use predetermined list of opcodes considered safe to move before/after cast operation if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise.
- graph_transformer.propagate_cast_ops_config.allow(list of str, [])
List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero.
attn_dropout_recompute (bool, default is False) – enable recomputing attention dropout to save memory
gelu_recompute (bool, default is False) – enable recomputing Gelu activation output to save memory
transformer_layer_recompute (bool, default is False) – enable recomputing transformer layerwise to save memory
number_recompute_layers (int, default is 0) – number of layers to apply transformer_layer_recompute, by default system will apply recompute to all the layers, except for the last one
utils (dict) – miscellaneous options
utils.frozen_weights (list of str, []) – list of model parameter names to skip training (weights don’t change)
utils.grad_norm_clip (bool, default is True) – enables gradient norm clipping for ‘AdamOptimizer’ and ‘LambOptimizer’
utils.memory_efficient_gradient (bool, default is False) – enables use of memory aware gradient builder.
utils.run_symbolic_shape_infer (bool, default is False) – runs symbolic shape inference on the model
debug (dict) – debug options
debug.deterministic_compute (bool, default is False) – forces compute to be deterministic accross runs
debug.check_model_export (bool, default is False) – compares PyTorch model outputs with ONNX model outputs in inference before the first train step to ensure successful model export
debug.graph_save_paths (dict) – paths used for dumping ONNX graphs for debugging purposes
debug.graph_save_paths.model_after_graph_transforms_path (str, default is "") – path to export the ONNX graph after training-related graph transforms have been applied. No output when it is empty.
debug.graph_save_paths.model_with_gradient_graph_path (str, default is "") – path to export the ONNX graph with the gradient graph added. No output when it is empty.
debug.graph_save_paths.model_with_training_graph_path (str, default is "") – path to export the training ONNX graph with forward, gradient and optimizer nodes. No output when it is empty.
debug.graph_save_paths.model_with_training_graph_after_optimization_path (str, default is "") – outputs the optimized training graph to the path if nonempty.
_internal_use (dict) – internal options, possibly undocumented, that might be removed without notice
_internal_use.enable_internal_postprocess (bool, default is True) – enable internal internal post processing of the ONNX model
_internal_use.extra_postprocess (callable, default is None) – a functor to postprocess the ONNX model and return a new ONNX model. It does not override
_internal_use.enable_internal_postprocess
, but complement it_internal_use.onnx_opset_version (int, default is 14) – ONNX opset version used during model exporting.
_internal_use.enable_onnx_contrib_ops (bool, default is True) – enable PyTorch to export nodes as contrib ops in ONNX. This flag may be removed anytime in the future.
session_options (onnxruntime.SessionOptions) – The SessionOptions instance that TrainingSession will use.
provider_options (dict) – The provider_options for customized execution providers. it is dict map from EP name to a key-value pairs, like {‘EP1’ : {‘key1’ : ‘val1’}, ….}
Example:
opts = ORTTrainerOptions({ 'batch' : { 'gradient_accumulation_steps' : 128 }, 'device' : { 'id' : 'cuda:0', 'mem_limit' : 2*1024*1024*1024, }, 'lr_scheduler' : optim.lr_scheduler.LinearWarmupLRScheduler(), 'mixed_precision' : { 'enabled': True, 'loss_scaler': amp.LossScaler(loss_scale=float(1 << 16)) } }) fp16_enabled = opts.mixed_precision.enabled
Internal ORTTrainer
- class onnxruntime.training.ORTTrainer(model, model_desc, optim_config, loss_fn=None, options=None)[source]
Bases:
object
Pytorch frontend for ONNX Runtime training
Entry point that exposes the C++ backend of ORT as a Pytorch frontend.
- Parameters:
model (torch.nn.Module or onnx.ModelProto) – either a PyTorch or ONNX model. When a PyTorch model and
loss_fn
are specified,model
andloss_fn
are combined. When a ONNX model is provided, the loss is identified by the flagis_loss=True
in one of themodel_desc.outputs
entries.model_desc (dict) – model input and output description. This is used to identify inputs and outputs and their shapes, so that ORT can generate back propagation graph, plan memory allocation for training, and perform optimizations.
model_desc
must be consistent with the trainingmodel
and have the following (dict
) schema{ 'inputs': [tuple(name, shape)], 'outputs': [tuple(name, shape, is_loss)]}
.name
is a string representing the name of input or output of the model. Formodel_desc['inputs']
entries,name
must match input names of the original PyTorch model’storch.nn.Module.forward()
method. For ONNX models, both name and order of input names must match. Formodel_desc['outputs']
entries, the order must match the original PyTorch’s output as returned bytorch.nn.Module.forward()
method. For ONNX models, both name and order of output names must match.shape
is a list of string or integers that describes the shape of the input/output. Each dimension size can be either a string or an int. String means the dimension size is dynamic, while integers mean static dimensions. An empty list implies a scalar. Lastly,is_loss
is a boolean (default is False) that flags if this output is considered a loss. ORT backend needs to know which output is loss in order to generate back propagation graph. Loss output must be specified when eitherloss_fn
is specified or when loss is embedded in the model. Note that only one loss output is supported per model.optimizer_config (optim._OptimizerConfig) – optimizer config. One of
optim.AdamConfig
,optim.LambConfig
oroptim.SGDConfig
.loss_fn (callable, default is None) – a PyTorch loss function. It takes two inputs [prediction, label] and outputs a scalar loss tensor. If provided,
loss_fn
is combined with the PyTorchmodel
to form a combined PyTorch model. Inputs to the combined PyTorch model are concatenation of themodel
’s input andloss_fn
’s label input. Outputs of the combined PyTorch model are concatenation ofloss_fn
’s loss output andmodel
’s outputs.options (ORTTrainerOptions, default is None) – options for additional features.
Example
model = ... loss_fn = ... model_desc = { "inputs": [ ("input_ids", ["batch", "max_seq_len_in_batch"]), ("attention_mask", ["batch", "max_seq_len_in_batch"]), ("token_type_ids", ["batch", "max_seq_len_in_batch"]), ("masked_lm_labels", ["batch", "max_seq_len_in_batch"]), ("next_sentence_label", ["batch", 1]) ], "outputs": [ ("loss", [], True), ], } optim_config = optim.LambConfig(param_groups = [ { 'params' : ['model_param0'], 'alpha' : 0.8, 'beta' : 0.7}, { 'params' : ['model_param1' , 'model_param_2'], 'alpha' : 0.0} ], alpha=0.9, beta=0.999) ort_trainer = ORTTrainer(model, model_desc, optim_config, loss_fn)
Deprecated since version 1.10.
Use
ORTModule
instead.- eval_step(*args, **kwargs)[source]
Evaluation step method
- Parameters:
*args – Arbitrary arguments that are used as model input (data only)
**kwargs – Arbitrary keyword arguments that are used as model input (data only)
- Returns:
ordered
list
with model outputs as described byORTTrainer.model_desc
- save_as_onnx(path)[source]
Persists ONNX model into
path
The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard. The graph includes full information, including inference and training metadata.
- Parameters:
path (str) – Full path, including filename, to save the ONNX model in the filesystem
- Raises:
RuntimeWarning – raised when neither train_step or eval_step was called at least once
ValueError – raised when path is not valid path
- train_step(*args, **kwargs)[source]
Train step method
After forward pass, an ordered list with all outputs described at
ORTTrainer.model_desc
is returned. Additional information relevant to the train step is maintend byORTTrainer._train_step_info
. SeeTrainStepInfo
for details.- Parameters:
*args – Arbitrary arguments that are used as model input (data only)
**kwargs – Arbitrary keyword arguments that are used as model input (data only)
- Returns:
ordered
list
with model outputs as described byORTTrainer.model_desc
- state_dict(pytorch_format=False)[source]
Returns a dictionary with model, train step info and optionally, optimizer states
The returned dictionary contains the following information:
Model and optimizer states
Required ORTTrainerOptions settings
Distributed training information, such as but not limited to ZeRO
Train step info settings
Structure of the returned dictionary:
When pytorch_format = False
schema: { "model": { type: dict, schema: { "full_precision": { type: dict, schema: { model_weight_name: { type: array } } } } }, "optimizer": { type: dict, schema: { model_weight_name: { type: dict, schema: { "Moment_1": { type: array }, "Moment_2": { type: array }, "Update_Count": { type: array, optional: True # present if optimizer is adam, absent otherwise } } }, "shared_optimizer_state": { type: dict, optional: True, # present optimizer is shared, absent otherwise. schema: { "step": { type: array, } } } } }, "trainer_options": { type: dict, schema: { "mixed_precision": { type: bool }, "zero_stage": { type: int }, "world_rank": { type: int }, "world_size": { type: int }, "optimizer_name": { type: str }, "data_parallel_size": { type: int }, "horizontal_parallel_size": { type: int } } }, "partition_info": { type: dict, optional: True, # present if states partitioned, else absent schema: { model_weight_name: { type: dict, schema: { "original_dim": { type: array }, "megatron_row_partition": { type: int } } } } } }
}, “train_step_info”: {
type: dict, schema: {
“optimization_step”: {
type: int
}, “step”: {
type: int
}
}
}
} - When pytorch_format = True
schema: { model_weight_name: { type: tensor } }
- Parameters:
pytorch_format – boolean flag to select either ONNX Runtime or PyTorch state schema
- Returns:
A dictionary with ORTTrainer state
- load_state_dict(state_dict, strict=True)[source]
Loads state_dict containing model/optimizer states into ORTTrainer
The state_dict dictionary may contain the following information: - Model and optimizer states - Required ORTTrainerOptions settings - Distributed training information, such as but not limited to ZeRO
- Parameters:
state_dict – state dictionary containing both model and optimizer states. The structure of this dictionary should be the same as the one that is returned by ORTTrainer.state_dict for the case when pytorch_format=False
strict – boolean flag to strictly enforce that the input state_dict keys match the keys from ORTTrainer.state_dict
- save_checkpoint(path, user_dict={}, include_optimizer_states=True)[source]
Persists ORTTrainer state dictionary on disk along with user_dict.
Saves the state_dict along with the user_dict to a file specified by path.
- Parameters:
path – string representation to a file path or a python file-like object. if file already exists at path, an exception is raised.
user_dict – custom data to be saved along with the state_dict. This data will be returned to the user when load_checkpoint is called.
include_optimizer_states – boolean flag indicating whether or not to persist the optimizer states. on load_checkpoint, only model states will be loaded if include_optimizer_states==True
- load_checkpoint(*paths, strict=True)[source]
Loads the saved checkpoint state dictionary into the ORTTrainer
Reads the saved checkpoint files specified by paths from disk and loads the state dictionary onto the ORTTrainer. Aggregates the checkpoint files if aggregation is required.
- Parameters:
paths – one or more files represented as strings where the checkpoint is saved
strict – boolean flag to strictly enforce that the saved checkpoint state_dict keys match the keys from ORTTrainer.state_dict
- Returns:
dictionary that the user had saved when calling save_checkpoint
Internal ORTTrainer
- class onnxruntime.capi.ort_trainer.ORTTrainer(model, loss_fn, model_desc, training_optimizer_name, map_optimizer_attributes, learning_rate_description, device, gradient_accumulation_steps=1, world_rank=0, world_size=1, use_mixed_precision=False, allreduce_post_accumulation=False, global_step=0, get_lr_this_step=None, loss_scaler=None, deepspeed_zero_stage=0, enable_grad_norm_clip=True, frozen_weights=[], _opset_version=14, _enable_internal_postprocess=True, _extra_postprocess=None, _use_deterministic_compute=False, use_memory_efficient_gradient=False, run_symbolic_shape_infer=False, enable_adasum=False, optimized_model_filepath='')[source]
Bases:
object
- train_step(*args, **kwargs)[source]
inputs: model inputs, labels, learning rate, and, if in mixed_precision mode, loss_scale. outputs: if fetches is not provided, outputs are loss and
(if in mixed mode and is finishing gradient accumulation) all_finite. if fetches is provided, outputs contains these requested with fetches.
fetches: names of requested outputs
- eval_step(*args, **kwargs)[source]
inputs: model inputs and/or labels. outputs: if ‘fetches’ is not provided, outputs are loss and
(if in mixed mode and is finishing gradient accumulation) all_finite. if fetches is provided, outputs contains these requested with fetches.
fetches: names of requested outputs