Source code for onnxruntime.capi.ort_trainer

import io
import os
import warnings
from distutils.version import LooseVersion

import numpy as np
import onnx
import torch
import torch.nn
import torch.onnx
from onnx import helper, numpy_helper

import onnxruntime as ort
import onnxruntime.capi.pt_patch
from import SymbolicShapeInference

from import postprocess
from .checkpointing_utils import CombineZeroCheckpoint, get_checkpoint_name, list_checkpoint_files


[docs]class IODescription: def __init__(self, name, shape, dtype=None, num_classes=None): self.name_ = name self.shape_ = shape self.dtype_ = dtype self.num_classes_ = num_classes
[docs]class ModelDescription: def __init__(self, inputs, outputs): self.inputs_ = inputs self.outputs_ = outputs
def resolve_symbolic_dimensions(inputs, input_descs, output_descs): import copy output_descs_copy = copy.deepcopy(output_descs) resolved_dims = {} for input, input_desc in zip(inputs, input_descs): for i, axis in enumerate(input_desc.shape_): if isinstance(axis, str): resolved_dims[axis] = input.size()[i] for output_desc in output_descs_copy: for i, axis in enumerate(output_desc.shape_): if isinstance(axis, str): output_desc.shape_[i] = resolved_dims[axis] if any(isinstance(axis, str) for axis in output_desc.shape_ for output_desc in output_descs): raise RuntimeError("Cannot run model with unknown output dimensions") return output_descs_copy def generate_sample(desc, device=None): # symbolic dimensions are described with strings. set symbolic dimensions to be 1 size = [s if isinstance(s, (int)) else 1 for s in desc.shape_] if desc.num_classes_: return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device) else: return torch.randn(size, dtype=desc.dtype_).to(device) def get_device_index(device): if type(device) == str: # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 device = torch.device(device) return 0 if device.index is None else device.index def input_get_device_index(input): if isinstance(input, (list, tuple)): device_index = get_device_index(input[0].device) else: device_index = get_device_index(input.device) return device_index def get_all_gradients_finite_arg_name(session): all_fp16_or_fp32_gradients_finite_node_args = [x for x in session._outputs_meta if "all_gradients_finite" in] if len(all_fp16_or_fp32_gradients_finite_node_args) < 1: raise RuntimeError( "Failed to find a group NodeArg with name that matches 'all_gradients_finite'\ from the training session." ) return all_fp16_or_fp32_gradients_finite_node_args[0].name def get_group_accumulated_gradients_output_node_arg_name(session): # TODO: get the constant string via pybind. # optimizer_graph_builder BuildGroupNode with fixed string: 'Group_Accumulated_Gradients' accumulated_gradients_output_node_args = [ x for x in session._outputs_meta if "Group_Accumulated_Gradients" in ] if len(accumulated_gradients_output_node_args) != 1: raise RuntimeError( "Failed to find a group NodeArg with name that matches 'Group_Accumulated_Gradients'\ from the training session." ) return accumulated_gradients_output_node_args[0].name def ort_training_session_run_helper(session, iobinding, inputs, input_descs, output_descs, device, run_options=None): for input, input_desc in zip(inputs, input_descs): device_index = input_get_device_index(input) iobinding.bind_input( input_desc.name_, input.device.type, device_index, dtype_torch_to_numpy(input.dtype), list(input.size()), input.data_ptr(), ) output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs) torch_outputs = {} for output_desc in output_descs_resolved: torch_tensor = torch.zeros( output_desc.shape_, device=device, dtype=output_desc.eval_dtype_ if hasattr(output_desc, "eval_dtype_") else output_desc.dtype_, ) iobinding.bind_output( output_desc.name_, torch_tensor.device.type, get_device_index(device), dtype_torch_to_numpy(torch_tensor.dtype), list(torch_tensor.size()), torch_tensor.data_ptr(), ) torch_outputs[output_desc.name_] = torch_tensor session.run_with_iobinding(iobinding, run_options) return torch_outputs def FuseSofmaxNLLToSoftmaxCE(onnx_model): nll_count = 0 while True: nll_count = nll_count + 1 nll_loss_node = None nll_loss_node_index = 0 for nll_loss_node_index, node in enumerate(onnx_model.graph.node): if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss": nll_loss_node = node break if nll_loss_node is None: break softmax_node = None softmax_node_index = 0 label_input_name = None weight_input_name = None for softmax_node_index, node in enumerate(onnx_model.graph.node): if node.op_type == "LogSoftmax": # has to be connected to nll_loss if len(nll_loss_node.input) > 2: weight_input_name = nll_loss_node.input[2] if node.output[0] == nll_loss_node.input[0]: softmax_node = node label_input_name = nll_loss_node.input[1] break elif node.output[0] == nll_loss_node.input[1]: softmax_node = node label_input_name = nll_loss_node.input[0] break else: if softmax_node is not None: break if softmax_node is None: break # delete nll_loss and LogSoftmax nodes in order if nll_loss_node_index < softmax_node_index: del onnx_model.graph.node[softmax_node_index] del onnx_model.graph.node[nll_loss_node_index] else: del onnx_model.graph.node[nll_loss_node_index] del onnx_model.graph.node[softmax_node_index] probability_output_name = softmax_node.output[0] node = onnx_model.graph.node.add() inputs = ( [softmax_node.input[0], label_input_name, weight_input_name] if weight_input_name else [softmax_node.input[0], label_input_name] ) node.CopyFrom( onnx.helper.make_node( "SparseSoftmaxCrossEntropy", inputs, [nll_loss_node.output[0], probability_output_name], "nll_loss_node_" + str(nll_count), ) ) return onnx_model def delete_input_with_name(input, name): index = 0 for i in input: if == name: del input[index] break index = index + 1 # reference: # # # also must map to types accepted by: # MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type) def dtype_torch_to_numpy(torch_dtype): if torch_dtype == torch.float64 or torch_dtype == torch.double: return np.float64 elif torch_dtype == torch.float32 or torch_dtype == torch.float: return np.float32 elif torch_dtype == torch.float16 or torch_dtype == torch.half: return np.float16 elif torch_dtype == torch.int64 or torch_dtype == torch.long: return np.longlong elif torch_dtype == torch.int32 or torch_dtype == return np.int32 elif torch_dtype == torch.int16 or torch_dtype == torch.short: return np.int16 elif torch_dtype == torch.bool: return np.bool else: raise Exception("Torch type to numpy type mapping unavailable for: " + str(torch_dtype)) class model_loss_cls(torch.nn.Module): def __init__(self, model, loss_fn): super(model_loss_cls, self).__init__() self.model_ = model self.loss_fn_ = loss_fn def forward(self, *inputs): # here we assume input can be unpacked into input and label input, label = inputs[:-1], inputs[-1] preds = self.model_(*input) return self.loss_fn_(preds, label), preds class WrapModel(torch.nn.Module): def __init__(self, model, loss_fn, input_names): super(WrapModel, self).__init__() self.model_ = model self.loss_fn_ = loss_fn self.input_names_ = input_names def forward(self, *inputs): import inspect # *inputs is given by torch trace. It is in the order of input_names. # model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names. sig = inspect.signature(self.model_.forward) ordered_list_keys = list(sig.parameters.keys()) input_dict = {} for key in sig.parameters.keys(): if key in self.input_names_: input_dict[key] = inputs[self.input_names_.index(key)] model_out = self.model_(**input_dict) if self.loss_fn_ is None: return model_out label = inputs[-1] preds = model_out return self.loss_fn_(preds, label), preds def wrap_for_input_match(model, loss_fn, input_names): import inspect sig = inspect.signature(model.forward) ordered_list_keys = list(sig.parameters.keys()) if loss_fn: sig_loss = inspect.signature(loss_fn) if len(sig_loss.parameters) != 2: raise RuntimeError("loss function should take two arguments - predict and label.") # label shall be the second input to loss_fn. ordered_list_keys = [*ordered_list_keys, list(sig_loss.parameters.keys())[1]] # name match is needed only when input_names are a subset # of expected inputs (inputs to model and loss_fn combined). if len(input_names) > len(ordered_list_keys): # this is likely the case where input arguments are packed. # TODO: to unpack the input argument. return model_loss_cls(model, loss_fn) if loss_fn else model elif len(input_names) == len(ordered_list_keys): # in this case, we do not require name match. return model_loss_cls(model, loss_fn) if loss_fn else model if not all(x in ordered_list_keys for x in input_names): # model desc has name(s) not matching the model signature. We cannot do anything in this case. # better to warning the user. return model_loss_cls(model, loss_fn) if loss_fn else model # if input_names match ordered_list_keys, there is not need for wrapping match = True for i, input_name in enumerate(input_names): if input_name != ordered_list_keys[i]: match = False break if match: return model_loss_cls(model, loss_fn) if loss_fn else model model = WrapModel(model, loss_fn, input_names) return model def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, opset_version=DEFAULT_OPSET_VERSION): # example: {input0:{0:'batch'}, input1:{0:'batch'}} dynamic_axes = {} for input in model_desc.inputs_: symbolic_axis = {} for i, axis in enumerate(input.shape_): if isinstance(axis, str): symbolic_axis[i] = axis if len(symbolic_axis): dynamic_axes[input.name_] = symbolic_axis for output in model_desc.outputs_: symbolic_axis = {} for i, axis in enumerate(output.shape_): if isinstance(axis, str): symbolic_axis[i] = axis if len(symbolic_axis): dynamic_axes[output.name_] = symbolic_axis input_names = [input.name_ for input in model_desc.inputs_] output_names = [output.name_ for output in model_desc.outputs_] if isinstance(inputs, torch.Tensor): inputs = [inputs] if isinstance(inputs, dict): sample_inputs = [inputs[k.name_].to(device=device) for k in model_desc.inputs_] elif isinstance(inputs, (list, tuple)): sample_inputs = [ for i, input in enumerate(inputs) if i < len(model_desc.inputs_)] else: raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.") # pytorch onnx exporter/trace does not try to match argument names. # e.g. for models with optional inputs, it requires all inputs be present. # this is a problem because the model graph depends on inputs provided. model = wrap_for_input_match(model, loss_fn, input_names) model.eval() with torch.no_grad(): import copy # Deepcopy inputs, since input values may change after model run. sample_inputs_copy = copy.deepcopy(sample_inputs) try: # Deepcopy model, in case model is stateful and changes after model run. model_copy = copy.deepcopy(model) except Exception: model_copy = model warnings.warn( "This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." " Compute will continue, but unexpected results may occur!" ) sample_outputs = model_copy(*sample_inputs_copy) if isinstance(sample_outputs, torch.Tensor): sample_outputs = [sample_outputs] for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_): output_desc.dtype_ = sample_output.dtype model.train() f = io.BytesIO() # Other export options to use(this is for backward compatibility). other_export_options = {} other_export_options["training"] = True # This option was added after 1.4 release. if LooseVersion(torch.__version__) > LooseVersion("1.4.0") and LooseVersion(torch.__version__) < LooseVersion( "1.10.0" ): other_export_options["enable_onnx_checker"] = False # This option was added after 1.6 release. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): other_export_options["training"] = torch.onnx.TrainingMode.TRAINING # Deepcopy inputs, since input values may change after model run. import copy sample_inputs_copy = copy.deepcopy(sample_inputs) # Enable contrib ops export from PyTorch from import pytorch_export_contrib_ops pytorch_export_contrib_ops.register() torch.onnx._export( model, tuple(sample_inputs_copy), f, input_names=input_names, output_names=output_names, opset_version=opset_version, dynamic_axes=dynamic_axes, do_constant_folding=False, **other_export_options, ) onnx_model = onnx.load_model_from_string(f.getvalue()) # Remove 'model_.' prefix introduced by model wrapper for initializers. if isinstance(model, WrapModel) or isinstance(model, model_loss_cls): replace_name_dict = {} for n in onnx_model.graph.initializer: if"model_."): replace_name_dict[] =[len("model_.") :] = replace_name_dict[] for n in onnx_model.graph.node: for i, name in enumerate(n.input): if name in replace_name_dict: n.input[i] = replace_name_dict[name] return onnx_model def create_ort_training_session_with_optimizer( model, device, training_optimizer_name, lr_params_feed_name, map_optimizer_attributes, world_rank=-1, world_size=1, gradient_accumulation_steps=1, bind_parameters=False, use_mixed_precision=False, allreduce_post_accumulation=False, deepspeed_zero_stage=0, enable_grad_norm_clip=True, frozen_weights=[], opset_version=DEFAULT_OPSET_VERSION, use_deterministic_compute=False, use_memory_efficient_gradient=False, enable_adasum=False, optimized_model_filepath="", ): output_name = model.graph.output[0].name ort_parameters = ort.TrainingParameters() ort_parameters.loss_output_name = output_name ort_parameters.use_mixed_precision = use_mixed_precision ort_parameters.world_rank = world_rank ort_parameters.world_size = world_size ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps ort_parameters.allreduce_post_accumulation = allreduce_post_accumulation ort_parameters.deepspeed_zero_stage = deepspeed_zero_stage ort_parameters.enable_grad_norm_clip = enable_grad_norm_clip ort_parameters.set_gradients_as_graph_outputs = False ort_parameters.use_memory_efficient_gradient = use_memory_efficient_gradient ort_parameters.enable_adasum = enable_adasum output_types = {} for output in model.graph.output: output_types[] = output.type.tensor_type # pybind does not allow to add directly to ort_parameters.weights_to_train. # Have to work around by using a temporary weights_to_train. torch_params = {} optimizer_attributes_map = {} optimizer_int_attributes_map = {} unused_frozen_weights = [n for n in frozen_weights if n not in [ for i in model.graph.initializer]] if unused_frozen_weights: raise RuntimeError("{} in frozen_weights not found in model weights.".format(unused_frozen_weights)) weights_to_train = set() for initializer in model.graph.initializer: if in frozen_weights: continue weights_to_train.add( if map_optimizer_attributes is not None: attributes = map_optimizer_attributes( optimizer_attributes_map[] = {} optimizer_int_attributes_map[] = {} for k, v in attributes.items(): if isinstance(v, float): optimizer_attributes_map[][k] = v elif isinstance(v, int): optimizer_int_attributes_map[][k] = v else: raise ValueError("Optimizer attributes must be either float or int.") else: optimizer_attributes_map[] = {} optimizer_int_attributes_map[] = {} if bind_parameters: for initializer in model.graph.initializer: torch_tensor = torch.nn.Parameter(torch.as_tensor(numpy_helper.to_array(initializer), device=device)) delete_input_with_name(model.graph.input, model.graph.input.extend( [helper.make_tensor_value_info(, initializer.data_type, initializer.dims)] ) torch_params[] = torch_tensor del model.graph.initializer[:] ort_parameters.weights_to_train = weights_to_train ort_parameters.training_optimizer_name = training_optimizer_name ort_parameters.lr_params_feed_name = lr_params_feed_name ort_parameters.optimizer_attributes_map = optimizer_attributes_map ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map sessionOptions = ort.SessionOptions() sessionOptions.use_deterministic_compute = use_deterministic_compute if len(optimized_model_filepath) > 0: sessionOptions.optimized_model_filepath = optimized_model_filepath session = ort.TrainingSession(model.SerializeToString(), ort_parameters, sessionOptions) train_io_binding = session.io_binding() eval_io_binding = session.io_binding() if bind_parameters: for param in torch_params.keys(): torch_tensor = torch_params[param] train_io_binding.bind_input( param, torch_tensor.device.type, get_device_index(torch_tensor.device), dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()), torch_tensor.data_ptr(), ) eval_io_binding.bind_input( param, torch_tensor.device.type, get_device_index(torch_tensor.device), dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()), torch_tensor.data_ptr(), ) return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types def save_checkpoint( model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", checkpoint_state_dict=None, include_optimizer_state=True ): if checkpoint_state_dict == None: checkpoint_state_dict = {"model": model.state_dict(include_optimizer_state)} else: checkpoint_state_dict.update({"model": model.state_dict(include_optimizer_state)}) assert os.path.exists(checkpoint_dir), "ERROR: Checkpoint directory doesn't exist: {}".format(checkpoint_dir) checkpoint_name = get_checkpoint_name( checkpoint_prefix, model.deepspeed_zero_stage_, model.world_rank, model.world_size ) checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) if os.path.exists(checkpoint_file): warnings.warn("{} already exists, overwriting.".format(checkpoint_file)), checkpoint_file) def _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict): checkpoint_name = get_checkpoint_name(checkpoint_prefix, is_partitioned, model.world_rank, model.world_size) checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) if is_partitioned: assert_msg = ( "Couldn't find checkpoint file {}." + "Optimizer partitioning is enabled using ZeRO. Please make sure that the " + "checkpoint file exists for rank {} of {}." ).format(checkpoint_file, model.world_rank, model.world_size) else: assert_msg = "Couldn't find checkpoint file {}.".format(checkpoint_file) assert os.path.exists(checkpoint_file), assert_msg checkpoint_state = torch.load(checkpoint_file, map_location="cpu") model.load_state_dict(checkpoint_state["model"], strict=strict) del checkpoint_state["model"] return checkpoint_state def _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict): checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) ckpt_agg = CombineZeroCheckpoint(checkpoint_files) aggregate_state_dict = ckpt_agg.aggregate_checkpoints() model.load_state_dict(aggregate_state_dict, strict=strict) # aggregate other keys in the state_dict. # Values will be overwritten for matching keys among workers all_checkpoint_states = dict() for checkpoint_file in checkpoint_files: checkpoint_state = torch.load(checkpoint_file, map_location="cpu") del checkpoint_state["model"] all_checkpoint_states.update(checkpoint_state) return all_checkpoint_states def load_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False): checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) is_partitioned = False if len(checkpoint_files) > 1: warnings.warn( f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." + "Attempting to load ZeRO checkpoint." ) is_partitioned = True if (not model.deepspeed_zero_stage_) and is_partitioned: return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict) else: return _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict)
[docs]class ORTTrainer: def __init__( self, model, loss_fn, model_desc, training_optimizer_name, map_optimizer_attributes, learning_rate_description, device, gradient_accumulation_steps=1, world_rank=0, world_size=1, use_mixed_precision=False, allreduce_post_accumulation=False, global_step=0, get_lr_this_step=None, loss_scaler=None, deepspeed_zero_stage=0, enable_grad_norm_clip=True, frozen_weights=[], _opset_version=DEFAULT_OPSET_VERSION, _enable_internal_postprocess=True, _extra_postprocess=None, _use_deterministic_compute=False, use_memory_efficient_gradient=False, run_symbolic_shape_infer=False, enable_adasum=False, optimized_model_filepath="", ): super(ORTTrainer, self).__init__() """ Initialize ORTTrainer. Args: model: one of - a PyTorch model (class that inherits from torch.nn.Module) - a combined PyTorch model and loss function. Inputs to this combined PyTorch model are a concatenation of the model's input and the loss function's label input. Outputs are a concatenation of the loss function's output and the model's output. - a combined ONNX model and loss function. loss_fn: one of - a PyTorch loss function if 'model' is a PyTorch model. A loss function takes two inputs (prediction, label) and outputs a loss tensor. - None if model is already combined with a loss function. model_desc: Specify input/output shapes, types, and names. Must be consistent with the training model. training_optimizer_name: one of - 'SGDOptimizer' - 'AdamOptimizer' - 'LambOptimizer' map_optimizer_attributes: for optimizers with weight-dependent parameters. A callable that maps weight name to a set of optimization parameters. Defaults to None. learning_rate_description: the name, shape and type of the learning rate in form of IODescription(Learning_Rate_Name, [1,], torch.float32). Because learning_rate is an input to the training model, Learning_Rate_Name must be specified so that there is no name conflict within the model. device: device to store tensors (e.g. 'cpu', 'cuda', 'cuda:<int_idx>'). gradient_accumulation_steps: number of training steps to accumulate gradients before averaging and applying them. Defaults to 1. world_rank: rank id used for distributed training. Defaults to 0. world_size: number of ranks participating in distributed training. Defaults to 1. use_mixed_precision: flag to enable mixed precision (aka fp16). Defaults to False. allreduce_post_accumulation: controls whether overlaping gradient computation is applied with allreduce. Defaults to False. global_step: training step that is used as input to 'get_lr_this_step'. Defaults to 0. get_lr_this_step: functor used as learning rate scheduler. It uses 'global_step' as input. Defaults to None. loss_scaler: updates loss scale automatically when 'use_mixed_precision' is specified. Defaults to None. deepspeed_zero_stage: controls whether to partition state using the DeepSpeed ZeRO technique. Stages 0 and 1 are supported. Defaults to 0 (disabled). enable_grad_norm_clip: enables gradient norm clipping. Defaults to True. frozen_weights: list of model parameters to be frozen (not trained). Defaults to []. _enable_internal_postprocess: whether to run or not the internal postprocesses. Defaults to True _extra_postprocess: a callable to postprocess the ONNX model that is converted from PyTorch. Defaults to None use_memory_efficient_gradient: use memory aware gradient builder. Defaults to False run_symbolic_shape_infer: run symbolic shape inference Defaults to False optimized_model_filepath: path to output the optimized training graph. Defaults to "" (no output). .. deprecated:: 1.10 Use :class:`ORTModule` instead. """ warnings.warn( "ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.", FutureWarning, ) warnings.warn( "DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it" ) self.is_train = True self.torch_model_ = None self.onnx_model_ = None self._enable_internal_postprocess = _enable_internal_postprocess self._extra_postprocess = _extra_postprocess if isinstance(model, torch.nn.Module): self.torch_model_ = model self.loss_fn_ = loss_fn self._torch_state_dict_keys = list(model.state_dict().keys()) else: self._torch_state_dict_keys = [] self.onnx_model_ = model if loss_fn is not None: warnings.warn("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.") # TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn self.loss_fn_ = None if self._enable_internal_postprocess: postprocess.run_postprocess(self.onnx_model_) if self._extra_postprocess: self._extra_postprocess(self.onnx_model_) self.model_desc_ = model_desc self.input_desc_with_lr = [*self.model_desc_.inputs_, learning_rate_description] self.world_rank = world_rank self.world_size = world_size self.use_mixed_precision = use_mixed_precision self.session = None self.device_ = device self.gradient_accumulation_steps = gradient_accumulation_steps # we use self.current_step to count calls to train_step. It is used for gradient accumulation. # gradients are being accumulated when self.current_step is not divisible by gradient_accumulation_steps. # gradients are updated when self.current_step is divisible by gradient_accumulation_steps. self.current_step = 0 # we use self.global_step_ to count optimizations being performed. # it is used to calculate learning rate if self.get_lr_this_step_ is provided. self.global_step_ = global_step self.get_lr_this_step_ = get_lr_this_step self.loss_scaler_ = loss_scaler if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None: warnings.warn("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.") self.training_optimizer_name_ = training_optimizer_name self.learning_rate_description_ = learning_rate_description self.map_optimizer_attributes_ = map_optimizer_attributes self.allreduce_post_accumulation_ = allreduce_post_accumulation self.deepspeed_zero_stage_ = deepspeed_zero_stage self.enable_grad_norm_clip_ = enable_grad_norm_clip self.frozen_weights_ = frozen_weights self.opset_version_ = _opset_version self.state_dict_ = None self._use_deterministic_compute = _use_deterministic_compute self.use_memory_efficient_gradient = use_memory_efficient_gradient self.run_symbolic_shape_infer = run_symbolic_shape_infer self.enable_adasum = enable_adasum self.optimized_model_filepath = optimized_model_filepath # use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs. # see prepare_input_and_fetches for more details. self.loss_scale_input_name = "default_loss_scale_input_name" self._init_session() def _init_session(self): if self.onnx_model_ is None: return self._verify_fully_optimized_model(self.onnx_model_) if self.run_symbolic_shape_infer: self.onnx_model_ = SymbolicShapeInference.infer_shapes( self.onnx_model_, auto_merge=True, guess_output_rank=True ) # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. # for example, load_state_dict will be called before returing the function, and it calls _init_session again del self.session ( self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types, ) = create_ort_training_session_with_optimizer( self.onnx_model_, self.device_, self.training_optimizer_name_, self.learning_rate_description_.name_, self.map_optimizer_attributes_, self.world_rank, self.world_size, self.gradient_accumulation_steps, bind_parameters=False, use_mixed_precision=self.use_mixed_precision, allreduce_post_accumulation=self.allreduce_post_accumulation_, deepspeed_zero_stage=self.deepspeed_zero_stage_, enable_grad_norm_clip=self.enable_grad_norm_clip_, frozen_weights=self.frozen_weights_, opset_version=self.opset_version_, use_deterministic_compute=self._use_deterministic_compute, use_memory_efficient_gradient=self.use_memory_efficient_gradient, enable_adasum=self.enable_adasum, optimized_model_filepath=self.optimized_model_filepath, ) self.loss_scale_input_name = self.session.loss_scale_input_name if self.use_mixed_precision: self.input_desc_with_lr_and_loss_scale = [ *self.input_desc_with_lr, IODescription(self.loss_scale_input_name, [], torch.float32), ] # ORT backend has modified model output dtype from float32 to float16. for o_desc in self.model_desc_.outputs_: if ( self.use_mixed_precision and o_desc.dtype_ == torch.float32 and not self.session.is_output_fp32_node(o_desc.name_) ): o_desc.eval_dtype_ = torch.float16 else: o_desc.eval_dtype_ = o_desc.dtype_ # gradient accumulation buffers are connected to a single node with a boolean, dimension 1 tensor output. # add a matching output to drive gradient accumulation. if self.gradient_accumulation_steps > 1: self.output_desc_with_group_accumulated_gradients = [ *self.model_desc_.outputs_, IODescription(get_group_accumulated_gradients_output_node_arg_name(self.session), [1], torch.bool), ] if self.use_mixed_precision: # when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine # if the gradient is usable. self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [ *self.model_desc_.outputs_, IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool), ] if self.state_dict_: self.load_state_dict(self.state_dict_, self.strict_) self.state_dict_ = None def _init_onnx_model(self, inputs): if self.onnx_model_ is not None: return if self.torch_model_ is not None: # NOTE: pt model is moved to cpu to conserve gpu memory. self.torch_model_.cpu() # torch buffers created using 'register_buffer' are not meant to be trainable. torch_buffers = list(dict(self.torch_model_.named_buffers()).keys()) self.frozen_weights_ = self.frozen_weights_ + torch_buffers self.onnx_model_ = convert_model_loss_fn_to_onnx( self.torch_model_, self.loss_fn_, self.model_desc_, torch.device("cpu"), inputs, opset_version=self.opset_version_, ) if self._enable_internal_postprocess: postprocess.run_postprocess(self.onnx_model_) if self._extra_postprocess: self._extra_postprocess(self.onnx_model_) self._init_session() def train(self): self.is_train = True def eval(self): self.is_train = False def _update_onnx_model_initializers(self, state_tensors): # replace the initializers with new value new_weights = [] replace_indices = [] for i, w in enumerate(self.onnx_model_.graph.initializer): if in state_tensors: new_weights.append(numpy_helper.from_array(state_tensors[], replace_indices.append(i) replace_indices.sort(reverse=True) for w_i in replace_indices: del self.onnx_model_.graph.initializer[w_i] self.onnx_model_.graph.initializer.extend(new_weights) def state_dict(self, include_optimizer_state=True): if not self.session: warnings.warn( "ONNXRuntime training session is not initialized yet. " "Please run train_step or eval_step at least once before calling state_dict()." ) return {} # extract trained weights session_state = self.session.get_state() torch_state = {} for name in session_state: torch_state[name] = torch.from_numpy(session_state[name]) # extract untrained weights and buffer for n in self.onnx_model_.graph.initializer: if not in torch_state: torch_state[] = torch.from_numpy(numpy_helper.to_array(n)) # Need to remove redundant initializers and name suffices to map back to original torch state names if not include_optimizer_state and self._torch_state_dict_keys: return {key: torch_state[key] for key in self._torch_state_dict_keys if key in torch_state} return torch_state def load_state_dict(self, state_dict, strict=False): # Note: It may happen ONNX model has not yet been initialized # In this case we cache a reference to desired state and delay the restore until after initialization # Unexpected behavior will result if the user changes the reference before initialization if not self.session: self.state_dict_ = state_dict self.strict_ = strict return # update onnx model from loaded state dict cur_initializers_names = [ for n in self.onnx_model_.graph.initializer] new_initializers = {} for name in state_dict: if name in cur_initializers_names: new_initializers[name] = state_dict[name].numpy() elif strict: raise RuntimeError("Checkpoint tensor: {} is not present in the model.".format(name)) self._update_onnx_model_initializers(new_initializers) # create new session based on updated onnx model self.state_dict_ = None self._init_session() # load training state session_state = {name: state_dict[name].numpy() for name in state_dict} self.session.load_state(session_state, strict) def save_as_onnx(self, path): if not self.session: warnings.warn( "ONNXRuntime training session is not initialized yet. " "Please run train_step or eval_step at least once before calling save_as_onnx()." ) return state_tensors = self.session.get_state() self._update_onnx_model_initializers(state_tensors) with open(path, "wb") as f: f.write(self.onnx_model_.SerializeToString()) def _prepare_input_and_fetches( self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs ): fetches = None if type(args) == tuple and len(args) == 1 and type(args[0]) == list: input = tuple(args[0]) else: input = args for input_desc in input_desc_with_: if input_desc.name_ in kwargs: input = input + (kwargs[input_desc.name_],) if internal_learning_rate is not None: input = input + (internal_learning_rate,) if internal_loss_scale is not None: input = input + (internal_loss_scale,) elif self.use_mixed_precision: # loss_scale input name is needed to call train_step, for example: # kwargs[model.loss_scale_input_name] = loss_scale # outputs = model.train_step(*args, **kwargs) # However, when first time train_step is called model.loss_scale_input_name is not set. # To workaround this problem, we use the special name 'default_loss_scale_input_name' to indicate # the loss_scale. if "default_loss_scale_input_name" in kwargs.keys(): input = input + (kwargs["default_loss_scale_input_name"],) fetches = None if "fetches" in kwargs: fetches = kwargs["fetches"] return input, fetches
[docs] def train_step(self, *args, **kwargs): """ inputs: model inputs, labels, learning rate, and, if in mixed_precision mode, loss_scale. outputs: if fetches is not provided, outputs are loss and (if in mixed mode and is finishing gradient accumulation) all_finite. if fetches is provided, outputs contains these requested with fetches. fetches: names of requested outputs """ # inputs to the ONNX model includes inputs to the original PyTorch model # plus learning rate and loss_scale if self.use_mixed_precision is True. # 1. when there are internal learning_rate and loss_scale (in fp16 cases) generators, # *args and **kwargs together contain ONLY and COMPLETE inputs to the PyTorch model. # In this case, changes to the training script is minimized. # 2. without internal learning rate and loss scale (in fp16 cases) generators, # *args and **kwargs passed in from the training script shall contains # inputs to the PyTorch model plus learning_rate and loss_scale. # it optionally contains the fetches. # localized arguments (*args) contains inputs to the ONNX model. # named arguments can contain both inputs, learning_rate and loss_scale, and the fetches learning_rate, loss_scale = None, None if self.get_lr_this_step_ is not None: # $args, **kwargs contains inputs to the pytorch model lr_this_step = self.get_lr_this_step_(self.global_step_) learning_rate = torch.tensor([lr_this_step]) if self.loss_scaler_ is not None and self.use_mixed_precision: loss_scale = torch.tensor([self.loss_scaler_.loss_scale_]) if self.onnx_model_ is None: sample_input, _ = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) self._init_onnx_model(sample_input) if self.use_mixed_precision: input, fetches = self._prepare_input_and_fetches( self.input_desc_with_lr_and_loss_scale, learning_rate, loss_scale, *args, **kwargs ) assert len(self.input_desc_with_lr_and_loss_scale) == len(input) input_descs = self.input_desc_with_lr_and_loss_scale else: input, fetches = self._prepare_input_and_fetches( self.input_desc_with_lr, learning_rate, loss_scale, *args, **kwargs ) assert len(self.input_desc_with_lr) == len(input) input_descs = self.input_desc_with_lr self.current_step += 1 # handle gradient accumulation in fully optimized mode run_options = None has_if_all_finite = False if fetches: output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch] elif self.current_step % self.gradient_accumulation_steps != 0: run_options = ort.RunOptions() run_options.only_execute_path_to_fetches = True output_desc = self.output_desc_with_group_accumulated_gradients elif self.use_mixed_precision: has_if_all_finite = True output_desc = self.output_desc_with_all_fp_16_or_fp32_gradients_finite else: output_desc = self.model_desc_.outputs_ if not isinstance(input, (list, tuple)): input = (input,) session_run_results = ort_training_session_run_helper( self.session, self.train_io_binding, input, input_descs, output_desc, self.device_, run_options ) if has_if_all_finite: # After session run with all_fp32_gradients_finite, we need to clear the iobinding's output state. # Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce # because all_fp32_gradients_finite is still in the feed. self.train_io_binding.clear_binding_outputs() all_finite = session_run_results[self.output_desc_with_all_fp_16_or_fp32_gradients_finite[-1].name_] if self.loss_scaler_ is not None: self.loss_scaler_.update_loss_scale(all_finite) if all_finite: # optimization has done, increase self.global_step_ self.global_step_ = self.global_step_ + 1 elif self.current_step % self.gradient_accumulation_steps == 0: # optimization has done, increase self.global_step_ self.global_step_ = self.global_step_ + 1 if fetches is not None: results = [session_run_results[fetch] for fetch in fetches] elif has_if_all_finite and self.loss_scaler_ is None: # return descripted outputs plus the all_finite flag so that the training script can handle loss scaling. results = [ session_run_results[output_desc.name_] for output_desc in self.output_desc_with_all_fp_16_or_fp32_gradients_finite ] else: results = [session_run_results[output_desc.name_] for output_desc in self.model_desc_.outputs_] return results[0] if len(results) == 1 else results
def __call__(self, *args, **kwargs): if self.is_train: return self.train_step(*args, **kwargs) else: return self.eval_step(*args, **kwargs)
[docs] def eval_step(self, *args, **kwargs): """ inputs: model inputs and/or labels. outputs: if 'fetches' is not provided, outputs are loss and (if in mixed mode and is finishing gradient accumulation) all_finite. if fetches is provided, outputs contains these requested with fetches. fetches: names of requested outputs """ # with model_loss_cls, the last input is label, first output is loss input, fetches = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) if self.onnx_model_ is None: if self.torch_model_ is not None: self._init_onnx_model(input) else: raise RuntimeError( "Model is unintialized. Please ensure a valid ONNX model or PyTorch model is provided to this Trainer." ) input_desc = self.model_desc_.inputs_[0 : len(input)] if fetches is None: output_desc = self.model_desc_.outputs_ else: output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch] if not isinstance(input, (list, tuple)): input = (input,) run_options = ort.RunOptions() run_options.only_execute_path_to_fetches = True run_options.training_mode = False session_run_results = ort_training_session_run_helper( self.session, self.eval_io_binding, input, input_desc, output_desc, self.device_, run_options ) if len(session_run_results) == 1: return session_run_results[list(session_run_results.keys())[0]] else: return [session_run_results[output_desc.name_] for output_desc in output_desc]
def _verify_fully_optimized_model(self, model): assert len(model.graph.output) > 0 # model's first output must be the loss tensor if ( model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().FLOAT and model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().FLOAT16 and model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().DOUBLE and model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().COMPLEX64 and model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().COMPLEX128 and model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().BFLOAT16 ): raise RuntimeError( "the first output of a model to run with fully optimized ORT backend must be float types." ) if len(model.graph.output[0].type.tensor_type.shape.dim) != 0: raise RuntimeError( "the first output of a model to run with fully optimized ORT backend assumed to be loss and must be a scalar." )
class LossScaler: def __init__( self, loss_scale_input_name, is_dynamic_scale, loss_scale=float(1 << 16), up_scale_window=2000, min_loss_scale=1.0, max_loss_scale=float(1 << 24), ): super(LossScaler, self).__init__() self.loss_scale_input_name_ = loss_scale_input_name self.is_dynamic_scale_ = is_dynamic_scale self.initial_loss_scale_ = loss_scale self.up_scale_window_ = up_scale_window self.min_loss_scale_ = min_loss_scale self.max_loss_scale_ = max_loss_scale self.loss_scale_ = loss_scale self.stable_steps_ = 0 def update_loss_scale(self, is_all_finite): if not self.is_dynamic_scale_: return if is_all_finite: self.stable_steps_ += 1 if self.stable_steps_ >= self.up_scale_window_: self.loss_scale_ = min(self.max_loss_scale_, self.loss_scale_ * 2) self.stable_steps_ = 0 else: self.loss_scale_ = max(self.min_loss_scale_, self.loss_scale_ / 2) self.stable_steps_ = 0 def reset(self): self.loss_scale_ = self.initial_loss_scale_ self.stable_steps_ = 0