Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ..shape_object import ShapeObject 

10 

11 

12def _pad_impl(data, raw_pads, mode, constant_values=0.0): 

13 input_rank = data.ndim 

14 if input_rank * 2 != raw_pads.size: 

15 raise RuntimeError( # pragma: no cover 

16 'The number of elements in raw_pads should be 2 * data_rank') 

17 

18 half = raw_pads.shape[0] // 2 

19 pad_width = tuple((raw_pads[i], raw_pads[i + half]) 

20 for i in range(0, half)) 

21 

22 if mode == 'constant': 

23 return numpy.pad(data, pad_width=pad_width, mode=mode, 

24 constant_values=constant_values) 

25 return numpy.pad(data, pad_width=pad_width, mode=mode) 

26 

27 

28def onnx_pad(data, pads, constant_value=None, mode='constant'): 

29 """ 

30 Implements :epkg:`numpy:pad` based on ONNX signature. 

31 

32 :param data: data to pad 

33 :param pads: tensor of integers indicating the number of 

34 padding elements to add or remove (if negative) at the 

35 beginning and end of each axis. For 2D input tensor, it 

36 is the number of pixels. `pads` should be a 1D tensor of 

37 shape `[2 * input_rank]`. `pads` format should be: 

38 `[x1_begin, x2_begin,...,x1_end, x2_end,...]`, where `xi_begin` is 

39 the number of pad values added at the beginning of axis `i` 

40 and xi_end, the number of pad values added at the end of axis `i`. 

41 :param constant_value: A scalar value to be used if the mode chosen is 

42 `constant` (by default it is 0, empty string or False). 

43 :param mode: Supported modes: `constant`(default), `reflect`, `edge` 

44 :return: tensor after padding 

45 """ 

46 return _pad_impl( 

47 data, pads, mode=mode, 

48 constant_values=constant_value or numpy.array( 

49 [0], dtype=data.dtype.type)) 

50 

51 

52class Pad(OpRun): 

53 

54 atts = {'mode': b'constant'} 

55 

56 def __init__(self, onnx_node, desc=None, **options): 

57 OpRun.__init__(self, onnx_node, desc=desc, 

58 expected_attributes=Pad.atts, 

59 **options) 

60 self.mode_ = self.mode.decode('ascii') 

61 

62 def _run(self, data, pads, constant_value=None): # pylint: disable=W0221 

63 if constant_value is None: 

64 constant_value = 0 

65 return (_pad_impl(data, pads, mode=self.mode_, 

66 constant_values=constant_value), ) 

67 

68 def _infer_shapes(self, data, pads, constant_value=None): # pylint: disable=E0202,W0221 

69 return (ShapeObject(None, data.dtype), ) 

70 

71 def _infer_types(self, data, pads, constant_value=None): # pylint: disable=E0202,W0221 

72 return (data, ) 

73 

74 def _infer_sizes(self, *args): # pylint: disable=W0221 

75 res = self.run(*args) 

76 return (dict(temp=0), ) + res