Source code for mlprodict.onnxrt.ops_cpu.op_dequantize_linear
# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.
:githublink:`%|py|7`
"""
import numpy
from ._op import OpRun
from ..shape_object import ShapeObject
[docs]class DequantizeLinear(OpRun):
atts = {'axis': 1}
python_inputs = ['*inputs']
[docs] def __init__(self, onnx_node, desc=None, **options):
OpRun.__init__(self, onnx_node, desc=desc,
expected_attributes=DequantizeLinear.atts,
**options)
[docs] def _run(self, *args): # pylint: disable=W0221
if len(args[1].shape) > 1:
raise RuntimeError( # pragma: no cover
"Input 2 must be a vector or a number.")
if len(args) > 2:
if args[2].dtype != args[0].dtype:
raise RuntimeError( # pragma no cover
"Type mismatch {} != {} in DequantizeLinear.".format(
args[0].dtype, args[2].dtype))
if len(args[2].shape) > 0:
new_shape = [1 for s in args[0].shape]
new_shape[self.axis] = len(args[2])
x = args[0].astype(numpy.float32) - args[2].reshape(new_shape)
y = x * args[1].reshape(new_shape)
else:
x = args[0].astype(numpy.float32) - args[2]
y = x * args[1]
elif len(args[1].shape) > 0:
new_shape = [1 for s in args[0].shape]
new_shape[self.axis] = len(args[2])
y = args[0].astype(numpy.float32) * args[2].reshape(new_shape)
else:
y = args[0].astype(numpy.float32) * args[2]
return (y.astype(numpy.float32), )
[docs] def _infer_shapes(self, *args): # pylint: disable=W0221
return (ShapeObject(args[0].shape, dtype=numpy.float32), )