Source code for mlprodict.onnxrt.ops_cpu.op_conv

# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.


:githublink:`%|py|7`
"""
import numpy
from ._op import OpRun
from ..shape_object import ShapeObjectFct
from .op_conv_ import ConvFloat, ConvDouble  # pylint: disable=E0611,E0401


[docs]class Conv(OpRun): atts = {'auto_pad': 'NOTSET', 'group': 1, 'dilations': [1, 1], 'kernel_shape': [], 'pads': [], 'strides': [1, 1]}
[docs] def __init__(self, onnx_node, desc=None, **options): OpRun.__init__(self, onnx_node, desc=desc, expected_attributes=Conv.atts, **options) self._init()
[docs] def _init(self): self.rt32_ = ConvFloat() self.rt64_ = ConvDouble() for rt in [self.rt32_, self.rt64_]: rt.init(self.auto_pad, numpy.array(self.dilations, dtype=numpy.int64), self.group, numpy.array(self.kernel_shape, dtype=numpy.int64), numpy.array(self.pads, dtype=numpy.int64), numpy.array(self.strides, dtype=numpy.int64))
[docs] def _run(self, X, W, B=None): # pylint: disable=W0221 if X is None: raise ValueError( "X cannot be None for operator %r, ONNX=%r" % ( type(self), self.onnx_node)) if X.dtype == numpy.float32: return (self.rt32_.compute(X, W, B), ) return (self.rt64_.compute(X, W, B), )
[docs] def _infer_shapes(self, X, W, B=None): # pylint: disable=W0221 def compute_shape(xshape, wshape, bshape): xs = numpy.ones(xshape, dtype=numpy.float32) ws = numpy.ones(wshape, dtype=numpy.float32) bs = (numpy.ones(bshape, dtype=numpy.float32) if bshape is not None else None) res = self.rt32_.compute(xs, ws, bs) return res.shape return (ShapeObjectFct( compute_shape, X, W, B, name="Conv", dtype=X.dtype), )