Source code for mlprodict.onnxrt.ops_cpu.op_constant_of_shape

# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.


:githublink:`%|py|7`
"""
import numpy
from ._op import OpRun
from ..shape_object import ShapeObject


[docs]class ConstantOfShape(OpRun): atts = {'value': numpy.array([0], dtype=numpy.float32)}
[docs] def __init__(self, onnx_node, desc=None, **options): OpRun.__init__(self, onnx_node, desc=desc, expected_attributes=ConstantOfShape.atts, **options) self.cst = (self.value[0] if isinstance(self.value, numpy.ndarray) else self.value) if not isinstance(self.cst, (float, numpy.float32, numpy.float64, numpy.int64, numpy.int32, numpy.bool, numpy.float16)): raise TypeError( # pragma: no cover "cst must be a real not {}".format(type(self.cst)))
[docs] def _run(self, data): # pylint: disable=W0221 res = numpy.full(tuple(data), self.cst) return (res, )
[docs] def _infer_shapes(self, data): # pylint: disable=W0221 # pref = str(hex(id(self))[2:]) return (ShapeObject(None, self.cst.dtype), )
[docs] def to_python(self, inputs): lines = ['cst = value[0] if isinstance(value, numpy.ndarray) else value', 'return numpy.full(tuple(%s), cst)' % inputs[0]] return ("import numpy", "\n".join(lines))